phanerozoic commited on
Commit
ebc65f6
Β·
verified Β·
1 Parent(s): bd8cd8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -49
app.py CHANGED
@@ -1,9 +1,10 @@
1
  """
2
- SchoolSpiritΒ AI – public chatbot Space
3
- --------------------------------------
4
- β€’ Tiny Llama‑3Β 3Β B model (fits HF CPU Space).
5
- β€’ Light‑blue Gradio chat widget.
6
- β€’ Robust error handling.
 
7
  """
8
 
9
  import re
@@ -15,96 +16,111 @@ from transformers import (
15
  )
16
  from transformers.utils import logging as hf_logging
17
 
 
18
  hf_logging.set_verbosity_error()
19
- LOGGER = hf_logging.get_logger("SchoolSpirit")
20
 
21
- # ------------------------ Config -------------------------------------------
22
  MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
23
- MAX_TURNS = 6 # last N exchanges kept
24
- MAX_TOKENS = 220 # response length
25
- MAX_INPUT_CH = 500 # user message length guard
26
 
27
  SYSTEM_MSG = (
28
  "You are SchoolSpiritΒ AI, the friendly digital mascot for a company that "
29
  "offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
30
  "hardware for schools. Keep answers concise, upbeat, and age‑appropriate. "
31
- "If unsure, admit it and suggest contacting a human. Never request "
32
- "personal data."
33
  )
34
 
35
- # ------------------------ Model Load ---------------------------------------
36
  try:
37
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
38
- model = AutoModelForCausalLM.from_pretrained(
39
- MODEL_ID,
40
- device_map="auto",
41
- torch_dtype="auto",
42
  )
43
  generator = pipeline(
44
  "text-generation",
45
  model=model,
46
- tokenizer=tokenizer,
47
  max_new_tokens=MAX_TOKENS,
48
  do_sample=True,
49
  temperature=0.7,
50
  )
51
  MODEL_LOAD_ERR = None
52
  except Exception as exc: # noqa: BLE001
53
- generator = None
54
  MODEL_LOAD_ERR = f"Model load error: {exc}"
55
- LOGGER.error(MODEL_LOAD_ERR)
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # ------------------------ Chat Function ------------------------------------
58
  def chat(history, user_msg):
59
- """Gradio ChatInterface callback using (history, user_msg) tuples."""
 
60
 
61
- # Hard failure at Space startup
62
  if MODEL_LOAD_ERR:
63
  history.append((user_msg, MODEL_LOAD_ERR))
64
  return history, ""
65
 
66
- # Basic user‑input guardrails
67
  user_msg = (user_msg or "").strip()
68
  if not user_msg:
69
- return history + [("", "Please enter a message.")], ""
 
70
  if len(user_msg) > MAX_INPUT_CH:
71
  history.append(
72
- (user_msg, "Sorry, your message is too long. Please shorten it.")
73
  )
74
  return history, ""
75
 
76
- # Keep only last MAX_TURNS
77
- if len(history) > MAX_TURNS:
78
- history = history[-MAX_TURNS:]
79
 
80
  # Build prompt
81
- prompt = [SYSTEM_MSG]
82
  for u, a in history:
83
- prompt.append(f"User: {u}")
84
- prompt.append(f"AI: {a}")
85
- prompt.append(f"User: {user_msg}")
86
- prompt.append("AI:")
87
-
88
- prompt = "\n".join(prompt)
89
 
90
- # Generate reply
91
  try:
92
  completion = generator(prompt, truncate=2048)[0]["generated_text"]
93
- reply = completion.split("AI:", 1)[-1].strip()
94
- reply = re.sub(r"\s+", " ", reply) # collapse excess whitespace
95
  except Exception as err: # noqa: BLE001
96
- LOGGER.error(f"Inference error: {err}")
97
  reply = (
98
- "Sorry, something went wrong on my end. "
99
- "Please try again in a few seconds."
100
  )
101
 
102
  history.append((user_msg, reply))
103
  return history, ""
104
 
