File size: 2,859 Bytes
36d1d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cc80d0
 
36d1d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import birder
from birder.inference.detection import infer_image
from huggingface_hub import HfApi

import gradio as gr


def get_birder_classification_models():
    api = HfApi()
    models = api.list_models(author="birder-project", tags="object-detection")
    return [model.modelId.split("/")[-1] for model in models]


def get_selected_models():
    return [
        "deformable_detr_boxref_coco_convnext_v2_tiny_imagenet21k",
    ]


def load_model_and_detect(image, model_name, min_score):
    if len(birder.list_pretrained_models(model_name)) == 0:
        model_name = birder.list_pretrained_models(model_name + "*")[0]

    (net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(model_name, inference=True)

    size = birder.get_size_from_signature(signature)
    transform = birder.detection_transform(size, rgb_stats, dynamic_size=signature["dynamic"])
    detections = infer_image(net, image, transform, score_threshold=min_score)

    idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys()))
    label_names = [idx_to_class[i.item()] for i in detections["labels"]]

    return (detections, label_names)


def predict(image, model_name, min_score):
    (detections, label_names) = load_model_and_detect(image, model_name, min_score)
    if detections is None:
        return (image, [])

    annotations = []
    boxes = detections["boxes"]
    scores = detections["scores"]

    for box, score, label in zip(boxes, scores, label_names):
        (x1, y1, x2, y2) = box.tolist()
        annotation = ((int(x1), int(y1), int(x2), int(y2)), f"{label} ({score:.2f})")
        annotations.append(annotation)

    return (image, annotations)


def create_interface():
    models = get_selected_models()

    examples = [
        ["safari.jpeg", "deformable_detr_boxref_coco_convnext_v2_tiny_imagenet21k", 0.45],
    ]

    # Create interface
    iface = gr.Interface(
        analytics_enabled=False,
        deep_link=False,
        fn=predict,
        inputs=[
            gr.Image(type="pil", label="Input Image"),
            gr.Dropdown(
                choices=models,
                label="Select Model",
                value=models[0] if models else None,
            ),
            gr.Slider(
                minimum=0.05,
                maximum=0.95,
                step=0.05,
                value=0.5,
                label="Minimum Score Threshold",
                info="Only detections with confidence above this threshold will be shown",
            ),
        ],
        outputs=gr.AnnotatedImage(),
        examples=examples,
        title="Birder Object Detection",
        description="Select a model and upload an image or use one of the examples to get bird detections.",
    )

    return iface


# Launch the app
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()