DurgaDeepak commited on
Commit
ce2b58f
·
verified ·
1 Parent(s): 377c3c6

Update models/detection/detector.py

Browse files
Files changed (1) hide show
  1. models/detection/detector.py +86 -73
models/detection/detector.py CHANGED
@@ -1,73 +1,86 @@
1
- import os
2
- import numpy as np
3
- from PIL import Image, ImageDraw
4
- import logging
5
- from ultralytics import YOLO
6
- from utils.model_downloader import download_model_if_needed
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- class ObjectDetector:
11
- """
12
- Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants.
13
- """
14
-
15
- def __init__(self, model_key="yolov5n-seg", weights_dir="models/detection/weights", device="cpu"):
16
- """
17
- Initialize the Object Detection model.
18
-
19
- Args:
20
- model_key (str): Model identifier as defined in model_downloader.py.
21
- weights_dir (str): Directory to store/download model weights.
22
- device (str): Inference device ("cpu" or "cuda").
23
- """
24
- weights_path = os.path.join(weights_dir, f"{model_key}.pt")
25
- download_model_if_needed(model_key, weights_path)
26
-
27
- logger.info(f"Loading Object Detection model '{model_key}' from {weights_path}")
28
- self.device = device
29
- self.model = YOLO(weights_path)
30
-
31
- def predict(self, image: Image.Image):
32
- """
33
- Run object detection.
34
-
35
- Args:
36
- image (PIL.Image.Image): Input image.
37
-
38
- Returns:
39
- List[Dict]: List of detected objects with class name, confidence, and bbox.
40
- """
41
- logger.info("Running object detection")
42
- results = self.model(image)
43
- detections = []
44
- for r in results:
45
- for box in r.boxes:
46
- detections.append({
47
- "class_name": r.names[int(box.cls)],
48
- "confidence": float(box.conf),
49
- "bbox": box.xyxy[0].tolist()
50
- })
51
- logger.info(f"Detected {len(detections)} objects")
52
- return detections
53
-
54
- def draw(self, image: Image.Image, detections, alpha=0.5):
55
- """
56
- Draw bounding boxes on image.
57
-
58
- Args:
59
- image (PIL.Image.Image): Input image.
60
- detections (List[Dict]): Detection results.
61
- alpha (float): Blend strength.
62
-
63
- Returns:
64
- PIL.Image.Image: Image with bounding boxes drawn.
65
- """
66
- overlay = image.copy()
67
- draw = ImageDraw.Draw(overlay)
68
- for det in detections:
69
- bbox = det["bbox"]
70
- label = f'{det["class_name"]} {det["confidence"]:.2f}'
71
- draw.rectangle(bbox, outline="red", width=2)
72
- draw.text((bbox[0], bbox[1]), label, fill="red")
73
- return Image.blend(image, overlay, alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import logging
5
+ from ultralytics import YOLO
6
+ from utils.model_downloader import download_model_if_needed
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class ObjectDetector:
11
+ """
12
+ Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants.
13
+ """
14
+
15
+ def __init__(self, model_key="yolov5n-seg", device="cpu"):
16
+ """
17
+ Initialize the Object Detection model.
18
+
19
+ Args:
20
+ model_key (str): Model identifier as defined in model_downloader.py.
21
+ weights_dir (str): Directory to store/download model weights.
22
+ device (str): Inference device ("cpu" or "cuda").
23
+ """
24
+ repo_map = {
25
+ "yolov5n": ("ultralytics/yolov5", "yolov5n.pt"),
26
+ "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
27
+ # Add more if needed
28
+ }
29
+
30
+ if model_key not in repo_map:
31
+ raise ValueError(f"Unsupported model_key: {model_key}")
32
+
33
+ repo_id, filename = repo_map[model_key]
34
+
35
+ weights_path = hf_hub_download(
36
+ repo_id=repo_id,
37
+ filename=filename,
38
+ cache_dir="models/detection/weights"
39
+ )
40
+
41
+ self.device = device
42
+ self.model = YOLO(weights_path)
43
+
44
+ def predict(self, image: Image.Image):
45
+ """
46
+ Run object detection.
47
+
48
+ Args:
49
+ image (PIL.Image.Image): Input image.
50
+
51
+ Returns:
52
+ List[Dict]: List of detected objects with class name, confidence, and bbox.
53
+ """
54
+ logger.info("Running object detection")
55
+ results = self.model(image)
56
+ detections = []
57
+ for r in results:
58
+ for box in r.boxes:
59
+ detections.append({
60
+ "class_name": r.names[int(box.cls)],
61
+ "confidence": float(box.conf),
62
+ "bbox": box.xyxy[0].tolist()
63
+ })
64
+ logger.info(f"Detected {len(detections)} objects")
65
+ return detections
66
+
67
+ def draw(self, image: Image.Image, detections, alpha=0.5):
68
+ """
69
+ Draw bounding boxes on image.
70
+
71
+ Args:
72
+ image (PIL.Image.Image): Input image.
73
+ detections (List[Dict]): Detection results.
74
+ alpha (float): Blend strength.
75
+
76
+ Returns:
77
+ PIL.Image.Image: Image with bounding boxes drawn.
78
+ """
79
+ overlay = image.copy()
80
+ draw = ImageDraw.Draw(overlay)
81
+ for det in detections:
82
+ bbox = det["bbox"]
83
+ label = f'{det["class_name"]} {det["confidence"]:.2f}'
84
+ draw.rectangle(bbox, outline="red", width=2)
85
+ draw.text((bbox[0], bbox[1]), label, fill="red")
86
+ return Image.blend(image, overlay, alpha)