|
import numpy as np |
|
import supervision as sv |
|
from sahi import AutoDetectionModel |
|
from sahi.predict import get_sliced_prediction |
|
|
|
class YOLO: |
|
def __init__(self, confidence_threshold, iou_threshold, slicing_overlap, categories, device): |
|
""" |
|
YOLO detector wrapper using SAHI for sliced prediction. |
|
|
|
Args: |
|
confidence_threshold (float): Minimum confidence for detections. |
|
iou_threshold (float): IoU threshold for NMS (not used directly here). |
|
slicing_overlap (float): Overlap ratio for slicing. |
|
categories (list): List of class names. |
|
device (str): Device to run the model on ('cpu' or 'cuda'). |
|
""" |
|
self.model = None |
|
self.confidence_threshold = confidence_threshold |
|
self.iou_threshold = iou_threshold |
|
self.slicing_overlap = slicing_overlap |
|
self.categories = categories |
|
self.category_mapping = {str(i): category for i, category in enumerate(categories)} |
|
self.device = device |
|
|
|
def load_onnx_model(self, path): |
|
""" |
|
Loads the ONNX model using SAHI's AutoDetectionModel. |
|
""" |
|
self.model = AutoDetectionModel.from_pretrained( |
|
model_type='yolov8onnx', |
|
model_path=path, |
|
confidence_threshold=self.confidence_threshold, |
|
category_mapping=self.category_mapping, |
|
device=self.device |
|
) |
|
|
|
def __call__(self, frame): |
|
""" |
|
Runs sliced prediction on the input frame and returns a supervision.Detections object. |
|
""" |
|
|
|
input_shape = self.model.model.get_inputs()[0].shape[2] |
|
result = get_sliced_prediction( |
|
frame, |
|
self.model, |
|
slice_height=input_shape, |
|
slice_width=input_shape, |
|
overlap_height_ratio=self.slicing_overlap, |
|
overlap_width_ratio=self.slicing_overlap, |
|
verbose=False, |
|
) |
|
boxes = [] |
|
confidences = [] |
|
class_ids = [] |
|
for det in result.object_prediction_list: |
|
boxes.append(det.bbox.to_xyxy()) |
|
confidences.append(det.score.value) |
|
class_ids.append(det.category.id) |
|
if boxes: |
|
boxes = np.array(boxes) |
|
confidences = np.array(confidences) |
|
class_ids = np.array(class_ids) |
|
else: |
|
boxes = np.zeros((0, 4)) |
|
confidences = np.zeros((0,)) |
|
class_ids = np.zeros((0,)) |
|
detections = sv.Detections( |
|
xyxy=boxes, |
|
confidence=confidences, |
|
class_id=class_ids, |
|
) |
|
return detections |
|
|
|
def get_category_mapping(self): |
|
""" |
|
Returns the category mapping. |
|
""" |
|
|
|
return {int(k): v for k, v in self.category_mapping.items()} |