gemma_test / app.py
broadfield-dev's picture
Update app.py
0cfbef9 verified
import gradio as gr
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
from PIL import Image
import requests
import torch
import io
import os
from huggingface_hub import login
import tempfile
from threading import Thread
# Log in to Hugging Face for gated model access
login(token=os.getenv("HF_TOKEN", "your_hf_token_here")) # Replace with your HF token or set HF_TOKEN env variable
# Load model and processor
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained(model_id)
def process_input(chat_history, image_file, image_url, text_input):
try:
# Validate text input
if not text_input:
yield chat_history + [(None, "Error: Please enter a text prompt.")]
return
# Handle image input
image_path = None
if image_file is not None:
image = Image.open(image_file).convert("RGB")
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp.name)
image_path = tmp.name
elif image_url:
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(io.BytesIO(response.content)).convert("RGB")
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp.name)
image_path = tmp.name
else:
yield chat_history + [(text_input, "Error: Please provide an image file or a valid image URL.")]
return
# Prepare chat template (based on original code)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant capable of analyzing images."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": text_input}
]
}
]
# Process inputs with chat template
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
# Debug: Verify input tensors
if "pixel_values" not in inputs:
yield chat_history + [(text_input, "Error: Image data not processed correctly (missing pixel_values).")]
return
# print(f"Input keys: {inputs.keys()}, pixel_values shape: {inputs['pixel_values'].shape}") # Uncomment for debugging
# Initialize streamer
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"pixel_values": inputs.get("pixel_values"),
"max_new_tokens": 500,
"do_sample": True,
}
# Update chat history with user message
user_message = text_input
bot_message = ""
new_history = chat_history + [(user_message, bot_message)]
yield new_history
# Run generation in a separate thread for streaming
with torch.inference_mode():
thread = Thread(target=model.generate, kwargs={**generation_kwargs, "streamer": streamer})
thread.start()
for new_text in streamer:
bot_message += new_text
new_history[-1] = (user_message, bot_message)
yield new_history
thread.join()
# Clean up temporary file
if image_path and os.path.exists(image_path):
os.remove(image_path)
except Exception as e:
import traceback
yield chat_history + [(text_input, f"Error: {str(e)}\n{traceback.format_exc()}")]
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Multimodal Streaming Chat with Gemma-3-4b-it")
gr.Markdown("Upload an image or provide an image URL, then enter a text prompt (e.g., 'Describe this image in detail').")
with gr.Row():
with gr.Column():
image_file_input = gr.Image(label="Upload Image (Drag-and-Drop)", type="filepath")
image_url_input = gr.Textbox(label="Or Enter Image URL", placeholder="e.g., https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/bee.jpg")
with gr.Column():
chatbot = gr.Chatbot(label="Chat with Gemma", height=400)
text_input = gr.Textbox(
label="Enter your prompt here",
placeholder="e.g., Describe this image in detail",
lines=2,
submit_btn=True
)
#submit_button = gr.Button("Submit")
text_input.submit(
fn=process_input,
inputs=[chatbot, image_file_input, image_url_input, text_input],
outputs=chatbot
)
'''submit_button.click(
fn=process_input,
inputs=[chatbot, image_file_input, image_url_input, text_input],
outputs=chatbot
)'''
# Launch the app
demo.launch()