Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
from PIL import Image | |
import requests | |
import torch | |
import io | |
import os | |
from huggingface_hub import login | |
import tempfile | |
from threading import Thread | |
# Log in to Hugging Face for gated model access | |
login(token=os.getenv("HF_TOKEN", "your_hf_token_here")) # Replace with your HF token or set HF_TOKEN env variable | |
# Load model and processor | |
model_id = "google/gemma-3-4b-it" | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
model_id, device_map="auto", torch_dtype=torch.bfloat16 | |
).eval() | |
processor = AutoProcessor.from_pretrained(model_id) | |
def process_input(chat_history, image_file, image_url, text_input): | |
try: | |
# Validate text input | |
if not text_input: | |
yield chat_history + [(None, "Error: Please enter a text prompt.")] | |
return | |
# Handle image input | |
image_path = None | |
if image_file is not None: | |
image = Image.open(image_file).convert("RGB") | |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: | |
image.save(tmp.name) | |
image_path = tmp.name | |
elif image_url: | |
response = requests.get(image_url) | |
response.raise_for_status() | |
image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: | |
image.save(tmp.name) | |
image_path = tmp.name | |
else: | |
yield chat_history + [(text_input, "Error: Please provide an image file or a valid image URL.")] | |
return | |
# Prepare chat template (based on original code) | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": "You are a helpful assistant capable of analyzing images."}] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image_path}, | |
{"type": "text", "text": text_input} | |
] | |
} | |
] | |
# Process inputs with chat template | |
inputs = processor.apply_chat_template( | |
messages, add_generation_prompt=True, tokenize=True, | |
return_dict=True, return_tensors="pt" | |
).to(model.device, dtype=torch.bfloat16) | |
# Debug: Verify input tensors | |
if "pixel_values" not in inputs: | |
yield chat_history + [(text_input, "Error: Image data not processed correctly (missing pixel_values).")] | |
return | |
# print(f"Input keys: {inputs.keys()}, pixel_values shape: {inputs['pixel_values'].shape}") # Uncomment for debugging | |
# Initialize streamer | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"pixel_values": inputs.get("pixel_values"), | |
"max_new_tokens": 500, | |
"do_sample": True, | |
} | |
# Update chat history with user message | |
user_message = text_input | |
bot_message = "" | |
new_history = chat_history + [(user_message, bot_message)] | |
yield new_history | |
# Run generation in a separate thread for streaming | |
with torch.inference_mode(): | |
thread = Thread(target=model.generate, kwargs={**generation_kwargs, "streamer": streamer}) | |
thread.start() | |
for new_text in streamer: | |
bot_message += new_text | |
new_history[-1] = (user_message, bot_message) | |
yield new_history | |
thread.join() | |
# Clean up temporary file | |
if image_path and os.path.exists(image_path): | |
os.remove(image_path) | |
except Exception as e: | |
import traceback | |
yield chat_history + [(text_input, f"Error: {str(e)}\n{traceback.format_exc()}")] | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Multimodal Streaming Chat with Gemma-3-4b-it") | |
gr.Markdown("Upload an image or provide an image URL, then enter a text prompt (e.g., 'Describe this image in detail').") | |
with gr.Row(): | |
with gr.Column(): | |
image_file_input = gr.Image(label="Upload Image (Drag-and-Drop)", type="filepath") | |
image_url_input = gr.Textbox(label="Or Enter Image URL", placeholder="e.g., https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/bee.jpg") | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Chat with Gemma", height=400) | |
text_input = gr.Textbox( | |
label="Enter your prompt here", | |
placeholder="e.g., Describe this image in detail", | |
lines=2, | |
submit_btn=True | |
) | |
#submit_button = gr.Button("Submit") | |
text_input.submit( | |
fn=process_input, | |
inputs=[chatbot, image_file_input, image_url_input, text_input], | |
outputs=chatbot | |
) | |
'''submit_button.click( | |
fn=process_input, | |
inputs=[chatbot, image_file_input, image_url_input, text_input], | |
outputs=chatbot | |
)''' | |
# Launch the app | |
demo.launch() |