phanerozoic commited on
Commit
9843b35
Β·
verified Β·
1 Parent(s): 318dc96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -101
app.py CHANGED
@@ -1,140 +1,83 @@
1
- """
2
- SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot (GradioΒ 4.3, messages API)
3
- ────────────────────────────────────────────────────────────────────
4
- β€’ Persistent HF cache: HF_HOME=/data/.huggingface (25Β GB tier)
5
- β€’ Persistent request log: /data/requests.log
6
- β€’ Detailed system prompt (brand + guardrails)
7
- β€’ Traces every request: Received β†’ Prompt β†’ generate() timing
8
- β€’ Cleans replies & removes any stray β€œUser:” / β€œAI:” echoes
9
- """
10
-
11
- # ──────────────────── standard libraries ───────────────────────────────────
12
- from __future__ import annotations
13
  import os, re, time, datetime, traceback
14
- # ───── gradio + hf transformers ────────────────────────────────────────────
15
  import gradio as gr
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
  from transformers.utils import logging as hf_logging
18
 
19
- # ──────────────────── persistent disk paths ────────────────────────────────
20
- os.environ["HF_HOME"] = "/data/.huggingface" # model / tokenizer cache
21
- LOG_FILE = "/data/requests.log" # simple persistent log
22
-
23
- def log(msg: str) -> None:
24
- """Print + append to /data/requests.log with UTC timestamp."""
25
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
26
  line = f"[{ts}] {msg}"
27
  print(line, flush=True)
28
- try: # ignore first‑run errors
29
- with open(LOG_FILE, "a") as f:
30
- f.write(line + "\n")
31
- except FileNotFoundError:
32
- pass
33
-
34
- # ──────────────────── chatbot configuration ────────────────────────────────
35
- MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2Β B params, Apache‑2
36
- MAX_TURNS = 6 # keep last N user/assistant pairs
37
- MAX_TOKENS = 128 # reply length (raise if you have patience)
38
- MAX_INPUT_CH = 400 # user message length guard
39
 
40
- SYSTEM_MSG = (
41
- "You are **SchoolSpiritΒ AI**, the friendly digital mascot for a company "
42
- "that provides on‑prem AI chat mascots, fine‑tuning services, and turnkey "
43
- "GPU hardware for schools.\n\n"
44
- "β€’ Keep answers concise, upbeat, and age‑appropriate (K‑12).\n"
45
- "β€’ If you are unsure, say so and suggest contacting a human staff member.\n"
46
- "β€’ Never request personal data beyond an email if the user volunteers it.\n"
47
- "β€’ Do **not** provide medical, legal, or financial advice.\n"
48
- "β€’ No politics, mature content, or profanity.\n"
49
- "Respond in a friendly, encouraging toneβ€”as a helpful school mascot!"
50
  )
 
 
 
51
 
52
- # ──────────────────── load model & pipeline ────────────────────────────────
53
  hf_logging.set_verbosity_error()
54
  try:
55
- log("Loading tokenizer & model …")
56
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
57
  model = AutoModelForCausalLM.from_pretrained(
58
- MODEL_ID, device_map="auto", torch_dtype="auto"
59
- )
60
- gen = pipeline(
61
- "text-generation",
62
- model=model,
63
- tokenizer=tok,
64
- max_new_tokens=MAX_TOKENS,
65
- do_sample=True,
66
- temperature=0.7,
67
- )
68
  MODEL_ERR = None
69
  log("Model loaded βœ”")
70
- except Exception as exc: # noqa: BLE001
71
  MODEL_ERR, gen = f"Model load error: {exc}", None
72
  log(MODEL_ERR)
73
 
74
- # ──────────────────── small helpers ────────────────────────────────────────
75
- def clean(txt: str) -> str:
76
- """Collapse whitespace & guarantee non‑empty string."""
77
- return re.sub(r"\s+", " ", txt.strip()) or "…"
78
 
79
- def trim_history(msgs: list[dict]) -> list[dict]:
80
- """Keep system + last MAX_TURNS pairs."""
81
- return msgs if len(msgs) <= 1 + MAX_TURNS * 2 else [msgs[0]] + msgs[-MAX_TURNS * 2 :]
82
-
83
- # ──────────────────── core chat function ───────────────────────────────────
84
- def chat_fn(user_msg: str, history: list[dict] | None):
85
  log(f"User sent {len(user_msg)} chars")
 
 
 
86
 
87
- # ensure history list exists & begins with system prompt
88
- if not history or history[0]["role"] != "system":
89
- history = [{"role": "system", "content": SYSTEM_MSG}]
90
-
91
- # fatal model‑load failure
92
- if MODEL_ERR:
93
- return MODEL_ERR
94
-
95
- # basic user‑input checks
96
  user_msg = clean(user_msg or "")
97
- if not user_msg:
98
- return "Please type something."
99
- if len(user_msg) > MAX_INPUT_CH:
100
  return f"Message too long (>{MAX_INPUT_CH} chars)."
101
 
102
- # add user message & trim
103
- history.append({"role": "user", "content": user_msg})
104
- history = trim_history(history)
105
 
