UVIS / models /detection /detector.py
DurgaDeepak's picture
Update models/detection/detector.py
455843f verified
raw
history blame
3.07 kB
import logging
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import os
import torch
# Setup logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class ObjectDetector:
def __init__(self, model_key="yolov8n.pt", device="cpu"):
"""
Initializes an Ultralytics YOLO model path, defers actual model loading.
Args:
model_key (str): e.g. 'yolov8n.pt', 'yolov8s.pt', etc.
device (str): 'cpu' or 'cuda'
"""
self.device = device
resolved_key = model_key.lower().replace(".pt", "")
alias_map = {
"yolov8n": "yolov8n",
"yolov8s": "yolov8s",
"yolov8l": "yolov8l",
"yolov11b": "yolov11b"
}
hf_map = {
"yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
"yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
"yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
"yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
}
resolved_key = alias_map.get(resolved_key, resolved_key)
if resolved_key not in hf_map:
raise ValueError(f"Unsupported model key: {resolved_key}")
repo_id, filename = hf_map[resolved_key]
self.weights_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir="models/detection/weights",
force_download=False
)
logger.info(f"βœ… YOLO weights ready for {resolved_key} at {self.weights_path}")
self.model = None # defer loading
def load_model(self):
if self.model is None:
logger.info("βš™οΈ Loading YOLO model into memory (runtime-safe)")
self.model = YOLO(self.weights_path)
if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda")
logger.info(f"βœ… YOLO model initialized on {self.device}")
return self
def predict(self, image: Image.Image, conf_threshold=0.25):
self.load_model()
logger.info("πŸ” Running object detection")
results = self.model(image)
detections = []
for r in results:
for box in r.boxes:
detections.append({
"class_name": r.names[int(box.cls)],
"confidence": float(box.conf),
"bbox": box.xyxy[0].tolist()
})
logger.info(f"βœ… Detected {len(detections)} objects")
return detections
def draw(self, image: Image.Image, detections, alpha=0.5):
overlay = image.copy()
draw = ImageDraw.Draw(overlay)
for det in detections:
bbox = det["bbox"]
label = f'{det["class_name"]} {det["confidence"]:.2f}'
draw.rectangle(bbox, outline="red", width=2)
draw.text((bbox[0], bbox[1]), label, fill="red")
return Image.blend(image, overlay, alpha)