SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
747a64c verified
raw
history blame
4.77 kB
"""
SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot Space
----------------------------------------------
β€’ Uses IBM Granite‑3.3‑2B‑Instruct (public, no access token).
β€’ Fits HF CPU Space (2‑B params, bfloat16).
β€’ Keeps last MAX_TURNS exchanges.
β€’ β€œClear chat” button resets context.
β€’ Robust error handling & logging.
"""
import re
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
)
from transformers.utils import logging as hf_logging
# ────────────────────────── Config ──────────────────────────────────────────
hf_logging.set_verbosity_error()
LOG = hf_logging.get_logger("SchoolSpirit")
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
MAX_TURNS = 6
MAX_TOKENS = 200
MAX_INPUT_CH = 400
SYSTEM_MSG = (
"You are SchoolSpiritΒ AI, the upbeat digital mascot for a company that "
"offers on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
"hardware for schools. Answer concisely and age‑appropriately. If unsure, "
"say so and suggest contacting a human. Do not ask for personal data."
)
# ────────────────────────── Model Load ──────────────────────────────────────
try:
tok = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto", # bfloat16/float16 under the hood
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tok,
max_new_tokens=MAX_TOKENS,
do_sample=True,
temperature=0.7,
)
MODEL_ERR = None
except Exception as exc: # noqa: BLE001
MODEL_ERR = f"Model load error: {exc}"
generator = None
LOG.error(MODEL_ERR)
# ────────────────────────── Helpers ────────────────────────────────────────
def truncate(hist):
"""Return last MAX_TURNS (u,a) pairs."""
return hist[-MAX_TURNS:] if len(hist) > MAX_TURNS else hist
def clean(text: str) -> str:
"""Collapse whitespace; never return empty string."""
out = re.sub(r"\s+", " ", text.strip())
return out or "…"
# ────────────────────────── Chat Callback ───────────────────────────────────
def chat(history, user_msg):
history = list(history) # Gradio ensures list of tuples
if MODEL_ERR:
history.append((user_msg, MODEL_ERR))
return history, ""
user_msg = clean(user_msg or "")
if not user_msg:
history.append(("", "Please enter a message."))
return history, ""
if len(user_msg) > MAX_INPUT_CH:
history.append((user_msg, "That message is too long."))
return history, ""
history = truncate(history)
# Build prompt
prompt_lines = [SYSTEM_MSG]
for u, a in history:
prompt_lines += [f"User: {u}", f"AI: {a}"]
prompt_lines += [f"User: {user_msg}", "AI:"]
prompt = "\n".join(prompt_lines)
try:
completion = generator(prompt, truncate=4096)[0]["generated_text"]
reply = clean(completion.split("AI:", 1)[-1])
except Exception as err: # noqa: BLE001
LOG.error(f"Inference error: {err}")
reply = "Sorryβ€”I'm having trouble right now. Please try again shortly."
history.append((user_msg, reply))
return history, ""
# ────────────────────────── Clear Chat ──────────────────────────────────────
def clear_chat():
return [], ""
# ────────────────────────── UI Launch ───────────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("# SchoolSpiritΒ AI Chat")
chatbot = gr.Chatbot()
msg_box = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AI…")
send_btn = gr.Button("Send")
clear_btn = gr.Button("Clear Chat", variant="secondary")
send_btn.click(chat, [chatbot, msg_box], [chatbot, msg_box])
msg_box.submit(chat, [chatbot, msg_box], [chatbot, msg_box])
clear_btn.click(clear_chat, outputs=[chatbot, msg_box])
demo.launch()