Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers import BitsAndBytesConfig | |
import torch | |
# β Load the model and tokenizer | |
MODEL_ID = "pareshmishra/mt564-gemma-lora" | |
API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if not API_TOKEN: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set") | |
# Configure 4-bit quantization | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Enable 4-bit quantization | |
bnb_4bit_compute_dtype=torch.float16, # Use fp16 for computation | |
bnb_4bit_quant_type="nf4", # Normal Float 4-bit quantization | |
bnb_4bit_use_double_quant=True # Nested quantization for efficiency | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=API_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
token=API_TOKEN, | |
torch_dtype=torch.float16, # fp16 as per model card | |
device_map="auto", # Auto-map to GPU/CPU | |
quantization_config=quantization_config # Use BitsAndBytesConfig | |
) | |
def respond(messages, chatbot_history, system_message, max_tokens, temperature, top_p): | |
try: | |
# Build prompt from history | |
prompt = f"{system_message.strip()}\n\n" | |
for msg in messages: | |
if isinstance(msg, dict): | |
role = msg.get("role") | |
content = msg.get("content", "") | |
if role == "user": | |
prompt += f"User: {content.strip()}\n" | |
elif role == "assistant": | |
prompt += f"Assistant: {content.strip()}\n" | |
prompt += "Assistant:" | |
# Tokenize and generate | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response[len(prompt):].strip() | |
yield response if response else "β οΈ No response returned from the model." | |
except Exception as e: | |
yield f"β Error: {str(e)}\nDetails: {e.__class__.__name__}" | |
# Gradio Interface | |
demo = gr.ChatInterface( | |
fn=respond, | |
type="messages", | |
additional_inputs=[ | |
gr.Textbox( | |
lines=3, | |
label="System message", | |
value="You are an expert in SWIFT MT564 financial messaging. Analyze, validate, and answer related user questions.", | |
), | |
gr.Slider(50, 2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p sampling"), | |
], | |
title="π¬ MT564 Chat Assistant", | |
description="Analyze SWIFT MT564 messages or ask financial-related questions.", | |
theme="default" | |
) | |
if __name__ == "__main__": | |
demo.launch() |