3nhance / _app.py
Tiago Caldeira
different approach using unsloth
9f37a6e
import gradio as gr
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
import torch
import textwrap
from huggingface_hub import login
import os
# Log in using the HF token (automatically read from secret)
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
login(token=hf_token)
# πŸ”„ Load model and processor
model_id = "google/gemma-3n-e2b-it"
model_id = "google/gemma-3n-E2B"
model_id = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit"
model_id = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit"
processor = AutoProcessor.from_pretrained(model_id)
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="cpu"
).eval()
# πŸ› οΈ Helper to format output
def print_response(text: str) -> str:
return "\n".join(textwrap.fill(line, 100) for line in text.split("\n"))
# πŸ” Inference function for text-only input
def predict_text(system_prompt: str, user_prompt: str) -> str:
messages = [
{"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]},
{"role": "user", "content": [{"type": "text", "text": user_prompt.strip()}]},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=500,
do_sample=False,
use_cache=False # πŸ”₯ Fixes CPU bug
)
gen = output[0][input_len:]
decoded = processor.decode(gen, skip_special_tokens=True)
return print_response(decoded)
# πŸŽ›οΈ Gradio Interface
demo = gr.Interface(
fn=predict_text,
inputs=[
gr.Textbox(lines=2, label="System Prompt", value="You are a helpful assistant."),
gr.Textbox(lines=4, label="User Prompt", placeholder="Ask something..."),
],
outputs=gr.Textbox(label="Gemma 3n Response"),
title="Gemma 3n Text-Only Chat",
description="Interact with the Gemma 3n language model using plain text. Image input not required.",
)
if __name__ == "__main__":
demo.launch()