File size: 2,891 Bytes
5f1587b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import supervision as sv
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction

class YOLO:
    def __init__(self, confidence_threshold, iou_threshold, slicing_overlap, categories, device):
        """
        YOLO detector wrapper using SAHI for sliced prediction.

        Args:
            confidence_threshold (float): Minimum confidence for detections.
            iou_threshold (float): IoU threshold for NMS (not used directly here).
            slicing_overlap (float): Overlap ratio for slicing.
            categories (list): List of class names.
            device (str): Device to run the model on ('cpu' or 'cuda').
        """
        self.model = None
        self.confidence_threshold = confidence_threshold
        self.iou_threshold = iou_threshold
        self.slicing_overlap = slicing_overlap
        self.categories = categories
        self.category_mapping = {str(i): category for i, category in enumerate(categories)}
        self.device = device

    def load_onnx_model(self, path):
        """
        Loads the ONNX model using SAHI's AutoDetectionModel.
        """
        self.model = AutoDetectionModel.from_pretrained(
            model_type='yolov8onnx',
            model_path=path,
            confidence_threshold=self.confidence_threshold,
            category_mapping=self.category_mapping,
            device=self.device
        )

    def __call__(self, frame):
        """
        Runs sliced prediction on the input frame and returns a supervision.Detections object.
        """
        # Get input shape from ONNX model
        input_shape = self.model.model.get_inputs()[0].shape[2]
        result = get_sliced_prediction(
            frame,
            self.model,
            slice_height=input_shape,
            slice_width=input_shape,
            overlap_height_ratio=self.slicing_overlap,
            overlap_width_ratio=self.slicing_overlap,
            verbose=False,
        )
        boxes = []
        confidences = []
        class_ids = []
        for det in result.object_prediction_list:
            boxes.append(det.bbox.to_xyxy())
            confidences.append(det.score.value)
            class_ids.append(det.category.id)
        if boxes:
            boxes = np.array(boxes)
            confidences = np.array(confidences)
            class_ids = np.array(class_ids)
        else:
            boxes = np.zeros((0, 4))
            confidences = np.zeros((0,))
            class_ids = np.zeros((0,))
        detections = sv.Detections(
            xyxy=boxes,
            confidence=confidences,
            class_id=class_ids,
        )
        return detections
    
    def get_category_mapping(self):
        """
        Returns the category mapping.
        """
        # Convert string keys to integers
        return {int(k): v for k, v in self.category_mapping.items()}