eadali's picture
Test yolo onnx model
5f1587b
import numpy as np
import supervision as sv
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
class YOLO:
def __init__(self, confidence_threshold, iou_threshold, slicing_overlap, categories, device):
"""
YOLO detector wrapper using SAHI for sliced prediction.
Args:
confidence_threshold (float): Minimum confidence for detections.
iou_threshold (float): IoU threshold for NMS (not used directly here).
slicing_overlap (float): Overlap ratio for slicing.
categories (list): List of class names.
device (str): Device to run the model on ('cpu' or 'cuda').
"""
self.model = None
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.slicing_overlap = slicing_overlap
self.categories = categories
self.category_mapping = {str(i): category for i, category in enumerate(categories)}
self.device = device
def load_onnx_model(self, path):
"""
Loads the ONNX model using SAHI's AutoDetectionModel.
"""
self.model = AutoDetectionModel.from_pretrained(
model_type='yolov8onnx',
model_path=path,
confidence_threshold=self.confidence_threshold,
category_mapping=self.category_mapping,
device=self.device
)
def __call__(self, frame):
"""
Runs sliced prediction on the input frame and returns a supervision.Detections object.
"""
# Get input shape from ONNX model
input_shape = self.model.model.get_inputs()[0].shape[2]
result = get_sliced_prediction(
frame,
self.model,
slice_height=input_shape,
slice_width=input_shape,
overlap_height_ratio=self.slicing_overlap,
overlap_width_ratio=self.slicing_overlap,
verbose=False,
)
boxes = []
confidences = []
class_ids = []
for det in result.object_prediction_list:
boxes.append(det.bbox.to_xyxy())
confidences.append(det.score.value)
class_ids.append(det.category.id)
if boxes:
boxes = np.array(boxes)
confidences = np.array(confidences)
class_ids = np.array(class_ids)
else:
boxes = np.zeros((0, 4))
confidences = np.zeros((0,))
class_ids = np.zeros((0,))
detections = sv.Detections(
xyxy=boxes,
confidence=confidences,
class_id=class_ids,
)
return detections
def get_category_mapping(self):
"""
Returns the category mapping.
"""
# Convert string keys to integers
return {int(k): v for k, v in self.category_mapping.items()}