Stray_Dogs / detection.py
mustafa2ak's picture
Create detection.py
3c82458 verified
raw
history blame
3.18 kB
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from typing import List, Tuple, Optional
from dataclasses import dataclass
@dataclass
class Detection:
"""Simple detection data structure"""
bbox: List[float] # [x1, y1, x2, y2]
confidence: float
image_crop: Optional[np.ndarray] = None # Cropped dog image
class DogDetector:
"""
Simplified YOLOv8 detector optimized for dogs
Uses standard pretrained model - no custom training needed
"""
def __init__(self,
confidence_threshold: float = 0.45,
device: str = 'cuda'):
"""
Initialize detector
Args:
confidence_threshold: Min confidence for detections (0.45 works well)
device: 'cuda' for GPU, 'cpu' for CPU
"""
self.confidence_threshold = confidence_threshold
self.device = device if torch.cuda.is_available() else 'cpu'
# Load YOLOv8 medium model (good balance of speed/accuracy)
self.model = YOLO('yolov8m.pt')
self.model.to(self.device)
# COCO class ID for dog
self.dog_class_id = 16
def detect(self, frame: np.ndarray) -> List[Detection]:
"""
Detect dogs in frame
Args:
frame: BGR image from OpenCV
Returns:
List of Detection objects with bounding boxes and crops
"""
# Run YOLO inference
results = self.model(frame,
conf=self.confidence_threshold,
classes=[self.dog_class_id], # Only detect dogs
verbose=False)
detections = []
if results and len(results) > 0:
result = results[0]
if result.boxes is not None:
boxes = result.boxes
for i in range(len(boxes)):
# Get bbox coordinates
x1, y1, x2, y2 = boxes.xyxy[i].cpu().numpy()
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# Ensure valid coordinates
h, w = frame.shape[:2]
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(w, x2)
y2 = min(h, y2)
# Skip invalid boxes
if x2 <= x1 or y2 <= y1:
continue
# Crop dog image
dog_crop = frame[y1:y2, x1:x2].copy()
# Create detection
detection = Detection(
bbox=[x1, y1, x2, y2],
confidence=float(boxes.conf[i]),
image_crop=dog_crop
)
detections.append(detection)
return detections
def set_confidence(self, threshold: float):
"""Update detection confidence threshold"""
self.confidence_threshold = max(0.1, min(1.0, threshold))