phanerozoic commited on
Commit
747a64c
Β·
verified Β·
1 Parent(s): ebc65f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -51
app.py CHANGED
@@ -1,10 +1,11 @@
1
  """
2
- SchoolSpiritΒ AI – hardened chatbot Space with β€œClear chat” button
3
- ----------------------------------------------------------------
4
- β€’ Loads MetaΒ Llama‑3Β 3Β B‑Instruct via transformers.
5
- β€’ Keeps last MAX_TURNS exchanges to fit context.
6
- β€’ Adds a β€œClear chat” button that resets history.
7
- β€’ Extensive try/except blocks to prevent widget crashes.
 
8
  """
9
 
10
  import re
@@ -20,23 +21,25 @@ from transformers.utils import logging as hf_logging
20
  hf_logging.set_verbosity_error()
21
  LOG = hf_logging.get_logger("SchoolSpirit")
22
 
23
- MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
24
  MAX_TURNS = 6
25
- MAX_TOKENS = 220
26
- MAX_INPUT_CH = 500
27
 
28
  SYSTEM_MSG = (
29
- "You are SchoolSpiritΒ AI, the friendly digital mascot for a company that "
30
  "offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
31
- "hardware for schools. Keep answers concise, upbeat, and age‑appropriate. "
32
- "If unsure, say so and suggest contacting a human. Never request personal data."
33
  )
34
 
35
  # ────────────────────────── Model Load ──────────────────────────────────────
36
  try:
37
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
38
  model = AutoModelForCausalLM.from_pretrained(
39
- MODEL_ID, device_map="auto", torch_dtype="auto"
 
 
40
  )
41
  generator = pipeline(
42
  "text-generation",
@@ -46,81 +49,72 @@ try:
46
  do_sample=True,
47
  temperature=0.7,
48
  )
49
- MODEL_LOAD_ERR = None
50
  except Exception as exc: # noqa: BLE001
51
- MODEL_LOAD_ERR = f"Model load error: {exc}"
52
  generator = None
53
- LOG.error(MODEL_LOAD_ERR)
54
 
55
  # ────────────────────────── Helpers ────────────────────────────────────────
56
- def truncate_history(hist, max_turns):
57
- """Return only the last `max_turns` (user,bot) pairs."""
58
- return hist[-max_turns:] if len(hist) > max_turns else hist
59
 
60
 
61
- def safe_reply(msg: str) -> str:
62
- """Post‑process model output and ensure it is non‑empty."""
63
- msg = msg.strip()
64
- msg = re.sub(r"\s+", " ", msg)
65
- return msg or "…"
66
 
67
  # ────────────────────────── Chat Callback ───────────────────────────────────
68
  def chat(history, user_msg):
69
- # Gradio guarantees history is a list of tuples (user, bot)
70
- history = list(history)
71
 
72
- # Start‑up failure fallback
73
- if MODEL_LOAD_ERR:
74
- history.append((user_msg, MODEL_LOAD_ERR))
75
  return history, ""
76
 
77
- # Basic guards
78
- user_msg = (user_msg or "").strip()
79
  if not user_msg:
80
  history.append(("", "Please enter a message."))
81
  return history, ""
82
  if len(user_msg) > MAX_INPUT_CH:
83
- history.append(
84
- (user_msg, "Sorry, that message is too long. Please shorten it.")
85
- )
86
  return history, ""
87
 
88
- history = truncate_history(history, MAX_TURNS)
89
 
90
  # Build prompt
91
- prompt_parts = [SYSTEM_MSG]
92
  for u, a in history:
93
- prompt_parts += [f"User: {u}", f"AI: {a}"]
94
- prompt_parts += [f"User: {user_msg}", "AI:"]
95
- prompt = "\n".join(prompt_parts)
96
 
97
- # Generate
98
  try:
99
- completion = generator(prompt, truncate=2048)[0]["generated_text"]
100
- reply = safe_reply(completion.split("AI:", 1)[-1])
101
  except Exception as err: # noqa: BLE001
102
  LOG.error(f"Inference error: {err}")
103
- reply = (
104
- "Sorryβ€”I'm having trouble right now. Please try again in a moment."
105
- )
106
 
107
  history.append((user_msg, reply))
108
  return history, ""
109
 
110
- # ────────────────────────── Clear Chat fn ───────────────────────────────────
111
  def clear_chat():
112
  return [], ""
113
 
114
  # ────────────────────────── UI Launch ───────────────────────────────────────
115
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
116
  gr.Markdown("# SchoolSpiritΒ AI Chat")
117
- chatbot = gr.Chatbot()
118
- msg_in = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AI…")
119
  send_btn = gr.Button("Send")
120
  clear_btn = gr.Button("Clear Chat", variant="secondary")
121
 
