""" Simplified ReID - Basic threshold matching only No adaptive thresholds, no quality scoring, no smart storage """ import numpy as np import cv2 import torch import timm from sklearn.metrics.pairwise import cosine_similarity from typing import Dict, Optional from collections import defaultdict from datetime import datetime class SimplifiedReID: """Simplified ReID with basic threshold matching""" def __init__(self, device: str = 'cuda'): self.device = device if torch.cuda.is_available() else 'cpu' # Single threshold for all matching self.threshold = 0.40 # Session tracking (temp IDs) self.temp_id_features = {} # temp_id -> list of feature vectors self.next_temp_id = 1 self.current_frame = 0 self.current_video_source = "unknown" # Initialize model self._initialize_model() print(f"Simplified ReID initialized on {self.device}") def _initialize_model(self): """Load MegaDescriptor model""" try: self.model = timm.create_model( 'hf-hub:BVRA/MegaDescriptor-L-384', pretrained=True ) self.model.to(self.device).eval() self.transform = timm.data.create_transform( input_size=(384, 384), is_training=False, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ) print("MegaDescriptor-L-384 loaded") except Exception as e: print(f"Error loading model: {e}") self.model = None def set_threshold(self, threshold: float): """Set matching threshold""" self.threshold = max(0.10, min(0.95, threshold)) print(f"Threshold set to: {self.threshold:.2f}") def set_video_source(self, video_path: str): """Set current video source""" self.current_video_source = video_path def reset_session(self): """Clear session data""" self.temp_id_features.clear() self.next_temp_id = 1 self.current_frame = 0 print("Session reset") def extract_features(self, image: np.ndarray) -> Optional[np.ndarray]: """Extract feature vector from image""" if image is None or image.size == 0 or self.model is None: return None try: img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) from PIL import Image pil_img = Image.fromarray(img_rgb) img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device) with torch.no_grad(): features = self.model(img_tensor) features = features.squeeze().cpu().numpy() features = features / (np.linalg.norm(features) + 1e-7) return features except Exception as e: print(f"Feature extraction error: {e}") return None def match_or_register(self, track) -> Dict: """ Simple matching: compare features against existing temp_ids If match found -> return temp_id If no match -> create new temp_id """ self.current_frame += 1 # Get image crop from track detection = None for det in reversed(track.detections[-3:]): if det.image_crop is not None: detection = det break if detection is None or detection.image_crop is None: return {'temp_id': 0} # Extract features features = self.extract_features(detection.image_crop) if features is None: return {'temp_id': 0} # Search for match in existing temp_ids best_temp_id = None best_score = -1.0 for temp_id, features_list in self.temp_id_features.items(): # Compare against all stored features for this temp_id similarities = [] for stored_features in features_list: sim = np.dot(features, stored_features) similarities.append(sim) if similarities: max_sim = max(similarities) if max_sim > best_score: best_score = max_sim best_temp_id = temp_id # Check if best score passes threshold if best_temp_id is not None and best_score >= self.threshold: # Match found - add features to existing temp_id self.temp_id_features[best_temp_id].append(features) # Limit storage (keep last 30) if len(self.temp_id_features[best_temp_id]) > 30: self.temp_id_features[best_temp_id] = self.temp_id_features[best_temp_id][-30:] return { 'temp_id': best_temp_id, 'confidence': best_score, 'match_type': 'existing' } else: # No match - create new temp_id new_temp_id = self.next_temp_id self.next_temp_id += 1 self.temp_id_features[new_temp_id] = [features] print(f"New temp_id: {new_temp_id} (threshold: {self.threshold:.2f})") return { 'temp_id': new_temp_id, 'confidence': 1.0, 'match_type': 'new' } def get_statistics(self) -> Dict: """Get simple statistics""" return { 'temp_ids': len(self.temp_id_features), 'threshold': self.threshold, 'current_frame': self.current_frame }