phanerozoic commited on
Commit
2cb9530
·
verified ·
1 Parent(s): 832bd11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -121
app.py CHANGED
@@ -1,9 +1,4 @@
1
- import os
2
- import re
3
- import time
4
- import datetime
5
- import traceback
6
- import torch
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from transformers.utils import logging as hf_logging
@@ -13,70 +8,57 @@ from transformers.utils import logging as hf_logging
13
  # ---------------------------------------------------------------------------
14
  os.environ["HF_HOME"] = "/data/.huggingface"
15
  LOG_FILE = "/data/requests.log"
16
-
17
-
18
- def log(msg: str):
19
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
20
  line = f"[{ts}] {msg}"
21
  print(line, flush=True)
22
- try:
23
- with open(LOG_FILE, "a") as f:
24
- f.write(line + "\n")
25
- except FileNotFoundError:
26
- pass
27
-
28
 
29
  # ---------------------------------------------------------------------------
30
- # 1. Configuration constants
31
  # ---------------------------------------------------------------------------
32
- MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
33
- CONTEXT_TOKENS = 1800
34
- MAX_NEW_TOKENS = 64
35
- TEMPERATURE = 0.6
36
- MAX_INPUT_CH = 300
 
 
37
 
38
  SYSTEM_MSG = (
39
  "You are **SchoolSpirit AI**, the official digital mascot of "
40
- "SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
41
- "deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
42
- "turnkey GPU servers to K‑12 schools.\n\n"
43
  "RULES:\n"
44
  "• Friendly, concise (≤4 sentences unless prompted).\n"
45
  "• No personal data collection; no medical/legal/financial advice.\n"
46
- "• If uncertain, admit it & suggest human follow‑up.\n"
47
- "• Avoid profanity, politics, mature themes."
 
48
  )
49
- WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
50
 
51
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
52
 
53
-
54
  # ---------------------------------------------------------------------------
55
- # 2. Load tokenizer + model (GPU FP‑16 → CPU)
56
  # ---------------------------------------------------------------------------
57
  hf_logging.set_verbosity_error()
58
  try:
