File size: 1,978 Bytes
ce2b58f
5567efd
9538100
ce2b58f
b0c7a24
9538100
8386bf1
b0c7a24
ce2b58f
bbc95d9
 
1044803
 
 
b0c7a24
 
 
 
 
cdbafa3
bbc95d9
1044803
 
 
 
 
 
 
ce2b58f
b0c7a24
bbc95d9
ce2b58f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import logging
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import os

logger = logging.getLogger(__name__)

class ObjectDetector:
    def __init__(self, model_key="yolov8n.pt", device="cpu"):
        self.device = device
        self.model = None
        self.model_key = model_key.lower().replace(".pt", "")
        self.repo_map = {
            "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
            "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
            "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
            "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
        }

    def load_model(self):
        if self.model is not None:
            return
        if self.model_key not in self.repo_map:
            raise ValueError(f"Unsupported model key: {self.model_key}")
        repo_id, filename = self.repo_map[self.model_key]
        weights_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir="models/detection/weights")
        self.model = YOLO(weights_path)  # ✅ ZeroGPU-safe: runtime only

    def predict(self, image: Image.Image, conf_threshold=0.25):
        self.load_model()
        results = self.model(image)
        detections = []
        for r in results:
            for box in r.boxes:
                detections.append({
                    "class_name": r.names[int(box.cls)],
                    "confidence": float(box.conf),
                    "bbox": box.xyxy[0].tolist()
                })
        return detections

    def draw(self, image: Image.Image, detections, alpha=0.5):
        overlay = image.copy()
        draw = ImageDraw.Draw(overlay)
        for det in detections:
            bbox = det["bbox"]
            label = f'{det["class_name"]} {det["confidence"]:.2f}'
            draw.rectangle(bbox, outline="red", width=2)
            draw.text((bbox[0], bbox[1]), label, fill="red")
        return Image.blend(image, overlay, alpha)