phanerozoic commited on
Commit
bd8cd8d
·
verified ·
1 Parent(s): 72242ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -24
app.py CHANGED
@@ -1,62 +1,110 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
3
  from transformers.utils import logging as hf_logging
4
 
5
  hf_logging.set_verbosity_error()
 
 
 
 
 
 
 
6
 
7
- MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
8
- MAX_TURNS = 6
9
- MAX_TOKENS = 220
10
  SYSTEM_MSG = (
11
  "You are SchoolSpirit AI, the friendly digital mascot for a company that "
12
- "provides on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
13
- "hardware for schools. Keep answers concise, upbeat, and age‑appropriate. "
14
- "If you don’t know, say so and suggest contacting a human. Never request "
15
  "personal data."
16
  )
17
 
 
18
  try:
19
- tok = AutoTokenizer.from_pretrained(MODEL_ID)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- MODEL_ID, device_map="auto", torch_dtype="auto"
 
 
22
  )
23
- gen = pipeline(
24
  "text-generation",
25
  model=model,
26
- tokenizer=tok,
27
  max_new_tokens=MAX_TOKENS,
28
  do_sample=True,
29
  temperature=0.7,
30
  )
31
- model_error = None
32
  except Exception as exc: # noqa: BLE001
33
- model_error = f"Model load error: {exc}"
34
- gen = None
 
35
 
 
36
  def chat(history, user_msg):
37
- if model_error:
38
- return history + [(user_msg, model_error)], ""
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if len(history) > MAX_TURNS:
41
  history = history[-MAX_TURNS:]
42
 
43
- prompt = SYSTEM_MSG + "\n"
 
44
  for u, a in history:
45
- prompt += f"User: {u}\nAI: {a}\n"
46
- prompt += f"User: {user_msg}\nAI:"
 
 
 
 
47
 
 
48
  try:
49
- completion = gen(prompt)[0]["generated_text"]
50
  reply = completion.split("AI:", 1)[-1].strip()
 
51
  except Exception as err: # noqa: BLE001
52
- reply = "Sorry, an internal error occurred. Please try again later."
53
- hf_logging.get_logger("SchoolSpirit").error(str(err))
 
 
 
54
 
55
  history.append((user_msg, reply))
56
  return history, ""
57
 
 
58
  gr.ChatInterface(
59
  chat,
60
  title="SchoolSpirit AI Chat",
61
- theme=gr.themes.Soft(primary_hue="blue"),
62
  ).launch()
 
1
+ """
2
+ SchoolSpirit AI – public chatbot Space
3
+ --------------------------------------
4
+ • Tiny Llama‑3 3 B model (fits HF CPU Space).
5
+ • Light‑blue Gradio chat widget.
6
+ • Robust error handling.
7
+ """
8
+
9
+ import re
10
  import gradio as gr
11
+ from transformers import (
12
+ AutoTokenizer,
13
+ AutoModelForCausalLM,
14
+ pipeline,
15
+ )
16
  from transformers.utils import logging as hf_logging
17
 
18
  hf_logging.set_verbosity_error()
19
+ LOGGER = hf_logging.get_logger("SchoolSpirit")
20
+
21
+ # ------------------------ Config -------------------------------------------
22
+ MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
23
+ MAX_TURNS = 6 # last N exchanges kept
24
+ MAX_TOKENS = 220 # response length
25
+ MAX_INPUT_CH = 500 # user message length guard
26
 
 
 
 
27
  SYSTEM_MSG = (
28
  "You are SchoolSpirit AI, the friendly digital mascot for a company that "
29
+ "offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
30
+ "hardware for schools. Keep answers concise, upbeat, and age‑appropriate. "
31
+ "If unsure, admit it and suggest contacting a human. Never request "
32
  "personal data."
33
  )
34
 
35
+ # ------------------------ Model Load ---------------------------------------
36
  try:
37
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ MODEL_ID,
40
+ device_map="auto",
41
+ torch_dtype="auto",
42
  )
43
+ generator = pipeline(
44
  "text-generation",
45
  model=model,
46
+ tokenizer=tokenizer,
47
  max_new_tokens=MAX_TOKENS,
48
  do_sample=True,
49
  temperature=0.7,
50
  )
51
+ MODEL_LOAD_ERR = None
52
  except Exception as exc: # noqa: BLE001
53
+ generator = None
54
+ MODEL_LOAD_ERR = f"Model load error: {exc}"
55
+ LOGGER.error(MODEL_LOAD_ERR)
56
 
57
+ # ------------------------ Chat Function ------------------------------------
58
  def chat(history, user_msg):
59
+ """Gradio ChatInterface callback using (history, user_msg) tuples."""
 
60
 
61
+ # Hard failure at Space startup
62
+ if MODEL_LOAD_ERR:
63
+ history.append((user_msg, MODEL_LOAD_ERR))
64
+ return history, ""
65
+
66
+ # Basic user‑input guardrails
67
+ user_msg = (user_msg or "").strip()
68
+ if not user_msg:
69
+ return history + [("", "Please enter a message.")], ""
70
+ if len(user_msg) > MAX_INPUT_CH:
71
+ history.append(
72
+ (user_msg, "Sorry, your message is too long. Please shorten it.")
73
+ )
74
+ return history, ""
75
+
76
+ # Keep only last MAX_TURNS
77
  if len(history) > MAX_TURNS:
78
  history = history[-MAX_TURNS:]
79
 
80
+ # Build prompt
81
+ prompt = [SYSTEM_MSG]
82
  for u, a in history:
83
+ prompt.append(f"User: {u}")
84
+ prompt.append(f"AI: {a}")
85
+ prompt.append(f"User: {user_msg}")
86
+ prompt.append("AI:")
87
+
88
+ prompt = "\n".join(prompt)
89
 
90
+ # Generate reply
91
  try:
92
+ completion = generator(prompt, truncate=2048)[0]["generated_text"]
93
  reply = completion.split("AI:", 1)[-1].strip()
94
+ reply = re.sub(r"\s+", " ", reply) # collapse excess whitespace
95
  except Exception as err: # noqa: BLE001
96
+ LOGGER.error(f"Inference error: {err}")
97
+ reply = (
98
+ "Sorry, something went wrong on my end. "
99
+ "Please try again in a few seconds."
100
+ )
101
 
102
  history.append((user_msg, reply))
103
  return history, ""
104
 
105
+ # ------------------------ UI -----------------------------------------------
106
  gr.ChatInterface(
107
  chat,
108
  title="SchoolSpirit AI Chat",
109
+ theme=gr.themes.Soft(primary_hue="blue"), # light‑blue look
110
  ).launch()