|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import uuid |
|
import os |
|
from datetime import datetime |
|
import spaces |
|
|
|
|
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" |
|
with open("system_prompt.txt", "r") as f: |
|
SYSTEM_PROMPT = f.read() |
|
LOG_DIR = "chat_logs" |
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
def log_chat(session_id, user_msg, bot_msg): |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
with open(os.path.join(LOG_DIR, f"{session_id}.txt"), "a") as f: |
|
f.write(f"[{timestamp}] User: {user_msg}\n") |
|
f.write(f"[{timestamp}] Bot: {bot_msg}\n\n") |
|
|
|
|
|
@spaces.GPU |
|
def load_model(): |
|
global model, tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" |
|
) |
|
model.eval() |
|
|
|
|
|
def format_chat_prompt(history, new_input): |
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
for user_msg, bot_msg in history: |
|
messages.append({"role": "user", "content": user_msg}) |
|
messages.append({"role": "assistant", "content": bot_msg}) |
|
messages.append({"role": "user", "content": new_input}) |
|
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
@torch.no_grad() |
|
def respond(message, history): |
|
prompt = format_chat_prompt(history, message) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
decoded = tokenizer.decode(output[0], skip_special_tokens=True) |
|
response = decoded.split(message)[-1].strip().split("\n")[0].strip() |
|
log_chat(session_id, message, response) |
|
return response |
|
|
|
load_model() |
|
|
|
|
|
gr.ChatInterface( |
|
fn=respond, |
|
title="BoundrAI", |
|
theme="soft" |
|
).launch() |