|
import gradio as gr |
|
from transformers import AutoProcessor, Gemma3nForConditionalGeneration |
|
import torch |
|
import textwrap |
|
from huggingface_hub import login |
|
import os |
|
|
|
|
|
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
login(token=hf_token) |
|
|
|
|
|
|
|
|
|
model_id = "google/gemma-3n-e2b-it" |
|
model_id = "google/gemma-3n-E2B" |
|
model_id = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit" |
|
model_id = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = Gemma3nForConditionalGeneration.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
device_map="cpu" |
|
).eval() |
|
|
|
|
|
def print_response(text: str) -> str: |
|
return "\n".join(textwrap.fill(line, 100) for line in text.split("\n")) |
|
|
|
|
|
def predict_text(system_prompt: str, user_prompt: str) -> str: |
|
messages = [ |
|
{"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]}, |
|
{"role": "user", "content": [{"type": "text", "text": user_prompt.strip()}]}, |
|
] |
|
|
|
inputs = processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
with torch.inference_mode(): |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=500, |
|
do_sample=False, |
|
use_cache=False |
|
) |
|
|
|
gen = output[0][input_len:] |
|
decoded = processor.decode(gen, skip_special_tokens=True) |
|
return print_response(decoded) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_text, |
|
inputs=[ |
|
gr.Textbox(lines=2, label="System Prompt", value="You are a helpful assistant."), |
|
gr.Textbox(lines=4, label="User Prompt", placeholder="Ask something..."), |
|
], |
|
outputs=gr.Textbox(label="Gemma 3n Response"), |
|
title="Gemma 3n Text-Only Chat", |
|
description="Interact with the Gemma 3n language model using plain text. Image input not required.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|