Stray_Dogs / reid.py
mustafa2ak's picture
Update reid.py
e23353b verified
"""
Simplified ReID - Basic threshold matching only
No adaptive thresholds, no quality scoring, no smart storage
"""
import numpy as np
import cv2
import torch
import timm
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, Optional
from collections import defaultdict
from datetime import datetime
class SimplifiedReID:
"""Simplified ReID with basic threshold matching"""
def __init__(self, device: str = 'cuda'):
self.device = device if torch.cuda.is_available() else 'cpu'
# Single threshold for all matching
self.threshold = 0.40
# Session tracking (temp IDs)
self.temp_id_features = {} # temp_id -> list of feature vectors
self.next_temp_id = 1
self.current_frame = 0
self.current_video_source = "unknown"
# Initialize model
self._initialize_model()
print(f"Simplified ReID initialized on {self.device}")
def _initialize_model(self):
"""Load MegaDescriptor model"""
try:
self.model = timm.create_model(
'hf-hub:BVRA/MegaDescriptor-L-384',
pretrained=True
)
self.model.to(self.device).eval()
self.transform = timm.data.create_transform(
input_size=(384, 384),
is_training=False,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
print("MegaDescriptor-L-384 loaded")
except Exception as e:
print(f"Error loading model: {e}")
self.model = None
def set_threshold(self, threshold: float):
"""Set matching threshold"""
self.threshold = max(0.10, min(0.95, threshold))
print(f"Threshold set to: {self.threshold:.2f}")
def set_video_source(self, video_path: str):
"""Set current video source"""
self.current_video_source = video_path
def reset_session(self):
"""Clear session data"""
self.temp_id_features.clear()
self.next_temp_id = 1
self.current_frame = 0
print("Session reset")
def extract_features(self, image: np.ndarray) -> Optional[np.ndarray]:
"""Extract feature vector from image"""
if image is None or image.size == 0 or self.model is None:
return None
try:
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
from PIL import Image
pil_img = Image.fromarray(img_rgb)
img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model(img_tensor)
features = features.squeeze().cpu().numpy()
features = features / (np.linalg.norm(features) + 1e-7)
return features
except Exception as e:
print(f"Feature extraction error: {e}")
return None
def match_or_register(self, track) -> Dict:
"""
Simple matching: compare features against existing temp_ids
If match found -> return temp_id
If no match -> create new temp_id
"""
self.current_frame += 1
# Get image crop from track
detection = None
for det in reversed(track.detections[-3:]):
if det.image_crop is not None:
detection = det
break
if detection is None or detection.image_crop is None:
return {'temp_id': 0}
# Extract features
features = self.extract_features(detection.image_crop)
if features is None:
return {'temp_id': 0}
# Search for match in existing temp_ids
best_temp_id = None
best_score = -1.0
for temp_id, features_list in self.temp_id_features.items():
# Compare against all stored features for this temp_id
similarities = []
for stored_features in features_list:
sim = np.dot(features, stored_features)
similarities.append(sim)
if similarities:
max_sim = max(similarities)
if max_sim > best_score:
best_score = max_sim
best_temp_id = temp_id
# Check if best score passes threshold
if best_temp_id is not None and best_score >= self.threshold:
# Match found - add features to existing temp_id
self.temp_id_features[best_temp_id].append(features)
# Limit storage (keep last 30)
if len(self.temp_id_features[best_temp_id]) > 30:
self.temp_id_features[best_temp_id] = self.temp_id_features[best_temp_id][-30:]
return {
'temp_id': best_temp_id,
'confidence': best_score,
'match_type': 'existing'
}
else:
# No match - create new temp_id
new_temp_id = self.next_temp_id
self.next_temp_id += 1
self.temp_id_features[new_temp_id] = [features]
print(f"New temp_id: {new_temp_id} (threshold: {self.threshold:.2f})")
return {
'temp_id': new_temp_id,
'confidence': 1.0,
'match_type': 'new'
}
def get_statistics(self) -> Dict:
"""Get simple statistics"""
return {
'temp_ids': len(self.temp_id_features),
'threshold': self.threshold,
'current_frame': self.current_frame
}