import birder from birder.inference.detection import infer_image from huggingface_hub import HfApi import gradio as gr def get_birder_classification_models(): api = HfApi() models = api.list_models(author="birder-project", tags="object-detection") return [model.modelId.split("/")[-1] for model in models] def get_selected_models(): return [ "deformable_detr_boxref_coco_convnext_v2_tiny_imagenet21k", ] def load_model_and_detect(image, model_name, min_score): if len(birder.list_pretrained_models(model_name)) == 0: model_name = birder.list_pretrained_models(model_name + "*")[0] (net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(model_name, inference=True) size = birder.get_size_from_signature(signature) transform = birder.detection_transform(size, rgb_stats, dynamic_size=signature["dynamic"]) detections = infer_image(net, image, transform, score_threshold=min_score) idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys())) label_names = [idx_to_class[i.item()] for i in detections["labels"]] return (detections, label_names) def predict(image, model_name, min_score): (detections, label_names) = load_model_and_detect(image, model_name, min_score) if detections is None: return (image, []) annotations = [] boxes = detections["boxes"] scores = detections["scores"] for box, score, label in zip(boxes, scores, label_names): (x1, y1, x2, y2) = box.tolist() annotation = ((int(x1), int(y1), int(x2), int(y2)), f"{label} ({score:.2f})") annotations.append(annotation) return (image, annotations) def create_interface(): models = get_selected_models() examples = [ ["safari.jpeg", "deformable_detr_boxref_coco_convnext_v2_tiny_imagenet21k", 0.45], ] # Create interface iface = gr.Interface( analytics_enabled=False, deep_link=False, fn=predict, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown( choices=models, label="Select Model", value=models[0] if models else None, ), gr.Slider( minimum=0.05, maximum=0.95, step=0.05, value=0.5, label="Minimum Score Threshold", info="Only detections with confidence above this threshold will be shown", ), ], outputs=gr.AnnotatedImage(), examples=examples, title="Birder Object Detection", description="Select a model and upload an image or use one of the examples to get bird detections.", ) return iface # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch()