ethiotech4848's picture
Update app.py
2911686 verified
import os
import json
import time
import gradio as gr
import numpy as np
import torch
# from gradio.themes.Soft import Soft
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import (
AutoProcessor,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
from spaces import GPU
import supervision as sv
# --- Config ---
# IMPORTANT: Both models are gated. You must be logged in to your Hugging Face account
# and have been granted access to use them.
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)
model_qwen_id = "Qwen/Qwen2.5-VL-3B-Instruct"
model_gemma_id = "google/gemma-3-4b-it"
# Load Qwen Model
model_qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_qwen_id, torch_dtype="auto", device_map="auto"
)
min_pixels = 224 * 224
max_pixels = 1024 * 1024
processor_qwen = AutoProcessor.from_pretrained(
model_qwen_id, min_pixels=min_pixels, max_pixels=max_pixels
)
# Load Gemma Model
model_gemma = Gemma3ForConditionalGeneration.from_pretrained(
model_gemma_id,
torch_dtype=torch.bfloat16, # Recommended dtype for Gemma
device_map="auto"
)
processor_gemma = AutoProcessor.from_pretrained(model_gemma_id)
def extract_model_short_name(model_id):
return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
model_qwen_name = extract_model_short_name(model_qwen_id) # β†’ "Qwen2.5 VL 3B Instruct"
model_gemma_name = extract_model_short_name(model_gemma_id) # β†’ "gemma 3 4b it"
def create_annotated_image(image, json_data, height, width):
try:
# Standardize parsing for outputs wrapped in markdown
if "```json" in json_data:
parsed_json_data = json_data.split("```json")[1].split("```")[0]
else:
parsed_json_data = json_data
bbox_data = json.loads(parsed_json_data)
except Exception:
# If parsing fails, return the original image
return image
# Ensure bbox_data is a list
if not isinstance(bbox_data, list):
bbox_data = [bbox_data]
original_width, original_height = image.size
x_scale = original_width / width
y_scale = original_height / height
points = []
point_labels = []
annotated_image = np.array(image.convert("RGB"))
detections_exist = False
# Check if there are bounding boxes in the data to create detections
if any("box_2d" in item for item in bbox_data):
detections_exist = True
# Use Qwen parser as a generic VLM parser for bounding boxes
detections = sv.Detections.from_vlm(vlm = sv.VLM.QWEN_2_5_VL,
result=json_data,
# resolution_wh is the size model "sees"
resolution_wh=(width, height))
bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_image = bounding_box_annotator.annotate(
scene=annotated_image, detections=detections
)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections
)
# Handle points separately
for item in bbox_data:
label = item.get("label", "")
if "point_2d" in item:
x, y = item["point_2d"]
scaled_x = int(x * x_scale)
scaled_y = int(y * y_scale)
points.append([scaled_x, scaled_y])
point_labels.append(label)
if points:
points_array = np.array(points).reshape(1, -1, 2)
key_points = sv.KeyPoints(xy=points_array)
vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.BLUE)
annotated_image = vertex_annotator.annotate(
scene=annotated_image, key_points=key_points
)
return Image.fromarray(annotated_image)
@GPU
def detect_qwen(image, prompt):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
t0 = time.perf_counter()
text = processor_qwen.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor_qwen(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model_qwen.device)
generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor_qwen.batch_decode(
generated_ids_trimmed,
do_sample=True,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
elapsed_ms = (time.perf_counter() - t0) * 1_000
# These dimensions are specific to how Qwen's processor handles images
input_height = inputs["image_grid_thw"][0][1] * 14
input_width = inputs["image_grid_thw"][0][2] * 14
annotated_image = create_annotated_image(
image, output_text, input_height, input_width
)
time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms"
return annotated_image, output_text, time_taken
@GPU
def detect_gemma(image, prompt):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
t0 = time.perf_counter()
inputs = processor_gemma.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model_gemma.device)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model_gemma.generate(**inputs, max_new_tokens=1024, do_sample=False)
generation_trimmed = generation[0][input_len:]
output_text = processor_gemma.decode(generation_trimmed, skip_special_tokens=True)
elapsed_ms = (time.perf_counter() - t0) * 1_000
# Gemma's vision encoder normalizes images to a fixed size (e.g., 896x896)
input_height = 896
input_width = 896
annotated_image = create_annotated_image(
image, output_text, input_height, input_width
)
time_taken = f"**Inference time ({model_gemma_name}):** {elapsed_ms:.0f} ms"
return annotated_image, output_text, time_taken
def detect(image, prompt_model_1, prompt_model_2):
STANDARD_SIZE = (1024, 1024)
image.thumbnail(STANDARD_SIZE)
annotated_image_model_1, output_text_model_1, timing_1 = detect_qwen(
image, prompt_model_1
)
annotated_image_model_2, output_text_model_2, timing_2 = detect_gemma(
image, prompt_model_2
)
return (
annotated_image_model_1,
output_text_model_1,
timing_1,
annotated_image_model_2,
output_text_model_2,
timing_2,
)
css_hide_share = """
button#gradio-share-link-button-0 {
display: none !important;
}
"""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css=css_hide_share) as demo:
gr.Markdown("# Object Detection & Understanding: Qwen vs. Gemma")
gr.Markdown(
"### Compare object detection, visual grounding, and keypoint detection using natural language prompts with two leading VLMs."
)
gr.Markdown("""
*Powered by [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) and [Gemma 3 4B IT](https://huggingface.co/google/gemma-3-4b-it). For best results, ask the model to return a JSON list in a markdown block. Inspired by the [HF Team's space](https://huggingface.co/spaces/sergiopaniego/vlm_object_understanding), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
""")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(label="Upload an image", type="pil", height=400)
prompt_input_model_1 = gr.Textbox(
label=f"Enter your prompt for {model_qwen_name}",
placeholder="e.g., Detect all red cars. Return a JSON list with 'box_2d' and 'label'.",
)
prompt_input_model_2 = gr.Textbox(
label=f"Enter your prompt for {model_gemma_name}",
placeholder="e.g., Detect all red cars. Return a JSON list with 'box_2d' and 'label'.",
)
generate_btn = gr.Button(value="Generate")
with gr.Column(scale=1):
output_image_model_1 = gr.Image(
type="pil", label=f"Annotated image from {model_qwen_name}", height=400
)
output_textbox_model_1 = gr.Textbox(
label=f"Model response from {model_qwen_name}", lines=10
)
output_time_model_1 = gr.Markdown()
with gr.Column(scale=1):
output_image_model_2 = gr.Image(
type="pil",
label=f"Annotated image from {model_gemma_name}",
height=400,
)
output_textbox_model_2 = gr.Textbox(
label=f"Model response from {model_gemma_name}", lines=10
)
output_time_model_2 = gr.Markdown()
gr.Markdown("### Examples")
prompt_obj_detect = "Detect all objects in this image. For each object, provide a 'box_2d' and a 'label'. Return the output as a JSON list inside a markdown block."
prompt_candy_detect = "Detect all individual candies in this image. For each, provide a 'box_2d' and a 'label'. Return the output as a JSON list inside a markdown block."
prompt_car_count = "Count the number of red cars in the image."
prompt_candy_count = "Count the number of blue candies in the image."
prompt_car_keypoint = "Identify the red cars in this image. For each, detect its key points and return their positions as 'point_2d' in a JSON list inside a markdown block."
prompt_candy_keypoint = "Identify the blue candies in this image. For each, detect its key points and return their positions as 'point_2d' in a JSON list inside a markdown block."
prompt_car_ground = "Detect the red car that is leading in this image. Return its location with 'box_2d' and 'label' in a JSON list inside a markdown block."
prompt_candy_ground = "Detect the blue candy at the top of the group. Return its location with 'box_2d' and 'label' in a JSON list inside a markdown block."
example_prompts = [
["examples/example_1.jpg", prompt_obj_detect, prompt_obj_detect],
["examples/example_2.JPG", prompt_candy_detect, prompt_candy_detect],
["examples/example_1.jpg", prompt_car_count, prompt_car_count],
["examples/example_2.JPG", prompt_candy_count, prompt_candy_count],
["examples/example_1.jpg", prompt_car_keypoint, prompt_car_keypoint],
["examples/example_2.JPG", prompt_candy_keypoint, prompt_candy_keypoint],
["examples/example_1.jpg", prompt_car_ground, prompt_car_ground],
["examples/example_2.JPG", prompt_candy_ground, prompt_candy_ground],
]
gr.Examples(
examples=example_prompts,
inputs=[
image_input,
prompt_input_model_1,
prompt_input_model_2,
],
label="Click an example to populate the input",
)
generate_btn.click(
fn=detect,
inputs=[
image_input,
prompt_input_model_1,
prompt_input_model_2,
],
outputs=[
output_image_model_1,
output_textbox_model_1,
output_time_model_1,
output_image_model_2,
output_textbox_model_2,
output_time_model_2,
],
)
if __name__ == "__main__":
demo.launch()