import os import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import BitsAndBytesConfig import torch # ✅ Load the model and tokenizer MODEL_ID = "pareshmishra/mt564-gemma-lora" API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") if not API_TOKEN: raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set") # Configure 4-bit quantization quantization_config = BitsAndBytesConfig( load_in_4bit=True, # Enable 4-bit quantization bnb_4bit_compute_dtype=torch.float16, # Use fp16 for computation bnb_4bit_quant_type="nf4", # Normal Float 4-bit quantization bnb_4bit_use_double_quant=True # Nested quantization for better efficiency ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=API_TOKEN) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, token=API_TOKEN, torch_dtype=torch.float16, # fp16 as per model card device_map="auto", # Auto-map to GPU/CPU quantization_config=quantization_config # Use BitsAndBytesConfig ) def respond(messages, chatbot_history, system_message, max_tokens, temperature, top_p): try: # Build prompt from history prompt = f"{system_message.strip()}\n\n" for msg in messages: if isinstance(msg, dict): role = msg.get("role") content = msg.get("content", "") if role == "user": prompt += f"User: {content.strip()}\n" elif role == "assistant": prompt += f"Assistant: {content.strip()}\n" prompt += "Assistant:" # Tokenize and generate inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response[len(prompt):].strip() yield response if response else "⚠️ No response returned from the model." except Exception as e: yield f"❌ Error: {str(e)}\nDetails: {e.__class__.__name__}" # Gradio Interface demo = gr.ChatInterface( fn=respond, type="messages", additional_inputs=[ gr.Textbox( lines=3, label="System message", value="You are an expert in SWIFT MT564 financial messaging. Analyze, validate, and answer related user questions.", ), gr.Slider(50, 2048, value=512, step=1, label="Max new tokens"), gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"), gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p sampling"), ], title="💬 MT564 Chat Assistant", description="Analyze SWIFT MT564 messages or ask financial-related questions.", theme="default" ) if __name__ == "__main__": demo.launch()