|
import birder |
|
from birder.inference.detection import infer_image |
|
from huggingface_hub import HfApi |
|
|
|
import gradio as gr |
|
|
|
|
|
def get_birder_classification_models(): |
|
api = HfApi() |
|
models = api.list_models(author="birder-project", tags="object-detection") |
|
return [model.modelId.split("/")[-1] for model in models] |
|
|
|
|
|
def get_selected_models(): |
|
return [ |
|
"deformable_detr_boxref_coco_convnext_v2_tiny_imagenet21k", |
|
] |
|
|
|
|
|
def load_model_and_detect(image, model_name, min_score): |
|
if len(birder.list_pretrained_models(model_name)) == 0: |
|
model_name = birder.list_pretrained_models(model_name + "*")[0] |
|
|
|
(net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(model_name, inference=True) |
|
|
|
size = birder.get_size_from_signature(signature) |
|
transform = birder.detection_transform(size, rgb_stats, dynamic_size=signature["dynamic"]) |
|
detections = infer_image(net, image, transform, score_threshold=min_score) |
|
|
|
idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys())) |
|
label_names = [idx_to_class[i.item()] for i in detections["labels"]] |
|
|
|
return (detections, label_names) |
|
|
|
|
|
def predict(image, model_name, min_score): |
|
(detections, label_names) = load_model_and_detect(image, model_name, min_score) |
|
if detections is None: |
|
return (image, []) |
|
|
|
annotations = [] |
|
boxes = detections["boxes"] |
|
scores = detections["scores"] |
|
|
|
for box, score, label in zip(boxes, scores, label_names): |
|
(x1, y1, x2, y2) = box.tolist() |
|
annotation = ((int(x1), int(y1), int(x2), int(y2)), f"{label} ({score:.2f})") |
|
annotations.append(annotation) |
|
|
|
return (image, annotations) |
|
|
|
|
|
def create_interface(): |
|
models = get_selected_models() |
|
|
|
examples = [ |
|
["safari.jpeg", "deformable_detr_boxref_coco_convnext_v2_tiny_imagenet21k", 0.45], |
|
] |
|
|
|
|
|
iface = gr.Interface( |
|
analytics_enabled=False, |
|
deep_link=False, |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Dropdown( |
|
choices=models, |
|
label="Select Model", |
|
value=models[0] if models else None, |
|
), |
|
gr.Slider( |
|
minimum=0.05, |
|
maximum=0.95, |
|
step=0.05, |
|
value=0.5, |
|
label="Minimum Score Threshold", |
|
info="Only detections with confidence above this threshold will be shown", |
|
), |
|
], |
|
outputs=gr.AnnotatedImage(), |
|
examples=examples, |
|
title="Birder Object Detection", |
|
description="Select a model and upload an image or use one of the examples to get bird detections.", |
|
) |
|
|
|
return iface |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |
|
|