import gradio as gr from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer from PIL import Image import requests import torch import io import os from huggingface_hub import login import tempfile from threading import Thread # Log in to Hugging Face for gated model access login(token=os.getenv("HF_TOKEN", "your_hf_token_here")) # Replace with your HF token or set HF_TOKEN env variable # Load model and processor model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ).eval() processor = AutoProcessor.from_pretrained(model_id) def process_input(chat_history, image_file, image_url, text_input): try: # Validate text input if not text_input: yield chat_history + [(None, "Error: Please enter a text prompt.")] return # Handle image input image_path = None if image_file is not None: image = Image.open(image_file).convert("RGB") with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: image.save(tmp.name) image_path = tmp.name elif image_url: response = requests.get(image_url) response.raise_for_status() image = Image.open(io.BytesIO(response.content)).convert("RGB") with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: image.save(tmp.name) image_path = tmp.name else: yield chat_history + [(text_input, "Error: Please provide an image file or a valid image URL.")] return # Prepare chat template (based on original code) messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a helpful assistant capable of analyzing images."}] }, { "role": "user", "content": [ {"type": "image", "image": image_path}, {"type": "text", "text": text_input} ] } ] # Process inputs with chat template inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) # Debug: Verify input tensors if "pixel_values" not in inputs: yield chat_history + [(text_input, "Error: Image data not processed correctly (missing pixel_values).")] return # print(f"Input keys: {inputs.keys()}, pixel_values shape: {inputs['pixel_values'].shape}") # Uncomment for debugging # Initialize streamer streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "pixel_values": inputs.get("pixel_values"), "max_new_tokens": 500, "do_sample": True, } # Update chat history with user message user_message = text_input bot_message = "" new_history = chat_history + [(user_message, bot_message)] yield new_history # Run generation in a separate thread for streaming with torch.inference_mode(): thread = Thread(target=model.generate, kwargs={**generation_kwargs, "streamer": streamer}) thread.start() for new_text in streamer: bot_message += new_text new_history[-1] = (user_message, bot_message) yield new_history thread.join() # Clean up temporary file if image_path and os.path.exists(image_path): os.remove(image_path) except Exception as e: import traceback yield chat_history + [(text_input, f"Error: {str(e)}\n{traceback.format_exc()}")] # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Multimodal Streaming Chat with Gemma-3-4b-it") gr.Markdown("Upload an image or provide an image URL, then enter a text prompt (e.g., 'Describe this image in detail').") with gr.Row(): with gr.Column(): image_file_input = gr.Image(label="Upload Image (Drag-and-Drop)", type="filepath") image_url_input = gr.Textbox(label="Or Enter Image URL", placeholder="e.g., https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/bee.jpg") with gr.Column(): chatbot = gr.Chatbot(label="Chat with Gemma", height=400) text_input = gr.Textbox( label="Enter your prompt here", placeholder="e.g., Describe this image in detail", lines=2, submit_btn=True ) #submit_button = gr.Button("Submit") text_input.submit( fn=process_input, inputs=[chatbot, image_file_input, image_url_input, text_input], outputs=chatbot ) '''submit_button.click( fn=process_input, inputs=[chatbot, image_file_input, image_url_input, text_input], outputs=chatbot )''' # Launch the app demo.launch()