Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import sahi.utils | |
from sahi import AutoDetectionModel | |
import sahi.predict | |
import sahi.slicing | |
from PIL import Image | |
import numpy | |
from ultralytics import YOLO | |
import sys | |
import types | |
if 'huggingface_hub.utils._errors' not in sys.modules: | |
mock_errors = types.ModuleType('_errors') | |
mock_errors.RepositoryNotFoundError = Exception | |
sys.modules['huggingface_hub.utils._errors'] = mock_errors | |
IMAGE_SIZE = 640 | |
# Images | |
sahi.utils.file.download_from_url( | |
"https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg", | |
"apple_tree.jpg", | |
) | |
sahi.utils.file.download_from_url( | |
"https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg", | |
"highway.jpg", | |
) | |
sahi.utils.file.download_from_url( | |
"https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg", | |
"highway2.jpg", | |
) | |
sahi.utils.file.download_from_url( | |
"https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg", | |
"highway3.jpg", | |
) | |
# Global model variable | |
model = None | |
def load_yolo_model(model_name, confidence_threshold=0.5): | |
""" | |
Loads a YOLOv11 detection model. | |
Args: | |
model_name (str): The name of the YOLOv11 model to load (e.g., "yolo11n.pt"). | |
confidence_threshold (float): The confidence threshold for object detection. | |
Returns: | |
AutoDetectionModel: The loaded SAHI AutoDetectionModel. | |
""" | |
global model | |
model_path = model_name | |
model = AutoDetectionModel.from_pretrained( | |
model_type="ultralytics", model_path=model_path, device=None, # auto device selection | |
confidence_threshold=confidence_threshold, image_size=IMAGE_SIZE | |
) | |
return model | |
def sahi_yolo_inference( | |
image, | |
yolo_model_name, | |
confidence_threshold, | |
max_detections, | |
slice_height=512, | |
slice_width=512, | |
overlap_height_ratio=0.2, | |
overlap_width_ratio=0.2, | |
postprocess_type="NMS", | |
postprocess_match_metric="IOU", | |
postprocess_match_threshold=0.5, | |
postprocess_class_agnostic=False, | |
): | |
""" | |
Performs object detection using SAHI with a specified YOLOv11 model. | |
Args: | |
image (PIL.Image.Image): The input image for detection. | |
yolo_model_name (str): The name of the YOLOv11 model to use for inference. | |
confidence_threshold (float): The confidence threshold for object detection. | |
max_detections (int): The maximum number of detections to return. | |
slice_height (int): The height of each slice for sliced inference. | |
slice_width (int): The width of each slice for sliced inference. | |
overlap_height_ratio (float): The overlap ratio for slice height. | |
overlap_width_ratio (float): The overlap ratio for slice width. | |
postprocess_type (str): The type of postprocessing to apply ("NMS" or "GREEDYNMM"). | |
postprocess_match_metric (str): The metric for postprocessing matching ("IOU" or "IOS"). | |
postprocess_match_threshold (float): The threshold for postprocessing matching. | |
postprocess_class_agnostic (bool): Whether postprocessing should be class agnostic. | |
Returns: | |
tuple: A tuple containing two PIL.Image.Image objects: | |
- The image with standard YOLO inference results. | |
- The image with SAHI sliced YOLO inference results. | |
""" | |
load_yolo_model(yolo_model_name, confidence_threshold) | |
image_width, image_height = image.size | |
sliced_bboxes = sahi.slicing.get_slice_bboxes( | |
image_height, | |
image_width, | |
slice_height, | |
slice_width, | |
False, | |
overlap_height_ratio, | |
overlap_width_ratio, | |
) | |
if len(sliced_bboxes) > 60: | |
raise ValueError( | |
f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size." | |
) | |
# Standard inference | |
prediction_result_1 = sahi.predict.get_prediction( | |
image=image, detection_model=model, | |
) | |
# Filter by max_detections for standard inference | |
if max_detections is not None and len(prediction_result_1.object_prediction_list) > max_detections: | |
prediction_result_1.object_prediction_list = sorted( | |
prediction_result_1.object_prediction_list, key=lambda x: x.score.value, reverse=True | |
)[:max_detections] | |
visual_result_1 = sahi.utils.cv.visualize_object_predictions( | |
image=numpy.array(image), | |
object_prediction_list=prediction_result_1.object_prediction_list, | |
) | |
output_1 = Image.fromarray(visual_result_1["image"]) | |
# Sliced inference | |
prediction_result_2 = sahi.predict.get_sliced_prediction( | |
image=image, | |
detection_model=model, | |
slice_height=int(slice_height), | |
slice_width=int(slice_width), | |
overlap_height_ratio=overlap_height_ratio, | |
overlap_width_ratio=overlap_width_ratio, | |
postprocess_type=postprocess_type, | |
postprocess_match_metric=postprocess_match_metric, | |
postprocess_match_threshold=postprocess_match_threshold, | |
postprocess_class_agnostic=postprocess_class_agnostic, | |
) | |
# Filter by max_detections for sliced inference | |
if max_detections is not None and len(prediction_result_2.object_prediction_list) > max_detections: | |
prediction_result_2.object_prediction_list = sorted( | |
prediction_result_2.object_prediction_list, key=lambda x: x.score.value, reverse=True | |
)[:max_detections] | |
visual_result_2 = sahi.utils.cv.visualize_object_predictions( | |
image=numpy.array(image), | |
object_prediction_list=prediction_result_2.object_prediction_list, | |
) | |
output_2 = Image.fromarray(visual_result_2["image"]) | |
return output_1, output_2 | |
with gr.Blocks() as app: | |
gr.Markdown("# Small Object Detection with SAHI + YOLOv11") | |
gr.Markdown( | |
"SAHI + YOLOv11 demo for small object detection. " | |
"Upload your own image or click an example image to use." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
original_image_input = gr.Image(type="pil", label="Original Image") | |
yolo_model_dropdown = gr.Dropdown( | |
choices=["yolo11n.pt", "yolo11s.pt", "yolo11m.pt", "yolo11l.pt", "yolo11x.pt"], | |
value="yolo11s.pt", | |
label="YOLOv11 Model", | |
) | |
confidence_threshold_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5, | |
label="Confidence Threshold", | |
) | |
max_detections_slider = gr.Slider( | |
minimum=1, | |
maximum=500, | |
step=1, | |
value=300, | |
label="Max Detections", | |
) | |
slice_height_input = gr.Number(value=512, label="Slice Height") | |
slice_width_input = gr.Number(value=512, label="Slice Width") | |
overlap_height_ratio_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.2, | |
label="Overlap Height Ratio", | |
) | |
overlap_width_ratio_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.2, | |
label="Overlap Width Ratio", | |
) | |
postprocess_type_dropdown = gr.Dropdown( | |
["NMS", "GREEDYNMM"], | |
type="value", | |
value="NMS", | |
label="Postprocess Type", | |
) | |
postprocess_match_metric_dropdown = gr.Dropdown( | |
["IOU", "IOS"], type="value", value="IOU", label="Postprocess Match Metric" | |
) | |
postprocess_match_threshold_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5, | |
label="Postprocess Match Threshold", | |
) | |
postprocess_class_agnostic_checkbox = gr.Checkbox(value=True, label="Postprocess Class Agnostic") | |
submit_button = gr.Button("Run Inference") | |
with gr.Column(): | |
output_standard = gr.Image(type="pil", label="YOLOv11 Standard") | |
output_sahi_sliced = gr.Image(type="pil", label="YOLOv11 + SAHI Sliced") | |
gr.Examples( | |
examples=[ | |
["apple_tree.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
["highway.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
["highway2.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
["highway3.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
], | |
inputs=[ | |
original_image_input, | |
yolo_model_dropdown, | |
confidence_threshold_slider, | |
max_detections_slider, | |
slice_height_input, | |
slice_width_input, | |
overlap_height_ratio_slider, | |
overlap_width_ratio_slider, | |
postprocess_type_dropdown, | |
postprocess_match_metric_dropdown, | |
postprocess_match_threshold_slider, | |
postprocess_class_agnostic_checkbox, | |
], | |
outputs=[output_standard, output_sahi_sliced], | |
fn=sahi_yolo_inference, | |
cache_examples=True, | |
) | |
submit_button.click( | |
fn=sahi_yolo_inference, | |
inputs=[ | |
original_image_input, | |
yolo_model_dropdown, | |
confidence_threshold_slider, | |
max_detections_slider, | |
slice_height_input, | |
slice_width_input, | |
overlap_height_ratio_slider, | |
overlap_width_ratio_slider, | |
postprocess_type_dropdown, | |
postprocess_match_metric_dropdown, | |
postprocess_match_threshold_slider, | |
postprocess_class_agnostic_checkbox, | |
], | |
outputs=[output_standard, output_sahi_sliced], | |
) | |
app.launch(mcp_server=True) |