File size: 10,158 Bytes
2f356cf
a203767
 
2f356cf
0fa76df
2f356cf
9dc104e
2f356cf
 
c37fa9e
 
 
895de79
c37fa9e
 
 
 
2f356cf
a17b1ad
 
2f356cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895de79
 
2f356cf
895de79
 
 
 
 
 
 
2f356cf
895de79
 
 
 
 
 
a203767
895de79
 
 
2f356cf
a203767
2f356cf
 
895de79
 
 
2f356cf
 
61fbf22
 
cd86d7f
 
2f356cf
 
 
895de79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f356cf
9dc104e
 
 
 
 
 
fbc6bbb
9dc104e
 
 
 
 
 
 
 
895de79
2f356cf
895de79
2f356cf
895de79
 
 
 
 
 
 
2f356cf
 
 
 
 
 
895de79
2f356cf
 
 
b85fc73
 
2f356cf
 
 
 
 
 
 
895de79
 
 
 
 
 
 
2f356cf
 
 
 
 
 
 
 
 
 
895de79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import gradio as gr
import spaces

import sahi.utils
from sahi import AutoDetectionModel
import sahi.predict
import sahi.slicing
from PIL import Image
import numpy
from ultralytics import YOLO
import sys
import types

if 'huggingface_hub.utils._errors' not in sys.modules:
    mock_errors = types.ModuleType('_errors')
    mock_errors.RepositoryNotFoundError = Exception
    sys.modules['huggingface_hub.utils._errors'] = mock_errors

IMAGE_SIZE = 640

# Images
sahi.utils.file.download_from_url(
    "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
    "apple_tree.jpg",
)
sahi.utils.file.download_from_url(
    "https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
    "highway.jpg",
)

sahi.utils.file.download_from_url(
    "https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
    "highway2.jpg",
)

sahi.utils.file.download_from_url(
    "https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
    "highway3.jpg",
)

# Global model variable
model = None

def load_yolo_model(model_name, confidence_threshold=0.5):
    """
    Loads a YOLOv11 detection model.

    Args:
        model_name (str): The name of the YOLOv11 model to load (e.g., "yolo11n.pt").
        confidence_threshold (float): The confidence threshold for object detection.

    Returns:
        AutoDetectionModel: The loaded SAHI AutoDetectionModel.
    """
    global model
    model_path = model_name
    model = AutoDetectionModel.from_pretrained(
        model_type="ultralytics", model_path=model_path, device=None, # auto device selection
        confidence_threshold=confidence_threshold, image_size=IMAGE_SIZE
    )
    return model

@spaces.GPU(duration=60)
def sahi_yolo_inference(
    image,
    yolo_model_name,
    confidence_threshold,
    max_detections,
    slice_height=512,
    slice_width=512,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2,
    postprocess_type="NMS",
    postprocess_match_metric="IOU",
    postprocess_match_threshold=0.5,
    postprocess_class_agnostic=False,
):
    """
    Performs object detection using SAHI with a specified YOLOv11 model.

    Args:
        image (PIL.Image.Image): The input image for detection.
        yolo_model_name (str): The name of the YOLOv11 model to use for inference.
        confidence_threshold (float): The confidence threshold for object detection.
        max_detections (int): The maximum number of detections to return.
        slice_height (int): The height of each slice for sliced inference.
        slice_width (int): The width of each slice for sliced inference.
        overlap_height_ratio (float): The overlap ratio for slice height.
        overlap_width_ratio (float): The overlap ratio for slice width.
        postprocess_type (str): The type of postprocessing to apply ("NMS" or "GREEDYNMM").
        postprocess_match_metric (str): The metric for postprocessing matching ("IOU" or "IOS").
        postprocess_match_threshold (float): The threshold for postprocessing matching.
        postprocess_class_agnostic (bool): Whether postprocessing should be class agnostic.

    Returns:
        tuple: A tuple containing two PIL.Image.Image objects:
               - The image with standard YOLO inference results.
               - The image with SAHI sliced YOLO inference results.
    """
    load_yolo_model(yolo_model_name, confidence_threshold)

    image_width, image_height = image.size
    sliced_bboxes = sahi.slicing.get_slice_bboxes(
        image_height,
        image_width,
        slice_height,
        slice_width,
        False,
        overlap_height_ratio,
        overlap_width_ratio,
    )
    if len(sliced_bboxes) > 60:
        raise ValueError(
            f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size."
        )

    # Standard inference
    prediction_result_1 = sahi.predict.get_prediction(
        image=image, detection_model=model,
    )
    
    # Filter by max_detections for standard inference
    if max_detections is not None and len(prediction_result_1.object_prediction_list) > max_detections:
        prediction_result_1.object_prediction_list = sorted(
            prediction_result_1.object_prediction_list, key=lambda x: x.score.value, reverse=True
        )[:max_detections]

    visual_result_1 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_1.object_prediction_list,
    )
    output_1 = Image.fromarray(visual_result_1["image"])

    # Sliced inference
    prediction_result_2 = sahi.predict.get_sliced_prediction(
        image=image,
        detection_model=model,
        slice_height=int(slice_height),
        slice_width=int(slice_width),
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        postprocess_type=postprocess_type,
        postprocess_match_metric=postprocess_match_metric,
        postprocess_match_threshold=postprocess_match_threshold,
        postprocess_class_agnostic=postprocess_class_agnostic,
    )

    # Filter by max_detections for sliced inference
    if max_detections is not None and len(prediction_result_2.object_prediction_list) > max_detections:
        prediction_result_2.object_prediction_list = sorted(
            prediction_result_2.object_prediction_list, key=lambda x: x.score.value, reverse=True
        )[:max_detections]

    visual_result_2 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_2.object_prediction_list,
    )

    output_2 = Image.fromarray(visual_result_2["image"])

    return output_1, output_2


