DurgaDeepak commited on
Commit
bbc95d9
Β·
verified Β·
1 Parent(s): 9da95df

Update models/detection/detector.py

Browse files
Files changed (1) hide show
  1. models/detection/detector.py +21 -15
models/detection/detector.py CHANGED
@@ -4,6 +4,7 @@ from huggingface_hub import hf_hub_download
4
  from ultralytics import YOLO
5
  import os
6
  import shutil
 
7
 
8
  # Setup logger
9
  logger = logging.getLogger(__name__)
@@ -13,15 +14,16 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
13
  shutil.rmtree("models/detection/weights", ignore_errors=True)
14
 
15
  class ObjectDetector:
16
- def __init__(self, model_key="yolov8n", device="cpu"):
17
  """
18
  Initializes an Ultralytics YOLO model using HF download path.
19
 
20
  Args:
21
- model_key (str): e.g. 'yolov8n', 'yolov8s', etc.
22
  device (str): 'cpu' or 'cuda'
23
  """
24
- # Optional aliasing
 
25
  alias_map = {
26
  "yolov8n": "yolov8n",
27
  "yolov8s": "yolov8s",
@@ -29,9 +31,6 @@ class ObjectDetector:
29
  "yolov11b": "yolov11b"
30
  }
31
 
32
- resolved_key = model_key.lower().replace(".pt", "")
33
-
34
- # HF repo map
35
  hf_map = {
36
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
37
  "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
@@ -39,25 +38,32 @@ class ObjectDetector:
39
  "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
40
  }
41
 
 
42
  if resolved_key not in hf_map:
43
  raise ValueError(f"Unsupported model key: {resolved_key}")
44
 
45
  repo_id, filename = hf_map[resolved_key]
46
-
47
- # πŸ“₯ Download from HF Hub
48
- weights_path = hf_hub_download(
49
  repo_id=repo_id,
50
  filename=filename,
51
  cache_dir="models/detection/weights",
52
- force_download=True # Optional: change to False for reuse
53
  )
54
 
55
- logger.info(f"βœ… Loaded YOLO model: {resolved_key} from {weights_path}")
56
- self.device = device
57
- self.model = YOLO(weights_path)
 
 
 
 
 
 
 
58
 
59
  def predict(self, image: Image.Image, conf_threshold=0.25):
60
- logger.info("Running object detection")
 
61
  results = self.model(image)
62
  detections = []
63
  for r in results:
@@ -67,7 +73,7 @@ class ObjectDetector:
67
  "confidence": float(box.conf),
68
  "bbox": box.xyxy[0].tolist()
69
  })
70
- logger.info(f"Detected {len(detections)} objects")
71
  return detections
72
 
73
  def draw(self, image: Image.Image, detections, alpha=0.5):
 
4
  from ultralytics import YOLO
5
  import os
6
  import shutil
7
+ import torch
8
 
9
  # Setup logger
10
  logger = logging.getLogger(__name__)
 
14
  shutil.rmtree("models/detection/weights", ignore_errors=True)
15
 
16
  class ObjectDetector:
17
+ def __init__(self, model_key="yolov8n.pt", device="cpu"):
18
  """
19
  Initializes an Ultralytics YOLO model using HF download path.
20
 
21
  Args:
22
+ model_key (str): e.g. 'yolov8n.pt', 'yolov8s.pt', etc.
23
  device (str): 'cpu' or 'cuda'
24
  """
25
+ self.device = device
26
+ resolved_key = model_key.lower().replace(".pt", "")
27
  alias_map = {
28
  "yolov8n": "yolov8n",
29
  "yolov8s": "yolov8s",
 
31
  "yolov11b": "yolov11b"
32
  }
33
 
 
 
 
34
  hf_map = {
35
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
36
  "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
 
38
  "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
39
  }
40
 
41
+ resolved_key = alias_map.get(resolved_key, resolved_key)
42
  if resolved_key not in hf_map:
43
  raise ValueError(f"Unsupported model key: {resolved_key}")
44
 
45
  repo_id, filename = hf_map[resolved_key]
46
+ self.weights_path = hf_hub_download(
 
 
47
  repo_id=repo_id,
48
  filename=filename,
49
  cache_dir="models/detection/weights",
50
+ force_download=False
51
  )
52
 
53
+ self.model = None # πŸ” Don't initialize on construction
54
+ logger.info(f"Model path ready for {resolved_key}: {self.weights_path}")
55
+
56
+ def load_model(self):
57
+ if self.model is None:
58
+ logger.info("⏳ Loading YOLO model into memory...")
59
+ self.model = YOLO(self.weights_path)
60
+ if self.device == "cuda" and torch.cuda.is_available():
61
+ self.model.to("cuda")
62
+ logger.info(f"βœ… YOLO model loaded on {self.device}")
63
 
64
  def predict(self, image: Image.Image, conf_threshold=0.25):
65
+ self.load_model()
66
+ logger.info("πŸ” Running object detection")
67
  results = self.model(image)
68
  detections = []
69
  for r in results:
 
73
  "confidence": float(box.conf),
74
  "bbox": box.xyxy[0].tolist()
75
  })
76
+ logger.info(f"βœ… Detected {len(detections)} objects")
77
  return detections
78
 
79
  def draw(self, image: Image.Image, detections, alpha=0.5):