BoundrAI / app.py
frimelle's picture
frimelle HF Staff
debug
492d2a0
raw
history blame
2.3 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import uuid
import os
from datetime import datetime
import spaces # required for ZeroGPU
# ----- Constants -----
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
with open("system_prompt.txt", "r") as f:
SYSTEM_PROMPT = f.read()
LOG_DIR = "chat_logs"
os.makedirs(LOG_DIR, exist_ok=True)
# Global vars to hold model and tokenizer
model = None
tokenizer = None
session_id = str(uuid.uuid4())
# ----- Log Chat -----
def log_chat(session_id, user_msg, bot_msg):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(os.path.join(LOG_DIR, f"{session_id}.txt"), "a") as f:
f.write(f"[{timestamp}] User: {user_msg}\n")
f.write(f"[{timestamp}] Bot: {bot_msg}\n\n")
# ----- Required by ZeroGPU -----
@spaces.GPU
def load_model():
global model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
model.eval()
# ----- Inference Function -----
def format_chat_prompt(history, new_input):
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": new_input})
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
@torch.no_grad()
def respond(message, history):
prompt = format_chat_prompt(history, message)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
response = decoded.split(message)[-1].strip().split("\n")[0].strip()
log_chat(session_id, message, response)
return response
load_model()
# ----- Gradio App -----
gr.ChatInterface(
fn=respond,
title="BoundrAI",
theme="soft"
).launch()