File size: 4,389 Bytes
75fbee5
 
03c9511
e29184b
75fbee5
d71beb6
5ec100e
 
ce2b58f
487b7f6
bbc95d9
1044803
 
487b7f6
 
b0c7a24
 
 
 
 
cdbafa3
487b7f6
1044803
487b7f6
 
 
 
 
 
 
 
 
 
93d071e
8cfece3
5caf904
 
 
5ec100e
 
 
5caf904
8cfece3
5caf904
 
8cfece3
5caf904
 
8cfece3
178b1f7
ce2b58f
5caf904
b0c7a24
bbc95d9
5052f14
93d071e
5052f14
 
ce2b58f
 
 
 
 
 
 
 
 
 
5ec100e
55317e4
03c9511
 
 
 
 
 
 
 
 
 
 
5ec100e
 
e29184b
03c9511
e29184b
03c9511
e29184b
 
 
03c9511
 
 
 
 
 
 
 
 
 
 
 
 
5ec100e
03c9511
 
 
 
 
e29184b
03c9511
 
e29184b
03c9511
 
e29184b
 
03c9511
 
 
 
99328a4
03c9511
 
 
7374951
03c9511
 
e29184b
03c9511
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import logging
import random
from PIL import Image, ImageDraw, ImageFont
from huggingface_hub import hf_hub_download

logger = logging.getLogger(__name__)

class ObjectDetector:
    def __init__(self, model_key="yolov8n", device="cpu"):
        self.device = device
        self.model = None
        self.model_key = model_key.lower().replace(".pt", "")

        hf_map = {
            "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
            "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
            "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
            "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
        }

        if self.model_key not in hf_map:
            raise ValueError(f"Unsupported model key: {self.model_key}")

        repo_id, filename = hf_map[self.model_key]
        self.weights_path = hf_hub_download(
            repo_id=repo_id,
            filename=filename,
            cache_dir="models/detection/weights",
            force_download=False
        )

    def load_model(self):
        logger.info(f"Loading model from path: {self.weights_path}")
        if self.model is None:
            import torch  # Safe to import here
            from ultralytics import YOLO  # Defer import
    
            if self.device == "cpu":
                os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
                
            # Initialize model
            self.model = YOLO(self.weights_path)
    
            # Move to CUDA only if necessary and safe
            if self.device == "cuda" and torch.cuda.is_available():
                self.model.to("cuda")
    
        return self



    def predict(self, image: Image.Image, conf_threshold=0.25):
        self.load_model()
        
        if self.model is None:
            raise RuntimeError("YOLO model not loaded. Call load_model() first.")
        
        results = self.model(image)
        detections = []
        for r in results:
            for box in r.boxes:
                detections.append({
                    "class_name": r.names[int(box.cls)],
                    "confidence": float(box.conf),
                    "bbox": box.xyxy[0].tolist()
                })
        return detections

    def draw(self, image: Image.Image, detections, alpha=0.5):
        """
        Draws thicker, per-class-colored bounding boxes and labels.
    
        Args:
            image (PIL.Image.Image): Original image.
            detections (List[Dict]): Each dict has "bbox", "class_name", "confidence".
            alpha (float): Blend strength for overlay.
        Returns:
            PIL.Image.Image: Blended image with overlays.
        """
        # copy & overlay
        overlay = image.copy()
        draw = ImageDraw.Draw(overlay)
    
        # try a TTF font, fallback to default
        try:
            font = ImageFont.truetype("arial.ttf", 18)
        except:
            font = ImageFont.load_default()
    
        # deterministic color per class
        class_colors = {}
        def get_color(cls):
            if cls not in class_colors:
                # seed by class name → same color every run
                rnd = random.Random(cls)
                class_colors[cls] = (
                    rnd.randint(100, 255),
                    rnd.randint(100, 255),
                    rnd.randint(100, 255),
                )
            return class_colors[cls]
    
        for det in detections:
            x1, y1, x2, y2 = det["bbox"]
            cls_name = det["class_name"]
            conf = det["confidence"]
            label = f"{cls_name} {conf:.2f}"
            color = get_color(cls_name)
    
            # thicker box: draw multiple offsets
            for t in range(4):
                draw.rectangle(
                    (x1 - t, y1 - t, x2 + t, y2 + t),
                    outline=color
                )
    
            # calculate text size
            text_box = draw.textbbox((x1, y1), label, font=font)
            tb_w = text_box[2] - text_box[0]
            tb_h = text_box[3] - text_box[1]
    
            # background rect for text
            bg = (x1, y1 - tb_h, x1 + tb_w + 6, y1)
            draw.rectangle(bg, fill=color)
    
            # draw text (with small padding)
            draw.text((x1 + 3, y1 - tb_h), label, font=font, fill="black")
    
        # blend and return
        return Image.blend(image, overlay, alpha)