Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,62 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import
|
|
|
|
|
|
|
|
|
3 |
from transformers.utils import logging as hf_logging
|
4 |
|
5 |
hf_logging.set_verbosity_error()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
|
8 |
-
MAX_TURNS = 6
|
9 |
-
MAX_TOKENS = 220
|
10 |
SYSTEM_MSG = (
|
11 |
"You are SchoolSpirit AI, the friendly digital mascot for a company that "
|
12 |
-
"
|
13 |
-
"hardware for schools.
|
14 |
-
"If
|
15 |
"personal data."
|
16 |
)
|
17 |
|
|
|
18 |
try:
|
19 |
-
|
20 |
-
model
|
21 |
-
MODEL_ID,
|
|
|
|
|
22 |
)
|
23 |
-
|
24 |
"text-generation",
|
25 |
model=model,
|
26 |
-
tokenizer=
|
27 |
max_new_tokens=MAX_TOKENS,
|
28 |
do_sample=True,
|
29 |
temperature=0.7,
|
30 |
)
|
31 |
-
|
32 |
except Exception as exc: # noqa: BLE001
|
33 |
-
|
34 |
-
|
|
|
35 |
|
|
|
36 |
def chat(history, user_msg):
|
37 |
-
|
38 |
-
return history + [(user_msg, model_error)], ""
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
if len(history) > MAX_TURNS:
|
41 |
history = history[-MAX_TURNS:]
|
42 |
|
43 |
-
|
|
|
44 |
for u, a in history:
|
45 |
-
prompt
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
|
|
|
48 |
try:
|
49 |
-
completion =
|
50 |
reply = completion.split("AI:", 1)[-1].strip()
|
|
|
51 |
except Exception as err: # noqa: BLE001
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
54 |
|
55 |
history.append((user_msg, reply))
|
56 |
return history, ""
|
57 |
|
|
|
58 |
gr.ChatInterface(
|
59 |
chat,
|
60 |
title="SchoolSpirit AI Chat",
|
61 |
-
theme=gr.themes.Soft(primary_hue="blue"),
|
62 |
).launch()
|
|
|
1 |
+
"""
|
2 |
+
SchoolSpirit AI – public chatbot Space
|
3 |
+
--------------------------------------
|
4 |
+
• Tiny Llama‑3 3 B model (fits HF CPU Space).
|
5 |
+
• Light‑blue Gradio chat widget.
|
6 |
+
• Robust error handling.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import re
|
10 |
import gradio as gr
|
11 |
+
from transformers import (
|
12 |
+
AutoTokenizer,
|
13 |
+
AutoModelForCausalLM,
|
14 |
+
pipeline,
|
15 |
+
)
|
16 |
from transformers.utils import logging as hf_logging
|
17 |
|
18 |
hf_logging.set_verbosity_error()
|
19 |
+
LOGGER = hf_logging.get_logger("SchoolSpirit")
|
20 |
+
|
21 |
+
# ------------------------ Config -------------------------------------------
|
22 |
+
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
|
23 |
+
MAX_TURNS = 6 # last N exchanges kept
|
24 |
+
MAX_TOKENS = 220 # response length
|
25 |
+
MAX_INPUT_CH = 500 # user message length guard
|
26 |
|
|
|
|
|
|
|
27 |
SYSTEM_MSG = (
|
28 |
"You are SchoolSpirit AI, the friendly digital mascot for a company that "
|
29 |
+
"offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
|
30 |
+
"hardware for schools. Keep answers concise, upbeat, and age‑appropriate. "
|
31 |
+
"If unsure, admit it and suggest contacting a human. Never request "
|
32 |
"personal data."
|
33 |
)
|
34 |
|
35 |
+
# ------------------------ Model Load ---------------------------------------
|
36 |
try:
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
39 |
+
MODEL_ID,
|
40 |
+
device_map="auto",
|
41 |
+
torch_dtype="auto",
|
42 |
)
|
43 |
+
generator = pipeline(
|
44 |
"text-generation",
|
45 |
model=model,
|
46 |
+
tokenizer=tokenizer,
|
47 |
max_new_tokens=MAX_TOKENS,
|
48 |
do_sample=True,
|
49 |
temperature=0.7,
|
50 |
)
|
51 |
+
MODEL_LOAD_ERR = None
|
52 |
except Exception as exc: # noqa: BLE001
|
53 |
+
generator = None
|
54 |
+
MODEL_LOAD_ERR = f"Model load error: {exc}"
|
55 |
+
LOGGER.error(MODEL_LOAD_ERR)
|
56 |
|
57 |
+
# ------------------------ Chat Function ------------------------------------
|
58 |
def chat(history, user_msg):
|
59 |
+
"""Gradio ChatInterface callback using (history, user_msg) tuples."""
|
|
|
60 |
|
61 |
+
# Hard failure at Space startup
|
62 |
+
if MODEL_LOAD_ERR:
|
63 |
+
history.append((user_msg, MODEL_LOAD_ERR))
|
64 |
+
return history, ""
|
65 |
+
|
66 |
+
# Basic user‑input guardrails
|
67 |
+
user_msg = (user_msg or "").strip()
|
68 |
+
if not user_msg:
|
69 |
+
return history + [("", "Please enter a message.")], ""
|
70 |
+
if len(user_msg) > MAX_INPUT_CH:
|
71 |
+
history.append(
|
72 |
+
(user_msg, "Sorry, your message is too long. Please shorten it.")
|
73 |
+
)
|
74 |
+
return history, ""
|
75 |
+
|
76 |
+
# Keep only last MAX_TURNS
|
77 |
if len(history) > MAX_TURNS:
|
78 |
history = history[-MAX_TURNS:]
|
79 |
|
80 |
+
# Build prompt
|
81 |
+
prompt = [SYSTEM_MSG]
|
82 |
for u, a in history:
|
83 |
+
prompt.append(f"User: {u}")
|
84 |
+
prompt.append(f"AI: {a}")
|
85 |
+
prompt.append(f"User: {user_msg}")
|
86 |
+
prompt.append("AI:")
|
87 |
+
|
88 |
+
prompt = "\n".join(prompt)
|
89 |
|
90 |
+
# Generate reply
|
91 |
try:
|
92 |
+
completion = generator(prompt, truncate=2048)[0]["generated_text"]
|
93 |
reply = completion.split("AI:", 1)[-1].strip()
|
94 |
+
reply = re.sub(r"\s+", " ", reply) # collapse excess whitespace
|
95 |
except Exception as err: # noqa: BLE001
|
96 |
+
LOGGER.error(f"Inference error: {err}")
|
97 |
+
reply = (
|
98 |
+
"Sorry, something went wrong on my end. "
|
99 |
+
"Please try again in a few seconds."
|
100 |
+
)
|
101 |
|
102 |
history.append((user_msg, reply))
|
103 |
return history, ""
|
104 |
|
105 |
+
# ------------------------ UI -----------------------------------------------
|
106 |
gr.ChatInterface(
|
107 |
chat,
|
108 |
title="SchoolSpirit AI Chat",
|
109 |
+
theme=gr.themes.Soft(primary_hue="blue"), # light‑blue look
|
110 |
).launch()
|