59
- log("Loading tokenizer …")
60
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
61
-
62
- if torch.cuda.is_available():
63
- log("GPU detected FP‑16")
64
- model = AutoModelForCausalLM.from_pretrained(
65
- MODEL_ID, device_map="auto", torch_dtype=torch.float16
66
- )
67
- else:
68
- log("CPU fallback")
69
- model = AutoModelForCausalLM.from_pretrained(
70
- MODEL_ID,
71
- device_map="cpu",
72
- torch_dtype="auto",
73
- low_cpu_mem_usage=True,
74
- )
75
-
76
  generator = pipeline(
77
  "text-generation",
78
  model=model,
79
- tokenizer=tokenizer,
80
  max_new_tokens=MAX_NEW_TOKENS,
81
  do_sample=True,
82
  temperature=TEMPERATURE,
@@ -85,115 +67,88 @@ try:
85
  MODEL_ERR = None
86
  log("Model loaded ✔")
87
  except Exception as exc:
88
- MODEL_ERR = f"Model load error: {exc}"
89
- generator = None
90
  log(MODEL_ERR)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # ---------------------------------------------------------------------------
94
- # 3. Helper: build prompt under token budget
95
  # ---------------------------------------------------------------------------
96
- def build_prompt(raw_history: list[dict]) -> str:
97
- """
98
- raw_history: list [{'role':'system'|'user'|'assistant', 'content': str}, ...]
99
- Keeps trimming oldest user/assistant pair until total tokens < CONTEXT_TOKENS
100
- """
101
  def render(msg):
102
- if msg["role"] == "system":
103
- return msg["content"]
104
- prefix = "User:" if msg["role"] == "user" else "AI:"
105
- return f"{prefix} {msg['content']}"
106
-
107
- # always include system
108
- system_msg = [msg for msg in raw_history if msg["role"] == "system"][0]
109
- convo = [m for m in raw_history if m["role"] != "system"]
110
-
111
- # iterative trim
112
  while True:
113
- prompt_parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
114
- token_len = len(tokenizer.encode("\n".join(prompt_parts), add_special_tokens=False))
115
- if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
116
  break
117
  convo = convo[2:]
118
-
119
- return "\n".join(prompt_parts)
120
-
121
 
122
  # ---------------------------------------------------------------------------
123
- # 4. Chat callback
124
  # ---------------------------------------------------------------------------
125
- def chat_fn(user_msg: str, display_history: list, state: dict):
126
- """
127
- display_history : list[tuple[str,str]] for UI
128
- state["raw"] : list[dict] for prompting
129
- """
 
 
130
  user_msg = strip(user_msg or "")
131
  if not user_msg:
132
  return display_history, state
133
-
134
  if len(user_msg) > MAX_INPUT_CH:
135
  display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
136
  return display_history, state
137
-
138
  if MODEL_ERR:
139
  display_history.append((user_msg, MODEL_ERR))
140
  return display_history, state
141
 
142
- # --- Update raw history
143
- state["raw"].append({"role": "user", "content": user_msg})
144
-
145
- # --- Build prompt within token budget
146
  prompt = build_prompt(state["raw"])
147
 
148
- # --- Generate
149
  try:
150
  start = time.time()
151
- result = generator(prompt)[0]
152
- reply = strip(result["generated_text"])
153
- if "User:" in reply:
154
- reply = reply.split("User:", 1)[0].strip()
155
- log(f"Reply in {time.time() - start:.2f}s ({len(reply)} chars)")
156
  except Exception:
157
- log("❌ Inference error:\n" + traceback.format_exc())
158
- reply = "Apologies—an internal error occurred. Please try again."
159
 
160
- # --- Append assistant reply to both histories
161
  display_history.append((user_msg, reply))
162
- state["raw"].append({"role": "assistant", "content": reply})
163
  return display_history, state
164
 
165
-
166
  # ---------------------------------------------------------------------------
167
- # 5. Launch Gradio Blocks UI
168
  # ---------------------------------------------------------------------------
169
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
170
  gr.Markdown("### SchoolSpirit AI Chat")
171
-
172
- chatbot = gr.Chatbot(
173
- value=[("", WELCOME_MSG)],
174
- height=480,
175
- label="SchoolSpirit AI",
176
- )
177
-
178
- state = gr.State(
179
- {
180
- "raw": [
181
- {"role": "system", "content": SYSTEM_MSG},
182
- {"role": "assistant", "content": WELCOME_MSG},
183
- ]
184
- }
185
- )
186
-
187
  with gr.Row():
188
- txt = gr.Textbox(
189
- placeholder="Type your question here…",
190
- show_label=False,
191
- scale=4,
192
- lines=1,
193
- )
194
- send_btn = gr.Button("Send", variant="primary")
195
-
196
- send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
197
- txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
198
 
199
  demo.launch()
 
1
+ import os, re, time, datetime, traceback, torch
 
 
 
 
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
 
8
  # ---------------------------------------------------------------------------
9
  os.environ["HF_HOME"] = "/data/.huggingface"
10
  LOG_FILE = "/data/requests.log"
11
+ def log(msg:str):
 
 
12
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
13
  line = f"[{ts}] {msg}"
14
  print(line, flush=True)
15
+ try: open(LOG_FILE,"a").write(line+"\n")
16
+ except FileNotFoundError: pass
 
 
 
 
17
 
18
  # ---------------------------------------------------------------------------
19
+ # 1. Config
20
  # ---------------------------------------------------------------------------
21
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
22
+ CONTEXT_TOKENS = 1800
23
+ MAX_NEW_TOKENS = 96
24
+ TEMPERATURE = 0.5
25
+ MAX_INPUT_CH = 300
26
+ RATE_LIMIT_N = 6 # ↲ max messages
27
+ RATE_LIMIT_WINDOW = 60 # ↲ per seconds
28
 
29
  SYSTEM_MSG = (
30
  "You are **SchoolSpirit AI**, the official digital mascot of "
31
+ "SchoolSpirit AI LLC. The company deploys on‑prem AI chat mascots, "
32
+ "fine‑tunes language models, and ships turnkey GPU servers to K‑12 "
33
+ "schools.\n\n"
34
  "RULES:\n"
35
  "• Friendly, concise (≤4 sentences unless prompted).\n"
36
  "• No personal data collection; no medical/legal/financial advice.\n"
37
+ "• If uncertain, admit it & suggest contacting a human.\n"
38
+ "• If you can’t answer, politely direct the user to admin@schoolspiritai.com.\n"
39
+ "• Avoid profanity, politics, or mature themes."
40
  )
41
+ WELCOME_MSG = "Welcome to SchoolSpirit AI! Ask me about our mascots, fine‑tuning, or GPU servers."
42
 
43
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
44
 
 
45
  # ---------------------------------------------------------------------------
46
+ # 2. Load model
47
  # ---------------------------------------------------------------------------
48
  hf_logging.set_verbosity_error()
49
  try:
50
+ log("Loading tokenizer / model …")
51
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ MODEL_ID,
54
+ device_map="auto" if torch.cuda.is_available() else "cpu",
55
+ torch_dtype=torch.float16 if torch.cuda.is_available() else "auto",
56
+ low_cpu_mem_usage=not torch.cuda.is_available(),
57
+ )
 
 
 
 
 
 
 
 
 