106
- # build prompt string
107
- prompt_lines: list[str] = []
108
- for m in history:
109
- if m["role"] == "system":
110
- prompt_lines.append(m["content"])
111
- elif m["role"] == "user":
112
- prompt_lines.append(f"User: {m['content']}")
113
- else:
114
- prompt_lines.append(f"AI: {m['content']}")
115
- prompt_lines.append("AI:")
116
  prompt = "\n".join(prompt_lines)
117
- log(f"Prompt {len(prompt)} chars β€’ generating…")
118
 
119
- # call generator
120
- t0 = time.time()
121
  try:
122
  raw = gen(prompt)[0]["generated_text"]
123
- reply = clean(raw.split("AI:", 1)[-1])
124
- # βœ‚ remove any echoed tags
125
- reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
126
- log(f"generate() {time.time() - t0:.2f}s β€’ reply {len(reply)} chars")
127
- except Exception: # noqa: BLE001
128
- log("❌ Inference exception:\n" + traceback.format_exc())
129
- reply = "Sorryβ€”AI backend crashed. Please try again later."
130
 
131
  return reply
132
 
133
- # ──────────────────── Gradio UI ────────────────────────────────────────────
134
  gr.ChatInterface(
135
  fn=chat_fn,
136
  chatbot=gr.Chatbot(height=480, type="messages"),
137
  title="SchoolSpiritΒ AI Chat",
138
- theme=gr.themes.Soft(primary_hue="blue"), # light‑blue accent
139
- type="messages", # modern message dicts
140
  ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os, re, time, datetime, traceback
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
5
 
6
+ # Persistent cache + request log
7
+ os.environ["HF_HOME"] = "/data/.huggingface"
8
+ LOG_FILE = "/data/requests.log"
9
+ def log(msg):
 
 
10
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
11
  line = f"[{ts}] {msg}"
12
  print(line, flush=True)
13
+ try:
14
+ with open(LOG_FILE, "a") as f: f.write(line + "\n")
15
+ except FileNotFoundError: pass
 
 
 
 
 
 
 
 
16
 
17
+ # Config
18
+ MODEL_ID, MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = (
19
+ "ibm-granite/granite-3.3-2b-instruct", 4, 64, 300
 
 
 
 
 
 
 
20
  )
21
+ SYSTEM_MSG = ("You are SchoolSpiritΒ AI, the upbeat mascot for a company that "
22
+ "installs on‑prem AI chatbots in schools. Keep answers short, "
23
+ "friendly, and safe.")
24
 
25
+ # Load model
26
  hf_logging.set_verbosity_error()
27
  try:
28
+ log("Loading model …")
29
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
30
  model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_ID, device_map="auto", torch_dtype="auto")
32
+ gen = pipeline("text-generation", model=model, tokenizer=tok,
33
+ max_new_tokens=MAX_TOKENS, do_sample=True, temperature=0.6)
 
 
 
 
 
 
 
34
  MODEL_ERR = None
35
  log("Model loaded βœ”")
36
+ except Exception as exc:
37
  MODEL_ERR, gen = f"Model load error: {exc}", None
38
  log(MODEL_ERR)
39
 
40
+ clean = lambda t: re.sub(r"\s+", " ", t.strip()) or "…"
41
+ trim = lambda m: m if len(m)<=1+MAX_TURNS*2 else [m[0]]+m[-MAX_TURNS*2:]
 
 
42
 
43
+ # Chat logic
44
+ def chat_fn(user_msg, history):
 
 
 
 
45
  log(f"User sent {len(user_msg)} chars")
46
+ if not history or history[0]["role"]!="system":
47
+ history=[{"role":"system","content":SYSTEM_MSG}]
48
+ if MODEL_ERR: return MODEL_ERR
49
 
 
 
 
 
 
 
 
 
 
50
  user_msg = clean(user_msg or "")
51
+ if not user_msg: return "Please type something."
52
+ if len(user_msg)>MAX_INPUT_CH:
 
53
  return f"Message too long (>{MAX_INPUT_CH} chars)."
54
 
55
+ history.append({"role":"user","content":user_msg})
56
+ history = trim(history)
 
57
 
58
+ prompt_lines=[m["content"] if m["role"]=="system"
59
+ else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
60
+ for m in history]+["AI:"]
 
 
 
 
 
 
 
61
  prompt = "\n".join(prompt_lines)
62
+ log(f"Prompt {len(prompt)} chars β†’ generating")
63
 
64
+ t0=time.time()
 
65
  try:
66
  raw = gen(prompt)[0]["generated_text"]
67
+ reply = clean(raw.split("AI:",1)[-1])
68
+ reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip() # ← cut here
69
+ log(f"generate() {time.time()-t0:.2f}s, reply {len(reply)} chars")
70
+ except Exception:
71
+ log("❌ Inference exception:\n"+traceback.format_exc())
72
+ reply="Sorryβ€”backend crashed. Please try again later."
 
73
 
74
  return reply
75
 
76
+ # UI
77
  gr.ChatInterface(
78
  fn=chat_fn,
79
  chatbot=gr.Chatbot(height=480, type="messages"),
80
  title="SchoolSpiritΒ AI Chat",
81
+ theme=gr.themes.Soft(primary_hue="blue"),
82
+ type="messages",
83
  ).launch()