phanerozoic commited on
Commit
9007cad
·
verified ·
1 Parent(s): 9843b35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -32
app.py CHANGED
@@ -3,80 +3,133 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
5
 
 
6
  # Persistent cache + request log
 
7
  os.environ["HF_HOME"] = "/data/.huggingface"
8
  LOG_FILE = "/data/requests.log"
9
- def log(msg):
 
 
10
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
11
  line = f"[{ts}] {msg}"
12
  print(line, flush=True)
13
- try:
14
- with open(LOG_FILE, "a") as f: f.write(line + "\n")
15
- except FileNotFoundError: pass
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Config
18
- MODEL_ID, MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = (
19
- "ibm-granite/granite-3.3-2b-instruct", 4, 64, 300
 
 
 
 
 
 
 
 
 
20
  )
21
- SYSTEM_MSG = ("You are SchoolSpirit AI, the upbeat mascot for a company that "
22
- "installs on‑prem AI chatbots in schools. Keep answers short, "
23
- "friendly, and safe.")
24
 
 
 
 
25
  # Load model
 
26
  hf_logging.set_verbosity_error()
27
  try:
28
  log("Loading model …")
29
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
30
  model = AutoModelForCausalLM.from_pretrained(
31
- MODEL_ID, device_map="auto", torch_dtype="auto")
32
- gen = pipeline("text-generation", model=model, tokenizer=tok,
33
- max_new_tokens=MAX_TOKENS, do_sample=True, temperature=0.6)
 
 
 
 
 
 
 
34
  MODEL_ERR = None
35
  log("Model loaded ✔")
36
- except Exception as exc:
37
  MODEL_ERR, gen = f"Model load error: {exc}", None
38
  log(MODEL_ERR)
39
 
40
  clean = lambda t: re.sub(r"\s+", " ", t.strip()) or "…"
41
- trim = lambda m: m if len(m)<=1+MAX_TURNS*2 else [m[0]]+m[-MAX_TURNS*2:]
42
 
 
43
  # Chat logic
44
- def chat_fn(user_msg, history):
 
 
 
45
  log(f"User sent {len(user_msg)} chars")
46
- if not history or history[0]["role"]!="system":
47
- history=[{"role":"system","content":SYSTEM_MSG}]
48
- if MODEL_ERR: return MODEL_ERR
 
 
 
 
 
 
49
 
50
  user_msg = clean(user_msg or "")
51
- if not user_msg: return "Please type something."
52
- if len(user_msg)>MAX_INPUT_CH:
 
53
  return f"Message too long (>{MAX_INPUT_CH} chars)."
54
 
55
- history.append({"role":"user","content":user_msg})
56
  history = trim(history)
57
 
58
- prompt_lines=[m["content"] if m["role"]=="system"
59
- else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
60
- for m in history]+["AI:"]
 
 
 
61
  prompt = "\n".join(prompt_lines)
62
  log(f"Prompt {len(prompt)} chars → generating")
63
 
64
- t0=time.time()
65
  try:
66
- raw = gen(prompt)[0]["generated_text"]
67
- reply = clean(raw.split("AI:",1)[-1])
68
- reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip() # ← cut here
69
  log(f"generate() {time.time()-t0:.2f}s, reply {len(reply)} chars")
70
  except Exception:
71
- log("❌ Inference exception:\n"+traceback.format_exc())
72
- reply="Sorry—backend crashed. Please try again later."
73
 
74
  return reply
75
 
 
 
76
  # UI
 