58
  generator = pipeline(
59
  "text-generation",
60
  model=model,
61
+ tokenizer=tok,
62
  max_new_tokens=MAX_NEW_TOKENS,
63
  do_sample=True,
64
  temperature=TEMPERATURE,
 
67
  MODEL_ERR = None
68
  log("Model loaded ✔")
69
  except Exception as exc:
70
+ MODEL_ERR, generator = f"Model load error: {exc}", None
 
71
  log(MODEL_ERR)
72
 
73
+ # ---------------------------------------------------------------------------
74
+ # 3. Rate‑limiter (IP → timestamps list)
75
+ # ---------------------------------------------------------------------------
76
+ VISITS: dict[str,list[float]] = {}
77
+ def allow(ip:str)->bool:
78
+ now = time.time()
79
+ times = VISITS.get(ip,[])
80
+ times = [t for t in times if now - t < RATE_LIMIT_WINDOW]
81
+ if len(times) >= RATE_LIMIT_N:
82
+ VISITS[ip] = times # cleanup stale entries
83
+ return False
84
+ times.append(now)
85
+ VISITS[ip] = times
86
+ return True
87
 
88
  # ---------------------------------------------------------------------------
89
+ # 4. Build prompt within token budget
90
  # ---------------------------------------------------------------------------
91
+ def build_prompt(raw_history:list[dict])->str:
 
 
 
 
92
  def render(msg):
93
+ prefix = {"user":"User:","assistant":"AI:"}.get(msg["role"],"")
94
+ return msg["content"] if not prefix else f"{prefix} {msg['content']}"
95
+ system = raw_history[0] # first is system
96
+ convo = raw_history[1:]
 
 
 
 
 
 
97
  while True:
98
+ parts = [system["content"]] + [render(m) for m in convo] + ["AI:"]
99
+ if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CONTEXT_TOKENS or len(convo)<=2:
 
100
  break
101
  convo = convo[2:]
102
+ return "\n".join(parts)
 
 
103
 
104
  # ---------------------------------------------------------------------------
105
+ # 5. Chat callback
106
  # ---------------------------------------------------------------------------
107
+ def chat_fn(user_msg:str, display_history:list, state:dict, request:gr.Request):
108
+ ip = request.client.host if request else "unknown"
109
+ if not allow(ip):
110
+ reply = "Rate limit exceeded — please wait a minute and try again."
111
+ display_history.append((user_msg, reply))
112
+ return display_history, state
113
+
114
  user_msg = strip(user_msg or "")
115
  if not user_msg:
116
  return display_history, state
 
117
  if len(user_msg) > MAX_INPUT_CH:
118
  display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
119
  return display_history, state
 
120
  if MODEL_ERR:
121
  display_history.append((user_msg, MODEL_ERR))
122
  return display_history, state
123
 
124
+ state["raw"].append({"role":"user","content":user_msg})
 
 
 
125
  prompt = build_prompt(state["raw"])
126
 
 
127
  try:
128
  start = time.time()
129
+ reply = strip(generator(prompt)[0]["generated_text"])
130
+ if "User:" in reply: reply = reply.split("User:",1)[0].strip()
131
+ log(f"{ip} ok {time.time()-start:.2f}s ({len(reply)} chars)")
 
 
132
  except Exception:
133
+ log("❌ Inference error:\n"+traceback.format_exc())
134
+ reply = "Apologies—internal error. Please try again."
135
 
 
136
  display_history.append((user_msg, reply))
137
+ state["raw"].append({"role":"assistant","content":reply})
138
  return display_history, state
139
 
 
140
  # ---------------------------------------------------------------------------
141
+ # 6. Launch UI
142
  # ---------------------------------------------------------------------------
143
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
144
  gr.Markdown("### SchoolSpirit AI Chat")
145
+ chatbot = gr.Chatbot(value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI")
146
+ state = gr.State({"raw":[{"role":"system","content":SYSTEM_MSG},
147
+ {"role":"assistant","content":WELCOME_MSG}]})
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  with gr.Row():
149
+ txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
150
+ btn = gr.Button("Send", variant="primary")
151
+ btn.click(chat_fn, inputs=[txt,chatbot,state], outputs=[chatbot,state])
152
+ txt.submit(chat_fn, inputs=[txt,chatbot,state], outputs=[chatbot,state])
 
 
 
 
 
 
153
 
154
  demo.launch()