Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,4 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import time
|
4 |
-
import datetime
|
5 |
-
import traceback
|
6 |
-
import torch
|
7 |
import gradio as gr
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
9 |
from transformers.utils import logging as hf_logging
|
@@ -13,70 +8,57 @@ from transformers.utils import logging as hf_logging
|
|
13 |
# ---------------------------------------------------------------------------
|
14 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
15 |
LOG_FILE = "/data/requests.log"
|
16 |
-
|
17 |
-
|
18 |
-
def log(msg: str):
|
19 |
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
|
20 |
line = f"[{ts}] {msg}"
|
21 |
print(line, flush=True)
|
22 |
-
try:
|
23 |
-
|
24 |
-
f.write(line + "\n")
|
25 |
-
except FileNotFoundError:
|
26 |
-
pass
|
27 |
-
|
28 |
|
29 |
# ---------------------------------------------------------------------------
|
30 |
-
# 1.
|
31 |
# ---------------------------------------------------------------------------
|
32 |
-
MODEL_ID
|
33 |
-
CONTEXT_TOKENS
|
34 |
-
MAX_NEW_TOKENS
|
35 |
-
TEMPERATURE
|
36 |
-
MAX_INPUT_CH
|
|
|
|
|
37 |
|
38 |
SYSTEM_MSG = (
|
39 |
"You are **SchoolSpirit AI**, the official digital mascot of "
|
40 |
-
"SchoolSpirit AI LLC.
|
41 |
-
"
|
42 |
-
"
|
43 |
"RULES:\n"
|
44 |
"• Friendly, concise (≤4 sentences unless prompted).\n"
|
45 |
"• No personal data collection; no medical/legal/financial advice.\n"
|
46 |
-
"• If uncertain, admit it & suggest human
|
47 |
-
"•
|
|
|
48 |
)
|
49 |
-
WELCOME_MSG = "Welcome to SchoolSpirit AI!
|
50 |
|
51 |
strip = lambda s: re.sub(r"\s+", " ", s.strip())
|
52 |
|
53 |
-
|
54 |
# ---------------------------------------------------------------------------
|
55 |
-
# 2. Load
|
56 |
# ---------------------------------------------------------------------------
|
57 |
hf_logging.set_verbosity_error()
|
58 |
try:
|
59 |
-
log("Loading tokenizer …")
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
else:
|
68 |
-
log("CPU fallback")
|
69 |
-
model = AutoModelForCausalLM.from_pretrained(
|
70 |
-
MODEL_ID,
|
71 |
-
device_map="cpu",
|
72 |
-
torch_dtype="auto",
|
73 |
-
low_cpu_mem_usage=True,
|
74 |
-
)
|
75 |
-
|
76 |
generator = pipeline(
|
77 |
"text-generation",
|
78 |
model=model,
|
79 |
-
tokenizer=
|
80 |
max_new_tokens=MAX_NEW_TOKENS,
|
81 |
do_sample=True,
|
82 |
temperature=TEMPERATURE,
|
@@ -85,115 +67,88 @@ try:
|
|
85 |
MODEL_ERR = None
|
86 |
log("Model loaded ✔")
|
87 |
except Exception as exc:
|
88 |
-
MODEL_ERR = f"Model load error: {exc}"
|
89 |
-
generator = None
|
90 |
log(MODEL_ERR)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# ---------------------------------------------------------------------------
|
94 |
-
#
|
95 |
# ---------------------------------------------------------------------------
|
96 |
-
def build_prompt(raw_history:
|
97 |
-
"""
|
98 |
-
raw_history: list [{'role':'system'|'user'|'assistant', 'content': str}, ...]
|
99 |
-
Keeps trimming oldest user/assistant pair until total tokens < CONTEXT_TOKENS
|
100 |
-
"""
|
101 |
def render(msg):
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
# always include system
|
108 |
-
system_msg = [msg for msg in raw_history if msg["role"] == "system"][0]
|
109 |
-
convo = [m for m in raw_history if m["role"] != "system"]
|
110 |
-
|
111 |
-
# iterative trim
|
112 |
while True:
|
113 |
-
|
114 |
-
|
115 |
-
if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
|
116 |
break
|
117 |
convo = convo[2:]
|
118 |
-
|
119 |
-
return "\n".join(prompt_parts)
|
120 |
-
|
121 |
|
122 |
# ---------------------------------------------------------------------------
|
123 |
-
#
|
124 |
# ---------------------------------------------------------------------------
|
125 |
-
def chat_fn(user_msg:
|
126 |
-
""
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
130 |
user_msg = strip(user_msg or "")
|
131 |
if not user_msg:
|
132 |
return display_history, state
|
133 |
-
|
134 |
if len(user_msg) > MAX_INPUT_CH:
|
135 |
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
|
136 |
return display_history, state
|
137 |
-
|
138 |
if MODEL_ERR:
|
139 |
display_history.append((user_msg, MODEL_ERR))
|
140 |
return display_history, state
|
141 |
|
142 |
-
|
143 |
-
state["raw"].append({"role": "user", "content": user_msg})
|
144 |
-
|
145 |
-
# --- Build prompt within token budget
|
146 |
prompt = build_prompt(state["raw"])
|
147 |
|
148 |
-
# --- Generate
|
149 |
try:
|
150 |
start = time.time()
|
151 |
-
|
152 |
-
reply =
|
153 |
-
|
154 |
-
reply = reply.split("User:", 1)[0].strip()
|
155 |
-
log(f"Reply in {time.time() - start:.2f}s ({len(reply)} chars)")
|
156 |
except Exception:
|
157 |
-
log("❌ Inference error:\n"
|
158 |
-
reply = "Apologies—
|
159 |
|
160 |
-
# --- Append assistant reply to both histories
|
161 |
display_history.append((user_msg, reply))
|
162 |
-
state["raw"].append({"role":
|
163 |
return display_history, state
|
164 |
|
165 |
-
|
166 |
# ---------------------------------------------------------------------------
|
167 |
-
#
|
168 |
# ---------------------------------------------------------------------------
|
169 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
170 |
gr.Markdown("### SchoolSpirit AI Chat")
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
height=480,
|
175 |
-
label="SchoolSpirit AI",
|
176 |
-
)
|
177 |
-
|
178 |
-
state = gr.State(
|
179 |
-
{
|
180 |
-
"raw": [
|
181 |
-
{"role": "system", "content": SYSTEM_MSG},
|
182 |
-
{"role": "assistant", "content": WELCOME_MSG},
|
183 |
-
]
|
184 |
-
}
|
185 |
-
)
|
186 |
-
|
187 |
with gr.Row():
|
188 |
-
txt = gr.Textbox(
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
lines=1,
|
193 |
-
)
|
194 |
-
send_btn = gr.Button("Send", variant="primary")
|
195 |
-
|
196 |
-
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
|
197 |
-
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
|
198 |
|
199 |
demo.launch()
|
|
|
1 |
+
import os, re, time, datetime, traceback, torch
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
from transformers.utils import logging as hf_logging
|
|
|
8 |
# ---------------------------------------------------------------------------
|
9 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
10 |
LOG_FILE = "/data/requests.log"
|
11 |
+
def log(msg:str):
|
|
|
|
|
12 |
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
|
13 |
line = f"[{ts}] {msg}"
|
14 |
print(line, flush=True)
|
15 |
+
try: open(LOG_FILE,"a").write(line+"\n")
|
16 |
+
except FileNotFoundError: pass
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# ---------------------------------------------------------------------------
|
19 |
+
# 1. Config
|
20 |
# ---------------------------------------------------------------------------
|
21 |
+
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
|
22 |
+
CONTEXT_TOKENS = 1800
|
23 |
+
MAX_NEW_TOKENS = 96
|
24 |
+
TEMPERATURE = 0.5
|
25 |
+
MAX_INPUT_CH = 300
|
26 |
+
RATE_LIMIT_N = 6 # ↲ max messages
|
27 |
+
RATE_LIMIT_WINDOW = 60 # ↲ per seconds
|
28 |
|
29 |
SYSTEM_MSG = (
|
30 |
"You are **SchoolSpirit AI**, the official digital mascot of "
|
31 |
+
"SchoolSpirit AI LLC. The company deploys on‑prem AI chat mascots, "
|
32 |
+
"fine‑tunes language models, and ships turnkey GPU servers to K‑12 "
|
33 |
+
"schools.\n\n"
|
34 |
"RULES:\n"
|
35 |
"• Friendly, concise (≤4 sentences unless prompted).\n"
|
36 |
"• No personal data collection; no medical/legal/financial advice.\n"
|
37 |
+
"• If uncertain, admit it & suggest contacting a human.\n"
|
38 |
+
"• If you can’t answer, politely direct the user to admin@schoolspiritai.com.\n"
|
39 |
+
"• Avoid profanity, politics, or mature themes."
|
40 |
)
|
41 |
+
WELCOME_MSG = "Welcome to SchoolSpirit AI! Ask me about our mascots, fine‑tuning, or GPU servers."
|
42 |
|
43 |
strip = lambda s: re.sub(r"\s+", " ", s.strip())
|
44 |
|
|
|
45 |
# ---------------------------------------------------------------------------
|
46 |
+
# 2. Load model
|
47 |
# ---------------------------------------------------------------------------
|
48 |
hf_logging.set_verbosity_error()
|
49 |
try:
|
50 |
+
log("Loading tokenizer / model …")
|
51 |
+
tok = AutoTokenizer.from_pretrained(MODEL_ID)
|
52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
53 |
+
MODEL_ID,
|
54 |
+
device_map="auto" if torch.cuda.is_available() else "cpu",
|
55 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else "auto",
|
56 |
+
low_cpu_mem_usage=not torch.cuda.is_available(),
|
57 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
generator = pipeline(
|
59 |
"text-generation",
|
60 |
model=model,
|
61 |
+
tokenizer=tok,
|
62 |
max_new_tokens=MAX_NEW_TOKENS,
|
63 |
do_sample=True,
|
64 |
temperature=TEMPERATURE,
|
|
|
67 |
MODEL_ERR = None
|
68 |
log("Model loaded ✔")
|
69 |
except Exception as exc:
|
70 |
+
MODEL_ERR, generator = f"Model load error: {exc}", None
|
|
|
71 |
log(MODEL_ERR)
|
72 |
|
73 |
+
# ---------------------------------------------------------------------------
|
74 |
+
# 3. Rate‑limiter (IP → timestamps list)
|
75 |
+
# ---------------------------------------------------------------------------
|
76 |
+
VISITS: dict[str,list[float]] = {}
|
77 |
+
def allow(ip:str)->bool:
|
78 |
+
now = time.time()
|
79 |
+
times = VISITS.get(ip,[])
|
80 |
+
times = [t for t in times if now - t < RATE_LIMIT_WINDOW]
|
81 |
+
if len(times) >= RATE_LIMIT_N:
|
82 |
+
VISITS[ip] = times # cleanup stale entries
|
83 |
+
return False
|
84 |
+
times.append(now)
|
85 |
+
VISITS[ip] = times
|
86 |
+
return True
|
87 |
|
88 |
# ---------------------------------------------------------------------------
|
89 |
+
# 4. Build prompt within token budget
|
90 |
# ---------------------------------------------------------------------------
|
91 |
+
def build_prompt(raw_history:list[dict])->str:
|
|
|
|
|
|
|
|
|
92 |
def render(msg):
|
93 |
+
prefix = {"user":"User:","assistant":"AI:"}.get(msg["role"],"")
|
94 |
+
return msg["content"] if not prefix else f"{prefix} {msg['content']}"
|
95 |
+
system = raw_history[0] # first is system
|
96 |
+
convo = raw_history[1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
while True:
|
98 |
+
parts = [system["content"]] + [render(m) for m in convo] + ["AI:"]
|
99 |
+
if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CONTEXT_TOKENS or len(convo)<=2:
|
|
|
100 |
break
|
101 |
convo = convo[2:]
|
102 |
+
return "\n".join(parts)
|
|
|
|
|
103 |
|
104 |
# ---------------------------------------------------------------------------
|
105 |
+
# 5. Chat callback
|
106 |
# ---------------------------------------------------------------------------
|
107 |
+
def chat_fn(user_msg:str, display_history:list, state:dict, request:gr.Request):
|
108 |
+
ip = request.client.host if request else "unknown"
|
109 |
+
if not allow(ip):
|
110 |
+
reply = "Rate limit exceeded — please wait a minute and try again."
|
111 |
+
display_history.append((user_msg, reply))
|
112 |
+
return display_history, state
|
113 |
+
|
114 |
user_msg = strip(user_msg or "")
|
115 |
if not user_msg:
|
116 |
return display_history, state
|
|
|
117 |
if len(user_msg) > MAX_INPUT_CH:
|
118 |
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
|
119 |
return display_history, state
|
|
|
120 |
if MODEL_ERR:
|
121 |
display_history.append((user_msg, MODEL_ERR))
|
122 |
return display_history, state
|
123 |
|
124 |
+
state["raw"].append({"role":"user","content":user_msg})
|
|
|
|
|
|
|
125 |
prompt = build_prompt(state["raw"])
|
126 |
|
|
|
127 |
try:
|
128 |
start = time.time()
|
129 |
+
reply = strip(generator(prompt)[0]["generated_text"])
|
130 |
+
if "User:" in reply: reply = reply.split("User:",1)[0].strip()
|
131 |
+
log(f"{ip} ok {time.time()-start:.2f}s ({len(reply)} chars)")
|
|
|
|
|
132 |
except Exception:
|
133 |
+
log("❌ Inference error:\n"+traceback.format_exc())
|
134 |
+
reply = "Apologies—internal error. Please try again."
|
135 |
|
|
|
136 |
display_history.append((user_msg, reply))
|
137 |
+
state["raw"].append({"role":"assistant","content":reply})
|
138 |
return display_history, state
|
139 |
|
|
|
140 |
# ---------------------------------------------------------------------------
|
141 |
+
# 6. Launch UI
|
142 |
# ---------------------------------------------------------------------------
|
143 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
144 |
gr.Markdown("### SchoolSpirit AI Chat")
|
145 |
+
chatbot = gr.Chatbot(value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI")
|
146 |
+
state = gr.State({"raw":[{"role":"system","content":SYSTEM_MSG},
|
147 |
+
{"role":"assistant","content":WELCOME_MSG}]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
with gr.Row():
|
149 |
+
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
|
150 |
+
btn = gr.Button("Send", variant="primary")
|
151 |
+
btn.click(chat_fn, inputs=[txt,chatbot,state], outputs=[chatbot,state])
|
152 |
+
txt.submit(chat_fn, inputs=[txt,chatbot,state], outputs=[chatbot,state])
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
demo.launch()
|