mt564-chat-api / app.py
pareshmishra's picture
Update app.py
0d89abd verified
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import torch
# βœ… Load the model and tokenizer
MODEL_ID = "pareshmishra/mt564-gemma-lora"
API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not API_TOKEN:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set")
# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Enable 4-bit quantization
bnb_4bit_compute_dtype=torch.float16, # Use fp16 for computation
bnb_4bit_quant_type="nf4", # Normal Float 4-bit quantization
bnb_4bit_use_double_quant=True # Nested quantization for efficiency
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=API_TOKEN,
torch_dtype=torch.float16, # fp16 as per model card
device_map="auto", # Auto-map to GPU/CPU
quantization_config=quantization_config # Use BitsAndBytesConfig
)
def respond(messages, chatbot_history, system_message, max_tokens, temperature, top_p):
try:
# Build prompt from history
prompt = f"{system_message.strip()}\n\n"
for msg in messages:
if isinstance(msg, dict):
role = msg.get("role")
content = msg.get("content", "")
if role == "user":
prompt += f"User: {content.strip()}\n"
elif role == "assistant":
prompt += f"Assistant: {content.strip()}\n"
prompt += "Assistant:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
yield response if response else "⚠️ No response returned from the model."
except Exception as e:
yield f"❌ Error: {str(e)}\nDetails: {e.__class__.__name__}"
# Gradio Interface
demo = gr.ChatInterface(
fn=respond,
type="messages",
additional_inputs=[
gr.Textbox(
lines=3,
label="System message",
value="You are an expert in SWIFT MT564 financial messaging. Analyze, validate, and answer related user questions.",
),
gr.Slider(50, 2048, value=512, step=1, label="Max new tokens"),
gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p sampling"),
],
title="πŸ’¬ MT564 Chat Assistant",
description="Analyze SWIFT MT564 messages or ask financial-related questions.",
theme="default"
)
if __name__ == "__main__":
demo.launch()