Stray_Dogs / reid.py
mustafa2ak's picture
Update reid.py
f6982b8 verified
raw
history blame
11.8 kB
"""
reid.py - Improved Single-Model Dog Re-Identification System
Enhanced with better feature matching and temporal consistency
"""
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Optional, Tuple
import time
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')
@dataclass
class DogFeatures:
"""Container for dog features"""
features: np.ndarray
confidence: float = 0
quality_score: float = 0.5
frame_num: int = 0
class SingleModelReID:
"""Improved ReID with better matching strategies"""
def __init__(self, device: str = 'cuda'):
self.device = device if torch.cuda.is_available() else 'cpu'
# Adaptive thresholds
self.primary_threshold = 0.45 # Main matching threshold
self.secondary_threshold = 0.35 # Lower threshold for recent tracks
self.new_dog_threshold = 0.55 # Higher threshold to create new dog
# In-memory dog database
self.dog_database = {} # dog_id -> list of features
self.next_dog_id = 1
# Track to dog mapping with confidence history
self.track_to_dog = {}
self.track_confidence = {} # track_id -> list of confidences
# Temporal consistency
self.recent_matches = {} # dog_id -> last_frame_seen
self.current_frame = 0
try:
# Initialize ResNet50
self.model = models.resnet50(weights='IMAGENET1K_V1')
self.model = nn.Sequential(*list(self.model.children())[:-1])
self.model.to(self.device).eval()
# Enhanced preprocessing with augmentation options
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)), # Slightly larger for better features
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
print("Enhanced ResNet50 ReID initialized")
except Exception as e:
print(f"ResNet50 init error: {e}")
self.model = None
def extract_features(self, image: np.ndarray) -> Optional[np.ndarray]:
"""Extract ResNet50 features with quality check"""
if self.model is None or image is None or image.size == 0:
return None
# Quality check - skip blurry images
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
if laplacian_var < 50: # Too blurry
return None
try:
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model(img_tensor)
features = features.squeeze().cpu().numpy()
# L2 normalize
features = features / (np.linalg.norm(features) + 1e-7)
return features
except:
return None
def match_or_register(self, track) -> Tuple[int, float]:
"""Enhanced matching with temporal consistency"""
self.current_frame += 1
# Get latest good quality detection
detection = None
for det in reversed(track.detections[-5:]): # Check last 5 detections
if det.image_crop is not None:
detection = det
break
if detection is None:
return 0, 0.0
# Check track confidence history
if track.track_id not in self.track_confidence:
self.track_confidence[track.track_id] = []
# If track already has consistent dog ID
if track.track_id in self.track_to_dog:
existing_dog_id = self.track_to_dog[track.track_id]
# Update features periodically (not every frame to save memory)
if self.current_frame % 3 == 0:
features = self.extract_features(detection.image_crop)
if features is not None:
if existing_dog_id in self.dog_database:
# Add with frame number for temporal reference
self.dog_database[existing_dog_id].append(
DogFeatures(
features=features,
confidence=detection.confidence,
frame_num=self.current_frame
)
)
# Keep only recent and high-quality features
self._prune_features(existing_dog_id)
# Update recent matches
self.recent_matches[existing_dog_id] = self.current_frame
# Calculate running confidence
recent_conf = np.mean(self.track_confidence[track.track_id][-10:]) if self.track_confidence[track.track_id] else detection.confidence
return existing_dog_id, recent_conf
# Extract features for matching
features = self.extract_features(detection.image_crop)
if features is None:
return 0, 0.0
# Find best match with adaptive thresholds
best_dog_id = None
best_score = -1.0
match_details = {}
for dog_id, feature_list in self.dog_database.items():
if not feature_list:
continue
# Check temporal proximity (bonus for recently seen dogs)
recency_bonus = 0.0
if dog_id in self.recent_matches:
frames_since = self.current_frame - self.recent_matches[dog_id]
if frames_since < 30: # Within 1 second at 30fps
recency_bonus = 0.05 * (1 - frames_since / 30)
# Weighted similarity based on feature quality and recency
similarities = []
weights = []
for dog_feat in feature_list[-8:]: # Use last 8 features
sim = cosine_similarity(
features.reshape(1, -1),
dog_feat.features.reshape(1, -1)
)[0, 0]
# Weight by confidence and recency
weight = dog_feat.confidence
if hasattr(dog_feat, 'frame_num'):
age = self.current_frame - dog_feat.frame_num
weight *= np.exp(-age / 100) # Exponential decay
similarities.append(sim)
weights.append(weight)
# Weighted average
if weights:
weights = np.array(weights)
weights = weights / weights.sum()
avg_similarity = np.average(similarities, weights=weights) + recency_bonus
else:
avg_similarity = np.mean(similarities) + recency_bonus
match_details[dog_id] = avg_similarity
if avg_similarity > best_score:
best_score = avg_similarity
best_dog_id = dog_id
# Adaptive threshold based on context
threshold = self.primary_threshold
if best_dog_id and best_dog_id in self.recent_matches:
# Lower threshold for recently seen dogs
if self.current_frame - self.recent_matches[best_dog_id] < 60:
threshold = self.secondary_threshold
# Decision logic
if best_dog_id is not None and best_score >= threshold:
# Match found - but verify it's not too different
if best_score < self.new_dog_threshold or len(match_details) < 3:
# Accept match
self.dog_database[best_dog_id].append(
DogFeatures(
features=features,
confidence=detection.confidence,
frame_num=self.current_frame
)
)
self._prune_features(best_dog_id)
self.track_to_dog[track.track_id] = best_dog_id
self.track_confidence[track.track_id].append(best_score)
self.recent_matches[best_dog_id] = self.current_frame
return best_dog_id, best_score
else:
# Score in ambiguous range - check if we should create new dog
second_best_score = sorted(match_details.values(), reverse=True)[1] if len(match_details) > 1 else 0
if best_score - second_best_score < 0.1:
# Too similar to multiple dogs - likely new dog
pass # Fall through to create new dog
# Register new dog
new_dog_id = self.next_dog_id
self.next_dog_id += 1
self.dog_database[new_dog_id] = [
DogFeatures(
features=features,
confidence=detection.confidence,
frame_num=self.current_frame
)
]
self.track_to_dog[track.track_id] = new_dog_id
self.track_confidence[track.track_id] = [1.0]
self.recent_matches[new_dog_id] = self.current_frame
return new_dog_id, 1.0
def _prune_features(self, dog_id: int):
"""Keep only best recent features to save memory"""
if dog_id not in self.dog_database:
return
features = self.dog_database[dog_id]
if len(features) > 15:
# Sort by confidence and recency
features.sort(key=lambda x: x.confidence + (0.001 * x.frame_num), reverse=True)
# Keep top 10
self.dog_database[dog_id] = features[:10]
def match_or_register_all(self, track) -> Dict:
"""Compatible interface"""
dog_id, confidence = self.match_or_register(track)
return {
'ResNet50': {
'dog_id': dog_id,
'confidence': confidence,
'processing_time': 0
}
}
def reset_all(self):
"""Reset all temporary data"""
self.dog_database.clear()
self.track_to_dog.clear()
self.track_confidence.clear()
self.recent_matches.clear()
self.next_dog_id = 1
self.current_frame = 0
def set_all_thresholds(self, threshold: float):
"""Set similarity thresholds adaptively"""
self.primary_threshold = max(0.3, min(0.9, threshold))
self.secondary_threshold = max(0.25, self.primary_threshold - 0.1)
self.new_dog_threshold = min(0.9, self.primary_threshold + 0.1)
def get_statistics(self) -> Dict:
"""Get statistics"""
return {
'ResNet50': {
'total_dogs': self.next_dog_id - 1,
'dogs_in_database': len(self.dog_database),
'active_dogs': len([d for d, f in self.recent_matches.items()
if self.current_frame - f < 150]),
'avg_features_per_dog': np.mean([len(f) for f in self.dog_database.values()]) if self.dog_database else 0,
'threshold': self.primary_threshold
}
}
# Compatibility aliases
ImprovedResNet50ReID = SingleModelReID
DualModelReID = SingleModelReID