navjotk's picture
Update app.py
23e2414 verified
raw
history blame
1.84 kB
import gradio as gr
from transformers import pipeline
from PIL import Image, ImageDraw, ImageFont
# Load YOLOS object detection model
detector = pipeline("object-detection", model="hustvl/yolos-small")
# Confidence threshold
CONFIDENCE_THRESHOLD = 0.5
# Color palette
COLORS = ["red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"]
def get_color_for_label(label):
return COLORS[hash(label) % len(COLORS)]
def detect_and_draw(image):
results = detector(image)
# Convert to RGB for drawing
image = image.convert("RGB")
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 16)
except:
font = ImageFont.load_default()
annotations = []
for obj in results:
score = obj["score"]
if score < CONFIDENCE_THRESHOLD:
continue
label = f"{obj['label']} ({score:.2f})"
box = obj["box"]
color = get_color_for_label(obj["label"])
# Draw box
draw.rectangle(
[(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])],
outline=color,
width=3,
)
# Draw label
draw.text(
(box["xmin"] + 5, box["ymin"] + 5),
label,
fill=color,
font=font
)
# AnnotatedImage expects (box_tuple, label)
box_coords = (box["xmin"], box["ymin"], box["xmax"], box["ymax"])
annotations.append((box_coords, label))
return image, annotations
# Gradio interface
demo = gr.Interface(
fn=detect_and_draw,
inputs=gr.Image(type="pil"),
outputs=gr.AnnotatedImage(),
title="YOLOS Object Detection",
description=f"Upload an image to detect objects using the YOLOS model. Only objects with confidence > {CONFIDENCE_THRESHOLD} are shown.",
)
demo.launch()