phanerozoic commited on
Commit
6f67928
Β·
verified Β·
1 Parent(s): e34a054

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -24
app.py CHANGED
@@ -1,10 +1,10 @@
1
  """
2
  SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot Space
3
  ----------------------------------------------
4
- β€’ Uses IBM Granite‑3.3‑2B‑Instruct (Apache‑2).
5
- β€’ Keeps last MAX_TURNS exchanges.
6
- β€’ β€œClear chat” button resets context.
7
- β€’ Robust error handling & logging.
8
  """
9
 
10
  import re
@@ -16,27 +16,29 @@ from transformers import (
16
  )
17
  from transformers.utils import logging as hf_logging
18
 
19
- # ─────────────── Config ────────────────────────────────────────────────────
20
  hf_logging.set_verbosity_error()
21
  LOG = hf_logging.get_logger("SchoolSpirit")
22
 
23
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
24
- MAX_TURNS = 6
25
- MAX_TOKENS = 200
26
- MAX_INPUT_CH = 400
27
 
28
  SYSTEM_MSG = (
29
  "You are SchoolSpiritΒ AI, the upbeat digital mascot for a company that "
30
  "offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
31
  "hardware for schools. Answer concisely and age‑appropriately. If unsure, "
32
- "say so and suggest contacting a human. Do not ask for personal data."
33
  )
34
 
35
- # ─────────────── Model Load ────────────────────────────────────────────────
36
  try:
37
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
38
  model = AutoModelForCausalLM.from_pretrained(
39
- MODEL_ID, device_map="auto", torch_dtype="auto"
 
 
40
  )
41
  generator = pipeline(
42
  "text-generation",
@@ -52,18 +54,35 @@ except Exception as exc: # noqa: BLE001
52
  generator = None
53
  LOG.error(MODEL_ERR)
54
 
55
- # ─────────────── Helpers ───────────────────────────────────────────────────
56
  def truncate(hist):
 
57
  return hist[-MAX_TURNS:] if len(hist) > MAX_TURNS else hist
58
 
59
 
60
  def clean(text: str) -> str:
 
61
  return re.sub(r"\s+", " ", text.strip()) or "…"
62
 
63
- # ─────────────── Chat Callback ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def chat(history, user_msg):
65
- history = list(history)
66
 
 
67
  if MODEL_ERR:
68
  history.append((user_msg, MODEL_ERR))
69
  return history, ""
@@ -73,35 +92,33 @@ def chat(history, user_msg):
73
  history.append(("", "Please enter a message."))
74
  return history, ""
75
  if len(user_msg) > MAX_INPUT_CH:
76
- history.append((user_msg, "That message is too long."))
 
 
77
  return history, ""
78
 
79
  history = truncate(history)
80
 
 
81
  prompt_lines = [SYSTEM_MSG]
82
  for u, a in history:
83
  prompt_lines += [f"User: {u}", f"AI: {a}"]
84
  prompt_lines += [f"User: {user_msg}", "AI:"]
85
  prompt = "\n".join(prompt_lines)
86
 
87
- try:
88
- completion = generator(prompt)[0]["generated_text"]
89
- reply = clean(completion.split("AI:", 1)[-1])
90
- except Exception as err: # noqa: BLE001
91
- LOG.error(f"Inference error: {err}")
92
- reply = "Sorryβ€”I'm having trouble right now. Please try again shortly."
93
 
94
  history.append((user_msg, reply))
95
  return history, ""
96
 
97
- # ─────────────── Clear Chat ────────────────────────────────────────────────
98
  def clear_chat():
99
  return [], ""
100
 
101
- # ─────────────── UI Launch ────────────────────────────────────────────────
102
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
103
  gr.Markdown("# SchoolSpiritΒ AI Chat")
104
- chatbot = gr.Chatbot(type="tuple") # legacy tuple format
105
  msg_box = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AI…")
106
  send_btn = gr.Button("Send")
107
  clear_btn = gr.Button("Clear Chat", variant="secondary")
 
1
  """
2
  SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot Space
3
  ----------------------------------------------
4
+ β€’ IBM Granite‑3.3‑2B‑Instruct (Apache‑2), runs in HF CPU Space.
5
+ β€’ Keeps last MAX_TURNS exchanges to fit context.
6
+ β€’ β€œClearΒ Chat” button resets conversation.
7
+ β€’ Extensive error‑handling: model‑load, inference, bad input.
8
  """
9
 
10
  import re
 
16
  )