122
- send_btn.click(chat, [chatbot, msg_in], [chatbot, msg_in])
123
- msg_in.submit(chat, [chatbot, msg_in], [chatbot, msg_in])
124
- clear_btn.click(clear_chat, outputs=[chatbot, msg_in])
125
 
126
  demo.launch()
 
1
  """
2
+ SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot Space
3
+ ----------------------------------------------
4
+ β€’ Uses IBM Granite‑3.3‑2B‑Instruct (public, no access token).
5
+ β€’ Fits HF CPU Space (2‑B params, bfloat16).
6
+ β€’ Keeps last MAX_TURNS exchanges.
7
+ β€’ β€œClear chat” button resets context.
8
+ β€’ Robust error handling & logging.
9
  """
10
 
11
  import re
 
21
  hf_logging.set_verbosity_error()
22
  LOG = hf_logging.get_logger("SchoolSpirit")
23
 
24
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
25
  MAX_TURNS = 6
26
+ MAX_TOKENS = 200
27
+ MAX_INPUT_CH = 400
28
 
29
  SYSTEM_MSG = (
30
+ "You are SchoolSpiritΒ AI, the upbeat digital mascot for a company that "
31
  "offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
32
+ "hardware for schools. Answer concisely and age‑appropriately. If unsure, "
33
+ "say so and suggest contacting a human. Do not ask for personal data."
34
  )
35
 
36
  # ────────────────────────── Model Load ──────────────────────────────────────
37
  try:
38
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
39
  model = AutoModelForCausalLM.from_pretrained(
40
+ MODEL_ID,
41
+ device_map="auto",
42
+ torch_dtype="auto", # bfloat16/float16 under the hood
43
  )
44
  generator = pipeline(
45
  "text-generation",
 
49
  do_sample=True,
50
  temperature=0.7,
51
  )
52
+ MODEL_ERR = None
53
  except Exception as exc: # noqa: BLE001
54
+ MODEL_ERR = f"Model load error: {exc}"
55
  generator = None
56
+ LOG.error(MODEL_ERR)
57
 
58
  # ────────────────────────── Helpers ────────────────────────────────────────
59
+ def truncate(hist):
60
+ """Return last MAX_TURNS (u,a) pairs."""
61
+ return hist[-MAX_TURNS:] if len(hist) > MAX_TURNS else hist
62
 
63
 
64
+ def clean(text: str) -> str:
65
+ """Collapse whitespace; never return empty string."""
66
+ out = re.sub(r"\s+", " ", text.strip())
67
+ return out or "…"
 
68
 
69
  # ────────────────────────── Chat Callback ───────────────────────────────────
70
  def chat(history, user_msg):
71
+ history = list(history) # Gradio ensures list of tuples
 
72
 
73
+ if MODEL_ERR:
74
+ history.append((user_msg, MODEL_ERR))
 
75
  return history, ""
76
 
77
+ user_msg = clean(user_msg or "")
 
78
  if not user_msg:
79
  history.append(("", "Please enter a message."))
80
  return history, ""
81
  if len(user_msg) > MAX_INPUT_CH:
82
+ history.append((user_msg, "That message is too long."))
 
 
83
  return history, ""
84
 
85
+ history = truncate(history)
86
 
87
  # Build prompt
88
+ prompt_lines = [SYSTEM_MSG]
89
  for u, a in history:
90
+ prompt_lines += [f"User: {u}", f"AI: {a}"]
91
+ prompt_lines += [f"User: {user_msg}", "AI:"]
92
+ prompt = "\n".join(prompt_lines)
93
 
 
94
  try:
95
+ completion = generator(prompt, truncate=4096)[0]["generated_text"]
96
+ reply = clean(completion.split("AI:", 1)[-1])
97
  except Exception as err: # noqa: BLE001
98
  LOG.error(f"Inference error: {err}")
99
+ reply = "Sorryβ€”I'm having trouble right now. Please try again shortly."
 
 
100
 
101
  history.append((user_msg, reply))
102
  return history, ""
103
 
104
+ # ────────────────────────── Clear Chat ──────────────────────────────────────
105
  def clear_chat():
106
  return [], ""
107
 
108
  # ────────────────────────── UI Launch ───────────────────────────────────────
109
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
110
  gr.Markdown("# SchoolSpiritΒ AI Chat")
111
+ chatbot = gr.Chatbot()
112
+ msg_box = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AI…")
113
  send_btn = gr.Button("Send")
114
  clear_btn = gr.Button("Clear Chat", variant="secondary")
115
 
116
+ send_btn.click(chat, [chatbot, msg_box], [chatbot, msg_box])
117
+ msg_box.submit(chat, [chatbot, msg_box], [chatbot, msg_box])
118
+ clear_btn.click(clear_chat, outputs=[chatbot, msg_box])
119
 
120
  demo.launch()