Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
"""
|
2 |
-
SchoolSpiritΒ AI β
|
3 |
-
|
4 |
-
β’
|
5 |
-
β’
|
6 |
-
β’
|
|
|
7 |
"""
|
8 |
|
9 |
import re
|
@@ -15,96 +16,111 @@ from transformers import (
|
|
15 |
)
|
16 |
from transformers.utils import logging as hf_logging
|
17 |
|
|
|
18 |
hf_logging.set_verbosity_error()
|
19 |
-
|
20 |
|
21 |
-
# ------------------------ Config -------------------------------------------
|
22 |
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
|
23 |
-
MAX_TURNS = 6
|
24 |
-
MAX_TOKENS = 220
|
25 |
-
MAX_INPUT_CH = 500
|
26 |
|
27 |
SYSTEM_MSG = (
|
28 |
"You are SchoolSpiritΒ AI, the friendly digital mascot for a company that "
|
29 |
"offers onβprem AI chat mascots, fineβtuning services, and turnkey GPU "
|
30 |
"hardware for schools. Keep answers concise, upbeat, and ageβappropriate. "
|
31 |
-
"If unsure,
|
32 |
-
"personal data."
|
33 |
)
|
34 |
|
35 |
-
#
|
36 |
try:
|
37 |
-
|
38 |
-
model
|
39 |
-
MODEL_ID,
|
40 |
-
device_map="auto",
|
41 |
-
torch_dtype="auto",
|
42 |
)
|
43 |
generator = pipeline(
|
44 |
"text-generation",
|
45 |
model=model,
|
46 |
-
tokenizer=
|
47 |
max_new_tokens=MAX_TOKENS,
|
48 |
do_sample=True,
|
49 |
temperature=0.7,
|
50 |
)
|
51 |
MODEL_LOAD_ERR = None
|
52 |
except Exception as exc: # noqa: BLE001
|
53 |
-
generator = None
|
54 |
MODEL_LOAD_ERR = f"Model load error: {exc}"
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
#
|
58 |
def chat(history, user_msg):
|
59 |
-
|
|
|
60 |
|
61 |
-
#
|
62 |
if MODEL_LOAD_ERR:
|
63 |
history.append((user_msg, MODEL_LOAD_ERR))
|
64 |
return history, ""
|
65 |
|
66 |
-
# Basic
|
67 |
user_msg = (user_msg or "").strip()
|
68 |
if not user_msg:
|
69 |
-
|
|
|
70 |
if len(user_msg) > MAX_INPUT_CH:
|
71 |
history.append(
|
72 |
-
(user_msg, "Sorry,
|
73 |
)
|
74 |
return history, ""
|
75 |
|
76 |
-
|
77 |
-
if len(history) > MAX_TURNS:
|
78 |
-
history = history[-MAX_TURNS:]
|
79 |
|
80 |
# Build prompt
|
81 |
-
|
82 |
for u, a in history:
|
83 |
-
|
84 |
-
|
85 |
-
prompt.
|
86 |
-
prompt.append("AI:")
|
87 |
-
|
88 |
-
prompt = "\n".join(prompt)
|
89 |
|
90 |
-
# Generate
|
91 |
try:
|
92 |
completion = generator(prompt, truncate=2048)[0]["generated_text"]
|
93 |
-
reply = completion.split("AI:", 1)[-1]
|
94 |
-
reply = re.sub(r"\s+", " ", reply) # collapse excess whitespace
|
95 |
except Exception as err: # noqa: BLE001
|
96 |
-
|
97 |
reply = (
|
98 |
-
"Sorry
|
99 |
-
"Please try again in a few seconds."
|
100 |
)
|
101 |
|
102 |
history.append((user_msg, reply))
|
103 |
return history, ""
|
104 |
|
105 |
-
#
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
SchoolSpiritΒ AI β hardened chatbot Space with βClear chatβ button
|
3 |
+
----------------------------------------------------------------
|
4 |
+
β’ Loads MetaΒ Llamaβ3Β 3Β BβInstruct via transformers.
|
5 |
+
β’ Keeps last MAX_TURNS exchanges to fit context.
|
6 |
+
β’ Adds a βClear chatβ button that resets history.
|
7 |
+
β’ Extensive try/except blocks to prevent widget crashes.
|
8 |
"""
|
9 |
|
10 |
import re
|
|
|
16 |
)
|
17 |
from transformers.utils import logging as hf_logging
|
18 |
|
19 |
+
# ββββββββββββββββββββββββββ Config ββββββββββββββββββββββββββββββββββββββββββ
|
20 |
hf_logging.set_verbosity_error()
|
21 |
+
LOG = hf_logging.get_logger("SchoolSpirit")
|
22 |
|
|
|
23 |
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
|
24 |
+
MAX_TURNS = 6
|
25 |
+
MAX_TOKENS = 220
|
26 |
+
MAX_INPUT_CH = 500
|
27 |
|
28 |
SYSTEM_MSG = (
|
29 |
"You are SchoolSpiritΒ AI, the friendly digital mascot for a company that "
|
30 |
"offers onβprem AI chat mascots, fineβtuning services, and turnkey GPU "
|
31 |
"hardware for schools. Keep answers concise, upbeat, and ageβappropriate. "
|
32 |
+
"If unsure, say so and suggest contacting a human. Never request personal data."
|
|
|
33 |
)
|
34 |
|
35 |
+
# ββββββββββββββββββββββββββ Model Load ββββββββββββββββββββββββββββββββββββββ
|
36 |
try:
|
37 |
+
tok = AutoTokenizer.from_pretrained(MODEL_ID)
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
39 |
+
MODEL_ID, device_map="auto", torch_dtype="auto"
|
|
|
|
|
40 |
)
|
41 |
generator = pipeline(
|
42 |
"text-generation",
|
43 |
model=model,
|
44 |
+
tokenizer=tok,
|
45 |
max_new_tokens=MAX_TOKENS,
|
46 |
do_sample=True,
|
47 |
temperature=0.7,
|
48 |
)
|
49 |
MODEL_LOAD_ERR = None
|
50 |
except Exception as exc: # noqa: BLE001
|
|
|
51 |
MODEL_LOAD_ERR = f"Model load error: {exc}"
|
52 |
+
generator = None
|
53 |
+
LOG.error(MODEL_LOAD_ERR)
|
54 |
+
|
55 |
+
# ββββββββββββββββββββββββββ Helpers ββββββββββββββββββββββββββββββββββββββββ
|
56 |
+
def truncate_history(hist, max_turns):
|
57 |
+
"""Return only the last `max_turns` (user,bot) pairs."""
|
58 |
+
return hist[-max_turns:] if len(hist) > max_turns else hist
|
59 |
+
|
60 |
+
|
61 |
+
def safe_reply(msg: str) -> str:
|
62 |
+
"""Postβprocess model output and ensure it is nonβempty."""
|
63 |
+
msg = msg.strip()
|
64 |
+
msg = re.sub(r"\s+", " ", msg)
|
65 |
+
return msg or "β¦"
|
66 |
|
67 |
+
# ββββββββββββββββββββββββββ Chat Callback βββββββββββββββββββββββββββββββββββ
|
68 |
def chat(history, user_msg):
|
69 |
+
# Gradio guarantees history is a list of tuples (user, bot)
|
70 |
+
history = list(history)
|
71 |
|
72 |
+
# Startβup failure fallback
|
73 |
if MODEL_LOAD_ERR:
|
74 |
history.append((user_msg, MODEL_LOAD_ERR))
|
75 |
return history, ""
|
76 |
|
77 |
+
# Basic guards
|
78 |
user_msg = (user_msg or "").strip()
|
79 |
if not user_msg:
|
80 |
+
history.append(("", "Please enter a message."))
|
81 |
+
return history, ""
|
82 |
if len(user_msg) > MAX_INPUT_CH:
|
83 |
history.append(
|
84 |
+
(user_msg, "Sorry, that message is too long. Please shorten it.")
|
85 |
)
|
86 |
return history, ""
|
87 |
|
88 |
+
history = truncate_history(history, MAX_TURNS)
|
|
|
|
|
89 |
|
90 |
# Build prompt
|
91 |
+
prompt_parts = [SYSTEM_MSG]
|
92 |
for u, a in history:
|
93 |
+
prompt_parts += [f"User: {u}", f"AI: {a}"]
|
94 |
+
prompt_parts += [f"User: {user_msg}", "AI:"]
|
95 |
+
prompt = "\n".join(prompt_parts)
|
|
|
|
|
|
|
96 |
|
97 |
+
# Generate
|
98 |
try:
|
99 |
completion = generator(prompt, truncate=2048)[0]["generated_text"]
|
100 |
+
reply = safe_reply(completion.split("AI:", 1)[-1])
|
|
|
101 |
except Exception as err: # noqa: BLE001
|
102 |
+
LOG.error(f"Inference error: {err}")
|
103 |
reply = (
|
104 |
+
"SorryβI'm having trouble right now. Please try again in a moment."
|
|
|
105 |
)
|
106 |
|
107 |
history.append((user_msg, reply))
|
108 |
return history, ""
|
109 |
|
110 |
+
# ββββββββββββββββββββββββββ Clear Chat fn βββββββββββββββββββββββββββββββββββ
|
111 |
+
def clear_chat():
|
112 |
+
return [], ""
|
113 |
+
|
114 |
+
# ββββββββββββββββββββββββββ UI Launch βββββββββββββββββββββββββββββββββββββββ
|
115 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
116 |
+
gr.Markdown("# SchoolSpiritΒ AI Chat")
|
117 |
+
chatbot = gr.Chatbot()
|
118 |
+
msg_in = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AIβ¦")
|
119 |
+
send_btn = gr.Button("Send")
|
120 |
+
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
121 |
+
|
122 |
+
send_btn.click(chat, [chatbot, msg_in], [chatbot, msg_in])
|
123 |
+
msg_in.submit(chat, [chatbot, msg_in], [chatbot, msg_in])
|
124 |
+
clear_btn.click(clear_chat, outputs=[chatbot, msg_in])
|
125 |
+
|
126 |
+
demo.launch()
|