File size: 3,003 Bytes
f8efb44
c0a1d2b
ca1c912
0d89abd
ca1c912
c0a1d2b
ca1c912
0876638
0b531ed
f8efb44
 
ca1c912
0d89abd
 
 
 
 
 
 
 
ca1c912
 
 
 
 
 
0d89abd
ca1c912
0876638
f0a518a
 
ca1c912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0a518a
 
ca1c912
 
ce8b430
ca1c912
 
 
 
 
 
f0a518a
ca1c912
c0a1d2b
ca1c912
c0a1d2b
cb04a47
dac3075
c0a1d2b
f0a518a
fceb5ae
f0a518a
fceb5ae
f0a518a
0876638
 
 
f0a518a
fceb5ae
 
f0a518a
c754e51
 
 
dac3075
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import torch

# ✅ Load the model and tokenizer
MODEL_ID = "pareshmishra/mt564-gemma-lora"
API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not API_TOKEN:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set")

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Enable 4-bit quantization
    bnb_4bit_compute_dtype=torch.float16,  # Use fp16 for computation
    bnb_4bit_quant_type="nf4",  # Normal Float 4-bit quantization
    bnb_4bit_use_double_quant=True  # Nested quantization for efficiency
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=API_TOKEN,
    torch_dtype=torch.float16,  # fp16 as per model card
    device_map="auto",  # Auto-map to GPU/CPU
    quantization_config=quantization_config  # Use BitsAndBytesConfig
)

def respond(messages, chatbot_history, system_message, max_tokens, temperature, top_p):
    try:
        # Build prompt from history
        prompt = f"{system_message.strip()}\n\n"
        for msg in messages:
            if isinstance(msg, dict):
                role = msg.get("role")
                content = msg.get("content", "")
                if role == "user":
                    prompt += f"User: {content.strip()}\n"
                elif role == "assistant":
                    prompt += f"Assistant: {content.strip()}\n"
        prompt += "Assistant:"

        # Tokenize and generate
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response[len(prompt):].strip()

        yield response if response else "⚠️ No response returned from the model."

    except Exception as e:
        yield f"❌ Error: {str(e)}\nDetails: {e.__class__.__name__}"

# Gradio Interface
demo = gr.ChatInterface(
    fn=respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(
            lines=3,
            label="System message",
            value="You are an expert in SWIFT MT564 financial messaging. Analyze, validate, and answer related user questions.",
        ),
        gr.Slider(50, 2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p sampling"),
    ],
    title="💬 MT564 Chat Assistant",
    description="Analyze SWIFT MT564 messages or ask financial-related questions.",
    theme="default"
)

if __name__ == "__main__":
    demo.launch()