import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer import textwrap model_id = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) # Load model in full precision on CPU — no bitsandbytes model = AutoModelForCausalLM.from_pretrained( model_id, device_map="cpu", # Force CPU torch_dtype=torch.float32, # Use FP32 to ensure CPU compatibility ) model.eval() # Helper to format response nicely def print_response(text: str) -> str: return "\n".join(textwrap.fill(line, 100) for line in text.split("\n")) # Inference function for Gradio def predict_text(system_prompt: str, user_prompt: str) -> str: messages = [ {"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]}, {"role": "user", "content": [{"type": "text", "text": user_prompt.strip()}]}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to("cpu") input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): output = model.generate( **inputs, max_new_tokens=300, do_sample=False, use_cache=False # Important for CPU compatibility ) generated = output[0][input_len:] decoded = tokenizer.decode(generated, skip_special_tokens=True) return print_response(decoded) # Gradio UI demo = gr.Interface( fn=predict_text, inputs=[ gr.Textbox(lines=2, label="System Prompt", value="You are a helpful assistant."), gr.Textbox(lines=4, label="User Prompt", placeholder="Ask something..."), ], outputs=gr.Textbox(label="Gemma 3n Response"), title="Gemma 3n Chat (CPU-friendly)", description="Lightweight CPU-only chatbot using a quantized Gemma 3n model.", ) if __name__ == "__main__": demo.launch()