Spaces:
Sleeping
Sleeping
File size: 5,350 Bytes
908c66a 22a9ea9 d9f9a8f 908c66a 5122672 22a9ea9 d9f9a8f 1334447 fa3d596 6a94503 908c66a 1334447 d9f9a8f 1334447 33c3ee8 07e9630 908c66a fa3d596 07e9630 6a94503 1a1967b 22a9ea9 5122672 b5bce8a 5122672 b5bce8a 6a94503 b5bce8a 5122672 b5bce8a 6a94503 908c66a 5122672 1334447 fa3d596 22a9ea9 fa3d596 0cfbef9 22a9ea9 a3c4da8 22a9ea9 6a94503 e3605f4 a3c4da8 22a9ea9 1334447 22a9ea9 6a94503 22a9ea9 5122672 908c66a fa3d596 d9f9a8f 22a9ea9 908c66a fa3d596 b5bce8a 908c66a b5bce8a fa3d596 07e9630 1334447 253c5c2 6a94503 253c5c2 b5bce8a 0cfbef9 b5bce8a 253c5c2 0cfbef9 908c66a 07e9630 2b1972b 0cfbef9 d9f9a8f 22a9ea9 908c66a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
from PIL import Image
import requests
import torch
import io
import os
from huggingface_hub import login
import tempfile
from threading import Thread
# Log in to Hugging Face for gated model access
login(token=os.getenv("HF_TOKEN", "your_hf_token_here")) # Replace with your HF token or set HF_TOKEN env variable
# Load model and processor
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained(model_id)
def process_input(chat_history, image_file, image_url, text_input):
try:
# Validate text input
if not text_input:
yield chat_history + [(None, "Error: Please enter a text prompt.")]
return
# Handle image input
image_path = None
if image_file is not None:
image = Image.open(image_file).convert("RGB")
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp.name)
image_path = tmp.name
elif image_url:
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(io.BytesIO(response.content)).convert("RGB")
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp.name)
image_path = tmp.name
else:
yield chat_history + [(text_input, "Error: Please provide an image file or a valid image URL.")]
return
# Prepare chat template (based on original code)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant capable of analyzing images."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": text_input}
]
}
]
# Process inputs with chat template
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
# Debug: Verify input tensors
if "pixel_values" not in inputs:
yield chat_history + [(text_input, "Error: Image data not processed correctly (missing pixel_values).")]
return
# print(f"Input keys: {inputs.keys()}, pixel_values shape: {inputs['pixel_values'].shape}") # Uncomment for debugging
# Initialize streamer
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"pixel_values": inputs.get("pixel_values"),
"max_new_tokens": 500,
"do_sample": True,
}
# Update chat history with user message
user_message = text_input
bot_message = ""
new_history = chat_history + [(user_message, bot_message)]
yield new_history
# Run generation in a separate thread for streaming
with torch.inference_mode():
thread = Thread(target=model.generate, kwargs={**generation_kwargs, "streamer": streamer})
thread.start()
for new_text in streamer:
bot_message += new_text
new_history[-1] = (user_message, bot_message)
yield new_history
thread.join()
# Clean up temporary file
if image_path and os.path.exists(image_path):
os.remove(image_path)
except Exception as e:
import traceback
yield chat_history + [(text_input, f"Error: {str(e)}\n{traceback.format_exc()}")]
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Multimodal Streaming Chat with Gemma-3-4b-it")
gr.Markdown("Upload an image or provide an image URL, then enter a text prompt (e.g., 'Describe this image in detail').")
with gr.Row():
with gr.Column():
image_file_input = gr.Image(label="Upload Image (Drag-and-Drop)", type="filepath")
image_url_input = gr.Textbox(label="Or Enter Image URL", placeholder="e.g., https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/bee.jpg")
with gr.Column():
chatbot = gr.Chatbot(label="Chat with Gemma", height=400)
text_input = gr.Textbox(
label="Enter your prompt here",
placeholder="e.g., Describe this image in detail",
lines=2,
submit_btn=True
)
#submit_button = gr.Button("Submit")
text_input.submit(
fn=process_input,
inputs=[chatbot, image_file_input, image_url_input, text_input],
outputs=chatbot
)
'''submit_button.click(
fn=process_input,
inputs=[chatbot, image_file_input, image_url_input, text_input],
outputs=chatbot
)'''
# Launch the app
demo.launch() |