DurgaDeepak commited on
Commit
487b7f6
·
verified ·
1 Parent(s): a3c2125

Update models/detection/detector.py

Browse files
Files changed (1) hide show
  1. models/detection/detector.py +20 -27
models/detection/detector.py CHANGED
@@ -1,31 +1,34 @@
1
- import logging
2
- from PIL import Image, ImageDraw
3
- from huggingface_hub import hf_hub_download
4
- from ultralytics import YOLO
5
- import os
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
  class ObjectDetector:
10
- def __init__(self, model_key="yolov8n.pt", device="cpu"):
11
  self.device = device
12
  self.model = None
13
  self.model_key = model_key.lower().replace(".pt", "")
14
- self.repo_map = {
 
15
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
16
  "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
17
  "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
18
  "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
19
  }
20
 
21
- def load_model(self):
22
- if self.model is not None:
23
- return
24
- if self.model_key not in self.repo_map:
25
  raise ValueError(f"Unsupported model key: {self.model_key}")
26
- repo_id, filename = self.repo_map[self.model_key]
27
- weights_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir="models/detection/weights")
28
- self.model = YOLO(weights_path) # ✅ ZeroGPU-safe: runtime only
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def predict(self, image: Image.Image, conf_threshold=0.25):
31
  self.load_model()
@@ -39,13 +42,3 @@ class ObjectDetector:
39
  "bbox": box.xyxy[0].tolist()
40
  })
41
  return detections
42
-
43
- def draw(self, image: Image.Image, detections, alpha=0.5):
44
- overlay = image.copy()
45
- draw = ImageDraw.Draw(overlay)
46
- for det in detections:
47
- bbox = det["bbox"]
48
- label = f'{det["class_name"]} {det["confidence"]:.2f}'
49
- draw.rectangle(bbox, outline="red", width=2)
50
- draw.text((bbox[0], bbox[1]), label, fill="red")
51
- return Image.blend(image, overlay, alpha)
 
 
 
 
 
 
 
 
 
1
  class ObjectDetector:
2
+ def __init__(self, model_key="yolov8n", device="cpu"):
3
  self.device = device
4
  self.model = None
5
  self.model_key = model_key.lower().replace(".pt", "")
6
+
7
+ hf_map = {
8
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
9
  "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
10
  "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
11
  "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
12
  }
13
 
14
+ if self.model_key not in hf_map:
 
 
 
15
  raise ValueError(f"Unsupported model key: {self.model_key}")
16
+
17
+ repo_id, filename = hf_map[self.model_key]
18
+ self.weights_path = hf_hub_download(
19
+ repo_id=repo_id,
20
+ filename=filename,
21
+ cache_dir="models/detection/weights",
22
+ force_download=False
23
+ )
24
+
25
+ def load_model(self):
26
+ if self.model is None:
27
+ from ultralytics import YOLO # Defer import
28
+ self.model = YOLO(self.weights_path)
29
+ if self.device == "cuda":
30
+ self.model.to("cuda")
31
+ return self # So you can chain
32
 
33
  def predict(self, image: Image.Image, conf_threshold=0.25):
34
  self.load_model()
 
42
  "bbox": box.xyxy[0].tolist()
43
  })
44
  return detections