105
- # ------------------------ UI -----------------------------------------------
106
- gr.ChatInterface(
107
- chat,
108
- title="SchoolSpiritΒ AI Chat",
109
- theme=gr.themes.Soft(primary_hue="blue"), # light‑blue look
110
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ SchoolSpiritΒ AI – hardened chatbot Space with β€œClear chat” button
3
+ ----------------------------------------------------------------
4
+ β€’ Loads MetaΒ Llama‑3Β 3Β B‑Instruct via transformers.
5
+ β€’ Keeps last MAX_TURNS exchanges to fit context.
6
+ β€’ Adds a β€œClear chat” button that resets history.
7
+ β€’ Extensive try/except blocks to prevent widget crashes.
8
  """
9
 
10
  import re
 
16
  )
17
  from transformers.utils import logging as hf_logging
18
 
19
+ # ────────────────────────── Config ──────────────────────────────────────────
20
  hf_logging.set_verbosity_error()
21
+ LOG = hf_logging.get_logger("SchoolSpirit")
22
 
 
23
  MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
24
+ MAX_TURNS = 6
25
+ MAX_TOKENS = 220
26
+ MAX_INPUT_CH = 500
27
 
28
  SYSTEM_MSG = (
29
  "You are SchoolSpiritΒ AI, the friendly digital mascot for a company that "
30
  "offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
31
  "hardware for schools. Keep answers concise, upbeat, and age‑appropriate. "
32
+ "If unsure, say so and suggest contacting a human. Never request personal data."
 
33
  )
34
 
35
+ # ────────────────────────── Model Load ──────────────────────────────────────
36
  try:
37
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ MODEL_ID, device_map="auto", torch_dtype="auto"
 
 
40
  )
41
  generator = pipeline(
42
  "text-generation",
43
  model=model,
44
+ tokenizer=tok,
45
  max_new_tokens=MAX_TOKENS,
46
  do_sample=True,
47
  temperature=0.7,
48
  )
49
  MODEL_LOAD_ERR = None
50
  except Exception as exc: # noqa: BLE001
 
51
  MODEL_LOAD_ERR = f"Model load error: {exc}"
52
+ generator = None
53
+ LOG.error(MODEL_LOAD_ERR)
54
+
55
+ # ────────────────────────── Helpers ────────────────────────────────────────
56
+ def truncate_history(hist, max_turns):
57
+ """Return only the last `max_turns` (user,bot) pairs."""
58
+ return hist[-max_turns:] if len(hist) > max_turns else hist
59
+
60
+
61
+ def safe_reply(msg: str) -> str:
62
+ """Post‑process model output and ensure it is non‑empty."""
63
+ msg = msg.strip()
64
+ msg = re.sub(r"\s+", " ", msg)
65
+ return msg or "…"
66
 
67
+ # ────────────────────────── Chat Callback ───────────────────────────────────
68
  def chat(history, user_msg):
69
+ # Gradio guarantees history is a list of tuples (user, bot)
70
+ history = list(history)
71
 
72
+ # Start‑up failure fallback
73
  if MODEL_LOAD_ERR:
74
  history.append((user_msg, MODEL_LOAD_ERR))
75
  return history, ""
76
 
77
+ # Basic guards
78
  user_msg = (user_msg or "").strip()
79
  if not user_msg:
80
+ history.append(("", "Please enter a message."))
81
+ return history, ""
82
  if len(user_msg) > MAX_INPUT_CH:
83
  history.append(
84
+ (user_msg, "Sorry, that message is too long. Please shorten it.")
85
  )
86
  return history, ""
87
 
88
+ history = truncate_history(history, MAX_TURNS)
 
 
89
 
90
  # Build prompt
91
+ prompt_parts = [SYSTEM_MSG]
92
  for u, a in history:
93
+ prompt_parts += [f"User: {u}", f"AI: {a}"]
94
+ prompt_parts += [f"User: {user_msg}", "AI:"]
95
+ prompt = "\n".join(prompt_parts)
 
 
 
96
 
97
+ # Generate
98
  try:
99
  completion = generator(prompt, truncate=2048)[0]["generated_text"]
100
+ reply = safe_reply(completion.split("AI:", 1)[-1])
 
101
  except Exception as err: # noqa: BLE001
102
+ LOG.error(f"Inference error: {err}")
103
  reply = (
104
+ "Sorryβ€”I'm having trouble right now. Please try again in a moment."
 
105
  )
106
 
107
  history.append((user_msg, reply))
108
  return history, ""
109
 
110
+ # ────────────────────────── Clear Chat fn ───────────────────────────────────
111
+ def clear_chat():
112
+ return [], ""
113
+
114
+ # ────────────────────────── UI Launch ───────────────────────────────────────
115
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
116
+ gr.Markdown("# SchoolSpiritΒ AI Chat")
117
+ chatbot = gr.Chatbot()
118
+ msg_in = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AI…")
119
+ send_btn = gr.Button("Send")
120
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
121
+
122
+ send_btn.click(chat, [chatbot, msg_in], [chatbot, msg_in])
123
+ msg_in.submit(chat, [chatbot, msg_in], [chatbot, msg_in])
124
+ clear_btn.click(clear_chat, outputs=[chatbot, msg_in])
125
+
126
+ demo.launch()