import cv2 import numpy as np import torch from ultralytics import YOLO from typing import List, Tuple, Optional from dataclasses import dataclass @dataclass class Detection: """Simple detection data structure""" bbox: List[float] # [x1, y1, x2, y2] confidence: float image_crop: Optional[np.ndarray] = None # Cropped dog image class DogDetector: """ Simplified YOLOv8 detector optimized for dogs Uses standard pretrained model - no custom training needed """ def __init__(self, confidence_threshold: float = 0.45, device: str = 'cuda'): """ Initialize detector Args: confidence_threshold: Min confidence for detections (0.45 works well) device: 'cuda' for GPU, 'cpu' for CPU """ self.confidence_threshold = confidence_threshold self.device = device if torch.cuda.is_available() else 'cpu' # Load YOLOv8 medium model (good balance of speed/accuracy) self.model = YOLO('yolov8m.pt') self.model.to(self.device) # COCO class ID for dog self.dog_class_id = 16 def detect(self, frame: np.ndarray) -> List[Detection]: """ Detect dogs in frame Args: frame: BGR image from OpenCV Returns: List of Detection objects with bounding boxes and crops """ # Run YOLO inference results = self.model(frame, conf=self.confidence_threshold, classes=[self.dog_class_id], # Only detect dogs verbose=False) detections = [] if results and len(results) > 0: result = results[0] if result.boxes is not None: boxes = result.boxes for i in range(len(boxes)): # Get bbox coordinates x1, y1, x2, y2 = boxes.xyxy[i].cpu().numpy() x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) # Ensure valid coordinates h, w = frame.shape[:2] x1 = max(0, x1) y1 = max(0, y1) x2 = min(w, x2) y2 = min(h, y2) # Skip invalid boxes if x2 <= x1 or y2 <= y1: continue # Crop dog image dog_crop = frame[y1:y2, x1:x2].copy() # Create detection detection = Detection( bbox=[x1, y1, x2, y2], confidence=float(boxes.conf[i]), image_crop=dog_crop ) detections.append(detection) return detections def set_confidence(self, threshold: float): """Update detection confidence threshold""" self.confidence_threshold = max(0.1, min(1.0, threshold))