with gr.Blocks() as app:
    gr.Markdown("# Small Object Detection with SAHI + YOLOv11")
    gr.Markdown(
        "SAHI + YOLOv11 demo for small object detection. "
        "Upload your own image or click an example image to use."
    )

    with gr.Row():
        with gr.Column():
            original_image_input = gr.Image(type="pil", label="Original Image")
            yolo_model_dropdown = gr.Dropdown(
                choices=["yolo11n.pt", "yolo11s.pt", "yolo11m.pt", "yolo11l.pt", "yolo11x.pt"],
                value="yolo11s.pt",
                label="YOLOv11 Model",
            )
            confidence_threshold_slider = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.01,
                value=0.5,
                label="Confidence Threshold",
            )
            max_detections_slider = gr.Slider(
                minimum=1,
                maximum=500,
                step=1,
                value=300,
                label="Max Detections",
            )
            slice_height_input = gr.Number(value=512, label="Slice Height")
            slice_width_input = gr.Number(value=512, label="Slice Width")
            overlap_height_ratio_slider = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.01,
                value=0.2,
                label="Overlap Height Ratio",
            )
            overlap_width_ratio_slider = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.01,
                value=0.2,
                label="Overlap Width Ratio",
            )
            postprocess_type_dropdown = gr.Dropdown(
                ["NMS", "GREEDYNMM"],
                type="value",
                value="NMS",
                label="Postprocess Type",
            )
            postprocess_match_metric_dropdown = gr.Dropdown(
                ["IOU", "IOS"], type="value", value="IOU", label="Postprocess Match Metric"
            )
            postprocess_match_threshold_slider = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.01,
                value=0.5,
                label="Postprocess Match Threshold",
            )
            postprocess_class_agnostic_checkbox = gr.Checkbox(value=True, label="Postprocess Class Agnostic")

            submit_button = gr.Button("Run Inference")

        with gr.Column():
            output_standard = gr.Image(type="pil", label="YOLOv11 Standard")
            output_sahi_sliced = gr.Image(type="pil", label="YOLOv11 + SAHI Sliced")

    gr.Examples(
        examples=[
            ["apple_tree.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
            ["highway.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
            ["highway2.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
            ["highway3.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
        ],
        inputs=[
            original_image_input,
            yolo_model_dropdown,
            confidence_threshold_slider,
            max_detections_slider,
            slice_height_input,
            slice_width_input,
            overlap_height_ratio_slider,
            overlap_width_ratio_slider,
            postprocess_type_dropdown,
            postprocess_match_metric_dropdown,
            postprocess_match_threshold_slider,
            postprocess_class_agnostic_checkbox,
        ],
        outputs=[output_standard, output_sahi_sliced],
        fn=sahi_yolo_inference,
        cache_examples=True,
    )

    submit_button.click(
        fn=sahi_yolo_inference,
        inputs=[
            original_image_input,
            yolo_model_dropdown,
            confidence_threshold_slider,
            max_detections_slider,
            slice_height_input,
            slice_width_input,
            overlap_height_ratio_slider,
            overlap_width_ratio_slider,
            postprocess_type_dropdown,
            postprocess_match_metric_dropdown,
            postprocess_match_threshold_slider,
            postprocess_class_agnostic_checkbox,
        ],
        outputs=[output_standard, output_sahi_sliced],
    )

app.launch(mcp_server=True)