|
import torch |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
model_id = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
) |
|
model.to("cpu") |
|
model.eval() |
|
|
|
|
|
def generate_response(user_prompt): |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "You are a helpful assistant."}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [{"type": "text", "text": user_prompt.strip()}] |
|
} |
|
] |
|
|
|
inputs = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to("cpu") |
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
with torch.inference_mode(): |
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_new_tokens=100, |
|
do_sample=False, |
|
use_cache=False |
|
) |
|
|
|
generated_tokens = outputs[0][input_len:] |
|
decoded = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
return decoded.strip() |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_response, |
|
inputs=gr.Textbox(lines=3, label="Enter your question"), |
|
outputs=gr.Textbox(label="Gemma 3n Response"), |
|
title="🧪 Simple Gemma 3n Demo (CPU)", |
|
description="Test the Gemma 3n model with minimal output. Max 100 tokens.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|