Spaces:
Paused
Paused
""" | |
SchoolSpiritΒ AI β Graniteβ3.3β2B chatbot Space | |
---------------------------------------------- | |
β’ Uses IBM Graniteβ3.3β2BβInstruct (public, no access token). | |
β’ Fits HF CPU Space (2βB params, bfloat16). | |
β’ Keeps last MAX_TURNS exchanges. | |
β’ βClear chatβ button resets context. | |
β’ Robust error handling & logging. | |
""" | |
import re | |
import gradio as gr | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
) | |
from transformers.utils import logging as hf_logging | |
# ββββββββββββββββββββββββββ Config ββββββββββββββββββββββββββββββββββββββββββ | |
hf_logging.set_verbosity_error() | |
LOG = hf_logging.get_logger("SchoolSpirit") | |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" | |
MAX_TURNS = 6 | |
MAX_TOKENS = 200 | |
MAX_INPUT_CH = 400 | |
SYSTEM_MSG = ( | |
"You are SchoolSpiritΒ AI, the upbeat digital mascot for a company that " | |
"offers onβprem AI chat mascots, fineβtuning services, and turnkey GPU " | |
"hardware for schools. Answer concisely and ageβappropriately. If unsure, " | |
"say so and suggest contacting a human. Do not ask for personal data." | |
) | |
# ββββββββββββββββββββββββββ Model Load ββββββββββββββββββββββββββββββββββββββ | |
try: | |
tok = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", | |
torch_dtype="auto", # bfloat16/float16 under the hood | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tok, | |
max_new_tokens=MAX_TOKENS, | |
do_sample=True, | |
temperature=0.7, | |
) | |
MODEL_ERR = None | |
except Exception as exc: # noqa: BLE001 | |
MODEL_ERR = f"Model load error: {exc}" | |
generator = None | |
LOG.error(MODEL_ERR) | |
# ββββββββββββββββββββββββββ Helpers ββββββββββββββββββββββββββββββββββββββββ | |
def truncate(hist): | |
"""Return last MAX_TURNS (u,a) pairs.""" | |
return hist[-MAX_TURNS:] if len(hist) > MAX_TURNS else hist | |
def clean(text: str) -> str: | |
"""Collapse whitespace; never return empty string.""" | |
out = re.sub(r"\s+", " ", text.strip()) | |
return out or "β¦" | |
# ββββββββββββββββββββββββββ Chat Callback βββββββββββββββββββββββββββββββββββ | |
def chat(history, user_msg): | |
history = list(history) # Gradio ensures list of tuples | |
if MODEL_ERR: | |
history.append((user_msg, MODEL_ERR)) | |
return history, "" | |
user_msg = clean(user_msg or "") | |
if not user_msg: | |
history.append(("", "Please enter a message.")) | |
return history, "" | |
if len(user_msg) > MAX_INPUT_CH: | |
history.append((user_msg, "That message is too long.")) | |
return history, "" | |
history = truncate(history) | |
# Build prompt | |
prompt_lines = [SYSTEM_MSG] | |
for u, a in history: | |
prompt_lines += [f"User: {u}", f"AI: {a}"] | |
prompt_lines += [f"User: {user_msg}", "AI:"] | |
prompt = "\n".join(prompt_lines) | |
try: | |
completion = generator(prompt, truncate=4096)[0]["generated_text"] | |
reply = clean(completion.split("AI:", 1)[-1]) | |
except Exception as err: # noqa: BLE001 | |
LOG.error(f"Inference error: {err}") | |
reply = "SorryβI'm having trouble right now. Please try again shortly." | |
history.append((user_msg, reply)) | |
return history, "" | |
# ββββββββββββββββββββββββββ Clear Chat ββββββββββββββββββββββββββββββββββββββ | |
def clear_chat(): | |
return [], "" | |
# ββββββββββββββββββββββββββ UI Launch βββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
gr.Markdown("# SchoolSpiritΒ AI Chat") | |
chatbot = gr.Chatbot() | |
msg_box = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AIβ¦") | |
send_btn = gr.Button("Send") | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
send_btn.click(chat, [chatbot, msg_box], [chatbot, msg_box]) | |
msg_box.submit(chat, [chatbot, msg_box], [chatbot, msg_box]) | |
clear_btn.click(clear_chat, outputs=[chatbot, msg_box]) | |
demo.launch() | |