17
  from transformers.utils import logging as hf_logging
18
 
19
+ # ────────── Configuration ───────────────────────────────────────────────────
20
  hf_logging.set_verbosity_error()
21
  LOG = hf_logging.get_logger("SchoolSpirit")
22
 
23
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
24
+ MAX_TURNS = 6 # history turns to keep
25
+ MAX_TOKENS = 200 # response length
26
+ MAX_INPUT_CH = 400 # user message length guard
27
 
28
  SYSTEM_MSG = (
29
  "You are SchoolSpiritΒ AI, the upbeat digital mascot for a company that "
30
  "offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
31
  "hardware for schools. Answer concisely and age‑appropriately. If unsure, "
32
+ "say so and suggest contacting a human. Never request personal data."
33
  )
34
 
35
+ # ────────── Model loading with fail‑safe ────────────────────────────────────
36
  try:
37
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
38
  model = AutoModelForCausalLM.from_pretrained(
39
+ MODEL_ID,
40
+ device_map="auto",
41
+ torch_dtype="auto",
42
  )
43
  generator = pipeline(
44
  "text-generation",
 
54
  generator = None
55
  LOG.error(MODEL_ERR)
56
 
57
+ # ────────── Helper utilities ────────────────────────────────────────────────
58
  def truncate(hist):
59
+ """Return last MAX_TURNS (user,bot) tuples."""
60
  return hist[-MAX_TURNS:] if len(hist) > MAX_TURNS else hist
61
 
62
 
63
  def clean(text: str) -> str:
64
+ """Normalize whitespace and guarantee non‑empty."""
65
  return re.sub(r"\s+", " ", text.strip()) or "…"
66
 
67
+
68
+ def safe_generate(prompt: str) -> str:
69
+ """Call model.generate, catch & log any error, always return a string."""
70
+ try:
71
+ completion = generator(prompt)[0]["generated_text"]
72
+ reply = clean(completion.split("AI:", 1)[-1])
73
+ except Exception as err: # noqa: BLE001
74
+ LOG.error(f"Inference error: {err}")
75
+ reply = (
76
+ "Sorryβ€”I'm having trouble right now. "
77
+ "Please try again in a moment."
78
+ )
79
+ return reply
80
+
81
+ # ────────── Chat callback ───────────────────────────────────────────────────
82
  def chat(history, user_msg):
83
+ history = list(history) # guaranteed list of tuples
84
 
85
+ # Fatal start‑up failure
86
  if MODEL_ERR:
87
  history.append((user_msg, MODEL_ERR))
88
  return history, ""
 
92
  history.append(("", "Please enter a message."))
93
  return history, ""
94
  if len(user_msg) > MAX_INPUT_CH:
95
+ history.append(
96
+ (user_msg, f"Message too long (>{MAX_INPUT_CH} chars).")
97
+ )
98
  return history, ""
99
 
100
  history = truncate(history)
101
 
102
+ # Build prompt
103
  prompt_lines = [SYSTEM_MSG]
104
  for u, a in history:
105
  prompt_lines += [f"User: {u}", f"AI: {a}"]
106
  prompt_lines += [f"User: {user_msg}", "AI:"]
107
  prompt = "\n".join(prompt_lines)
108
 
109
+ reply = safe_generate(prompt)
 
 
 
 
 
110
 
111
  history.append((user_msg, reply))
112
  return history, ""
113
 
114
+ # ────────── Clear chat callback ─────────────────────────────────────────────
115
  def clear_chat():
116
  return [], ""
117
 
118
+ # ────────── UI definition ───────────────────────────────────────────────────
119
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
120
  gr.Markdown("# SchoolSpiritΒ AI Chat")
121
+ chatbot = gr.Chatbot(type="tuples")
122
  msg_box = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AI…")
123
  send_btn = gr.Button("Send")
124
  clear_btn = gr.Button("Clear Chat", variant="secondary")