DurgaDeepak commited on
Commit
1044803
Β·
verified Β·
1 Parent(s): 455843f

Update models/detection/detector.py

Browse files
Files changed (1) hide show
  1. models/detection/detector.py +10 -43
models/detection/detector.py CHANGED
@@ -3,64 +3,32 @@ from PIL import Image, ImageDraw
3
  from huggingface_hub import hf_hub_download
4
  from ultralytics import YOLO
5
  import os
6
- import torch
7
 
8
- # Setup logger
9
  logger = logging.getLogger(__name__)
10
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
11
 
12
  class ObjectDetector:
13
  def __init__(self, model_key="yolov8n.pt", device="cpu"):
14
- """
15
- Initializes an Ultralytics YOLO model path, defers actual model loading.
16
-
17
- Args:
18
- model_key (str): e.g. 'yolov8n.pt', 'yolov8s.pt', etc.
19
- device (str): 'cpu' or 'cuda'
20
- """
21
  self.device = device
22
- resolved_key = model_key.lower().replace(".pt", "")
23
- alias_map = {
24
- "yolov8n": "yolov8n",
25
- "yolov8s": "yolov8s",
26
- "yolov8l": "yolov8l",
27
- "yolov11b": "yolov11b"
28
- }
29
-
30
- hf_map = {
31
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
32
  "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
33
  "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
34
  "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
35
  }
36
 
37
- resolved_key = alias_map.get(resolved_key, resolved_key)
38
- if resolved_key not in hf_map:
39
- raise ValueError(f"Unsupported model key: {resolved_key}")
40
-
41
- repo_id, filename = hf_map[resolved_key]
42
- self.weights_path = hf_hub_download(
43
- repo_id=repo_id,
44
- filename=filename,
45
- cache_dir="models/detection/weights",
46
- force_download=False
47
- )
48
-
49
- logger.info(f"βœ… YOLO weights ready for {resolved_key} at {self.weights_path}")
50
- self.model = None # defer loading
51
-
52
  def load_model(self):
53
- if self.model is None:
54
- logger.info("βš™οΈ Loading YOLO model into memory (runtime-safe)")
55
- self.model = YOLO(self.weights_path)
56
- if self.device == "cuda" and torch.cuda.is_available():
57
- self.model.to("cuda")
58
- logger.info(f"βœ… YOLO model initialized on {self.device}")
59
- return self
60
 
61
  def predict(self, image: Image.Image, conf_threshold=0.25):
62
  self.load_model()
63
- logger.info("πŸ” Running object detection")
64
  results = self.model(image)
65
  detections = []
66
  for r in results:
@@ -70,7 +38,6 @@ class ObjectDetector:
70
  "confidence": float(box.conf),
71
  "bbox": box.xyxy[0].tolist()
72
  })
73
- logger.info(f"βœ… Detected {len(detections)} objects")
74
  return detections
75
 
76
  def draw(self, image: Image.Image, detections, alpha=0.5):
 
3
  from huggingface_hub import hf_hub_download
4
  from ultralytics import YOLO
5
  import os
 
6
 
 
7
  logger = logging.getLogger(__name__)
 
8
 
9
  class ObjectDetector:
10
  def __init__(self, model_key="yolov8n.pt", device="cpu"):
 
 
 
 
 
 
 
11
  self.device = device
12
+ self.model = None
13
+ self.model_key = model_key.lower().replace(".pt", "")
14
+ self.repo_map = {
 
 
 
 
 
 
15
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
16
  "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
17
  "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
18
  "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
19
  }
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def load_model(self):
22
+ if self.model is not None:
23
+ return
24
+ if self.model_key not in self.repo_map:
25
+ raise ValueError(f"Unsupported model key: {self.model_key}")
26
+ repo_id, filename = self.repo_map[self.model_key]
27
+ weights_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir="models/detection/weights")
28
+ self.model = YOLO(weights_path) # βœ… ZeroGPU-safe: runtime only
29
 
30
  def predict(self, image: Image.Image, conf_threshold=0.25):
31
  self.load_model()
 
32
  results = self.model(image)
33
  detections = []
34
  for r in results:
 
38
  "confidence": float(box.conf),
39
  "bbox": box.xyxy[0].tolist()
40
  })
 
41
  return detections
42
 
43
  def draw(self, image: Image.Image, detections, alpha=0.5):