Spaces:
Paused
Paused
import os | |
import re | |
import time | |
import datetime | |
import traceback | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from transformers.utils import logging as hf_logging | |
# --------------------------------------------------------------------------- | |
# 0. Paths & basic logging helper | |
# --------------------------------------------------------------------------- | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
LOG_FILE = "/data/requests.log" | |
def log(msg: str): | |
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3] | |
line = f"[{ts}] {msg}" | |
print(line, flush=True) | |
try: | |
with open(LOG_FILE, "a") as f: | |
f.write(line + "\n") | |
except FileNotFoundError: | |
pass | |
# --------------------------------------------------------------------------- | |
# 1. Configuration constants | |
# --------------------------------------------------------------------------- | |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2 B model fits Spaces | |
CONTEXT_TOKENS = 1800 # leave head‑room for reply inside 2k window | |
MAX_NEW_TOKENS = 64 | |
TEMPERATURE = 0.6 | |
MAX_INPUT_CH = 300 # UI safeguard | |
SYSTEM_MSG = ( | |
"You are **SchoolSpirit AI**, the official digital mascot of " | |
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company " | |
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships " | |
"turnkey GPU servers to K‑12 schools.\n\n" | |
"RULES:\n" | |
"• Friendly, concise (≤4 sentences unless prompted).\n" | |
"• No personal data collection; no medical/legal/financial advice.\n" | |
"• If uncertain, admit it & suggest human follow‑up.\n" | |
"• avoid profanity, politics, mature themes." | |
) | |
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?" | |
strip = lambda s: re.sub(r"\s+", " ", s.strip()) | |
# --------------------------------------------------------------------------- | |
# 2. Load tokenizer + model (GPU FP‑16 → CPU) | |
# --------------------------------------------------------------------------- | |
hf_logging.set_verbosity_error() | |
try: | |
log("Loading tokenizer …") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
if torch.cuda.is_available(): | |
log("GPU detected → FP‑16") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, device_map="auto", torch_dtype=torch.float16 | |
) | |
else: | |
log("CPU fallback") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="cpu", | |
torch_dtype="auto", | |
low_cpu_mem_usage=True, | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=MAX_NEW_TOKENS, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
return_full_text=False, # ← only return the newly generated text | |
) | |
MODEL_ERR = None | |
log("Model loaded ✔") | |
except Exception as exc: | |
MODEL_ERR = f"Model load error: {exc}" | |
generator = None | |
log(MODEL_ERR) | |
# --------------------------------------------------------------------------- | |
# 3. Helper: build prompt under token budget | |
# --------------------------------------------------------------------------- | |
def build_prompt(raw_history: list[dict]) -> str: | |
""" | |
raw_history: list [{'role':'system'|'user'|'assistant', 'content': str}, ...] | |
Keeps trimming oldest user/assistant pair until total tokens < CONTEXT_TOKENS | |
""" | |
def render(msg): | |
if msg["role"] == "system": | |
return msg["content"] | |
prefix = "User:" if msg["role"] == "user" else "AI:" | |
return f"{prefix} {msg['content']}" | |
# always include system | |
system_msg = [msg for msg in raw_history if msg["role"] == "system"][0] | |
convo = [m for m in raw_history if m["role"] != "system"] | |
# iterative trim | |
while True: | |
prompt_parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"] | |
token_len = len(tokenizer.encode("\n".join(prompt_parts), add_special_tokens=False)) | |
if token_len <= CONTEXT_TOKENS or len(convo) <= 2: | |
break | |
# drop oldest user+assistant pair | |
convo = convo[2:] | |
return "\n".join(prompt_parts) | |
# --------------------------------------------------------------------------- | |
# 4. Chat callback | |
# --------------------------------------------------------------------------- | |
def chat_fn(user_msg: str, display_history: list, state: dict): | |
""" | |
display_history : list[tuple[str,str]] for UI | |
state["raw"] : list[dict] for prompting | |
""" | |
user_msg = strip(user_msg or "") | |
if not user_msg: | |
return display_history, state | |
if len(user_msg) > MAX_INPUT_CH: | |
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars.")) | |
return display_history, state | |
if MODEL_ERR: | |
display_history.append((user_msg, MODEL_ERR)) | |
return display_history, state | |
# --- Update raw history | |
state["raw"].append({"role": "user", "content": user_msg}) | |
# --- Build prompt within token budget | |
prompt = build_prompt(state["raw"]) | |
# --- Generate | |
try: | |
start = time.time() | |
result = generator(prompt)[0] | |
reply = strip(result["generated_text"]) | |
log(f"Reply in {time.time() - start:.2f}s ({len(reply)} chars)") | |
except Exception: | |
log("❌ Inference error:\n" + traceback.format_exc()) | |
reply = "Apologies—an internal error occurred. Please try again." | |
# --- Append assistant reply to both histories | |
display_history.append((user_msg, reply)) | |
state["raw"].append({"role": "assistant", "content": reply}) | |
return display_history, state | |
# --------------------------------------------------------------------------- | |
# 5. Launch Gradio Blocks UI | |
# --------------------------------------------------------------------------- | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
gr.Markdown("### SchoolSpirit AI Chat") | |
chatbot = gr.Chatbot( | |
value=[("", WELCOME_MSG)], | |
height=480, | |
label="SchoolSpirit AI", | |
) | |
state = gr.State( | |
{ | |
"raw": [ | |
{"role": "system", "content": SYSTEM_MSG}, | |
{"role": "assistant", "content": WELCOME_MSG}, | |
] | |
} | |
) | |
with gr.Row(): | |
txt = gr.Textbox( | |
placeholder="Type your question here…", | |
show_label=False, | |
scale=4, | |
lines=1, | |
) | |
send_btn = gr.Button("Send", variant="primary") | |
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state]) | |
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state]) | |
demo.launch() | |