77
  gr.ChatInterface(
78
  fn=chat_fn,
79
- chatbot=gr.Chatbot(height=480, type="messages"),
 
 
 
 
80
  title="SchoolSpirit AI Chat",
81
  theme=gr.themes.Soft(primary_hue="blue"),
82
  type="messages",
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
5
 
6
+ # ---------------------------------------------------------------------------
7
  # Persistent cache + request log
8
+ # ---------------------------------------------------------------------------
9
  os.environ["HF_HOME"] = "/data/.huggingface"
10
  LOG_FILE = "/data/requests.log"
11
+
12
+
13
+ def log(msg: str):
14
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
15
  line = f"[{ts}] {msg}"
16
  print(line, flush=True)
17
+ try:
18
+ with open(LOG_FILE, "a") as f:
19
+ f.write(line + "\n")
20
+ except FileNotFoundError:
21
+ pass
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Configuration
26
+ # ---------------------------------------------------------------------------
27
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2‑B fits HF CPU Space
28
+ MAX_TURNS = 4 # keep last N user/AI pairs
29
+ MAX_TOKENS = 64
30
+ MAX_INPUT_CH = 300
31
 
32
+ SYSTEM_MSG = (
33
+ "You are **SchoolSpirit AI**, the official digital mascot for "
34
+ "SchoolSpirit AI LLC, founded by Charles Norton in 2025. The company "
35
+ "specializes in on‑prem AI chat mascots, custom fine‑tuning of language "
36
+ "models, and turnkey GPU servers for K‑12 schools and education vendors.\n\n"
37
+ "GUIDELINES:\n"
38
+ "• Respond in a warm, encouraging tone suitable for students, parents, "
39
+ "and staff.\n"
40
+ "• Keep answers concise (≤ 4 sentences) unless asked for detail.\n"
41
+ "• If unsure or out of scope, say you’re not sure and offer human follow‑up.\n"
42
+ "• No personal data collection, no medical/legal/financial advice.\n"
43
+ "• Maintain professionalism—no profanity, politics, or mature themes."
44
  )
 
 
 
45
 
46
+ WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
47
+
48
+ # ---------------------------------------------------------------------------
49
  # Load model
50
+ # ---------------------------------------------------------------------------
51
  hf_logging.set_verbosity_error()
52
  try:
53
  log("Loading model …")
54
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
55
  model = AutoModelForCausalLM.from_pretrained(
56
+ MODEL_ID, device_map="auto", torch_dtype="auto"
57
+ )
58
+ gen = pipeline(
59
+ "text-generation",
60
+ model=model,
61
+ tokenizer=tok,
62
+ max_new_tokens=MAX_TOKENS,
63
+ do_sample=True,
64
+ temperature=0.6,
65
+ )
66
  MODEL_ERR = None
67
  log("Model loaded ✔")
68
+ except Exception as exc: # noqa: BLE001
69
  MODEL_ERR, gen = f"Model load error: {exc}", None
70
  log(MODEL_ERR)
71
 
72
  clean = lambda t: re.sub(r"\s+", " ", t.strip()) or "…"
73
+ trim = lambda m: m if len(m) <= 1 + MAX_TURNS * 2 else [m[0]] + m[-MAX_TURNS * 2 :]
74
 
75
+ # ---------------------------------------------------------------------------
76
  # Chat logic
77
+ # ---------------------------------------------------------------------------
78
+
79
+
80
+ def chat_fn(user_msg: str, history: list):
81
  log(f"User sent {len(user_msg)} chars")
82
+ # Seed system + welcome messages on first call
83
+ if not history or history[0]["role"] != "system":
84
+ history = [
85
+ {"role": "system", "content": SYSTEM_MSG},
86
+ {"role": "assistant", "content": WELCOME_MSG},
87
+ ]
88
+
89
+ if MODEL_ERR:
90
+ return MODEL_ERR
91
 
92
  user_msg = clean(user_msg or "")
93
+ if not user_msg:
94
+ return "Please type something."
95
+ if len(user_msg) > MAX_INPUT_CH:
96
  return f"Message too long (>{MAX_INPUT_CH} chars)."
97
 
98
+ history.append({"role": "user", "content": user_msg})
99
  history = trim(history)
100
 
101
+ prompt_lines = [
102
+ m["content"]
103
+ if m["role"] == "system"
104
+ else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
105
+ for m in history
106
+ ] + ["AI:"]
107
  prompt = "\n".join(prompt_lines)
108
  log(f"Prompt {len(prompt)} chars → generating")
109
 
110
+ t0 = time.time()
111
  try:
112
+ raw = gen(prompt)[0]["generated_text"]
113
+ reply = clean(raw.split("AI:", 1)[-1])
114
+ reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
115
  log(f"generate() {time.time()-t0:.2f}s, reply {len(reply)} chars")
116
  except Exception:
117
+ log("❌ Inference exception:\n" + traceback.format_exc())
118
+ reply = "Sorry—backend crashed. Please try again later."
119
 
120
  return reply
121
 
122
+
123
+ # ---------------------------------------------------------------------------
124
  # UI
125
+ # ---------------------------------------------------------------------------
126
  gr.ChatInterface(
127
  fn=chat_fn,
128
+ chatbot=gr.Chatbot(
129
+ height=480,
130
+ type="messages",
131
+ value=[("", WELCOME_MSG)], # pre-populate AI welcome bubble
132
+ ),
133
  title="SchoolSpirit AI Chat",
134
  theme=gr.themes.Soft(primary_hue="blue"),
135
  type="messages",