SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
2e445c2 verified
raw
history blame
6.78 kB
import os
import re
import time
import datetime
import traceback
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# ---------------------------------------------------------------------------
# 0. Paths & basic logging helper
# ---------------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
# ---------------------------------------------------------------------------
# 1. Configuration constants
# ---------------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2 B model fits Spaces
CONTEXT_TOKENS = 1800 # leave head‑room for reply inside 2k window
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.6
MAX_INPUT_CH = 300 # UI safeguard
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the official digital mascot of "
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
"turnkey GPU servers to K‑12 schools.\n\n"
"RULES:\n"
"• Friendly, concise (≤4 sentences unless prompted).\n"
"• No personal data collection; no medical/legal/financial advice.\n"
"• If uncertain, admit it & suggest human follow‑up.\n"
"• avoid profanity, politics, mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
# ---------------------------------------------------------------------------
# 2. Load tokenizer + model (GPU FP‑16 → CPU)
# ---------------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cpu",
torch_dtype="auto",
low_cpu_mem_usage=True,
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
return_full_text=False, # ← only return the newly generated text
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc:
MODEL_ERR = f"Model load error: {exc}"
generator = None
log(MODEL_ERR)
# ---------------------------------------------------------------------------
# 3. Helper: build prompt under token budget
# ---------------------------------------------------------------------------
def build_prompt(raw_history: list[dict]) -> str:
"""
raw_history: list [{'role':'system'|'user'|'assistant', 'content': str}, ...]
Keeps trimming oldest user/assistant pair until total tokens < CONTEXT_TOKENS
"""
def render(msg):
if msg["role"] == "system":
return msg["content"]
prefix = "User:" if msg["role"] == "user" else "AI:"
return f"{prefix} {msg['content']}"
# always include system
system_msg = [msg for msg in raw_history if msg["role"] == "system"][0]
convo = [m for m in raw_history if m["role"] != "system"]
# iterative trim
while True:
prompt_parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
token_len = len(tokenizer.encode("\n".join(prompt_parts), add_special_tokens=False))
if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
break
# drop oldest user+assistant pair
convo = convo[2:]
return "\n".join(prompt_parts)
# ---------------------------------------------------------------------------
# 4. Chat callback
# ---------------------------------------------------------------------------
def chat_fn(user_msg: str, display_history: list, state: dict):
"""
display_history : list[tuple[str,str]] for UI
state["raw"] : list[dict] for prompting
"""
user_msg = strip(user_msg or "")
if not user_msg:
return display_history, state
if len(user_msg) > MAX_INPUT_CH:
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
return display_history, state
if MODEL_ERR:
display_history.append((user_msg, MODEL_ERR))
return display_history, state
# --- Update raw history
state["raw"].append({"role": "user", "content": user_msg})
# --- Build prompt within token budget
prompt = build_prompt(state["raw"])
# --- Generate
try:
start = time.time()
result = generator(prompt)[0]
reply = strip(result["generated_text"])
log(f"Reply in {time.time() - start:.2f}s ({len(reply)} chars)")
except Exception:
log("❌ Inference error:\n" + traceback.format_exc())
reply = "Apologies—an internal error occurred. Please try again."
# --- Append assistant reply to both histories
display_history.append((user_msg, reply))
state["raw"].append({"role": "assistant", "content": reply})
return display_history, state
# ---------------------------------------------------------------------------
# 5. Launch Gradio Blocks UI
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("### SchoolSpirit AI Chat")
chatbot = gr.Chatbot(
value=[("", WELCOME_MSG)],
height=480,
label="SchoolSpirit AI",
)
state = gr.State(
{
"raw": [
{"role": "system", "content": SYSTEM_MSG},
{"role": "assistant", "content": WELCOME_MSG},
]
}
)
with gr.Row():
txt = gr.Textbox(
placeholder="Type your question here…",
show_label=False,
scale=4,
lines=1,
)
send_btn = gr.Button("Send", variant="primary")
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
demo.launch()