Spaces:
Sleeping
Sleeping
| """ | |
| Simplified ReID - Basic threshold matching only | |
| No adaptive thresholds, no quality scoring, no smart storage | |
| """ | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import timm | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from typing import Dict, Optional | |
| from collections import defaultdict | |
| from datetime import datetime | |
| class SimplifiedReID: | |
| """Simplified ReID with basic threshold matching""" | |
| def __init__(self, device: str = 'cuda'): | |
| self.device = device if torch.cuda.is_available() else 'cpu' | |
| # Single threshold for all matching | |
| self.threshold = 0.40 | |
| # Session tracking (temp IDs) | |
| self.temp_id_features = {} # temp_id -> list of feature vectors | |
| self.next_temp_id = 1 | |
| self.current_frame = 0 | |
| self.current_video_source = "unknown" | |
| # Initialize model | |
| self._initialize_model() | |
| print(f"Simplified ReID initialized on {self.device}") | |
| def _initialize_model(self): | |
| """Load MegaDescriptor model""" | |
| try: | |
| self.model = timm.create_model( | |
| 'hf-hub:BVRA/MegaDescriptor-L-384', | |
| pretrained=True | |
| ) | |
| self.model.to(self.device).eval() | |
| self.transform = timm.data.create_transform( | |
| input_size=(384, 384), | |
| is_training=False, | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5] | |
| ) | |
| print("MegaDescriptor-L-384 loaded") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| self.model = None | |
| def set_threshold(self, threshold: float): | |
| """Set matching threshold""" | |
| self.threshold = max(0.10, min(0.95, threshold)) | |
| print(f"Threshold set to: {self.threshold:.2f}") | |
| def set_video_source(self, video_path: str): | |
| """Set current video source""" | |
| self.current_video_source = video_path | |
| def reset_session(self): | |
| """Clear session data""" | |
| self.temp_id_features.clear() | |
| self.next_temp_id = 1 | |
| self.current_frame = 0 | |
| print("Session reset") | |
| def extract_features(self, image: np.ndarray) -> Optional[np.ndarray]: | |
| """Extract feature vector from image""" | |
| if image is None or image.size == 0 or self.model is None: | |
| return None | |
| try: | |
| img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| from PIL import Image | |
| pil_img = Image.fromarray(img_rgb) | |
| img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| features = self.model(img_tensor) | |
| features = features.squeeze().cpu().numpy() | |
| features = features / (np.linalg.norm(features) + 1e-7) | |
| return features | |
| except Exception as e: | |
| print(f"Feature extraction error: {e}") | |
| return None | |
| def match_or_register(self, track) -> Dict: | |
| """ | |
| Simple matching: compare features against existing temp_ids | |
| If match found -> return temp_id | |
| If no match -> create new temp_id | |
| """ | |
| self.current_frame += 1 | |
| # Get image crop from track | |
| detection = None | |
| for det in reversed(track.detections[-3:]): | |
| if det.image_crop is not None: | |
| detection = det | |
| break | |
| if detection is None or detection.image_crop is None: | |
| return {'temp_id': 0} | |
| # Extract features | |
| features = self.extract_features(detection.image_crop) | |
| if features is None: | |
| return {'temp_id': 0} | |
| # Search for match in existing temp_ids | |
| best_temp_id = None | |
| best_score = -1.0 | |
| for temp_id, features_list in self.temp_id_features.items(): | |
| # Compare against all stored features for this temp_id | |
| similarities = [] | |
| for stored_features in features_list: | |
| sim = np.dot(features, stored_features) | |
| similarities.append(sim) | |
| if similarities: | |
| max_sim = max(similarities) | |
| if max_sim > best_score: | |
| best_score = max_sim | |
| best_temp_id = temp_id | |
| # Check if best score passes threshold | |
| if best_temp_id is not None and best_score >= self.threshold: | |
| # Match found - add features to existing temp_id | |
| self.temp_id_features[best_temp_id].append(features) | |
| # Limit storage (keep last 30) | |
| if len(self.temp_id_features[best_temp_id]) > 30: | |
| self.temp_id_features[best_temp_id] = self.temp_id_features[best_temp_id][-30:] | |
| return { | |
| 'temp_id': best_temp_id, | |
| 'confidence': best_score, | |
| 'match_type': 'existing' | |
| } | |
| else: | |
| # No match - create new temp_id | |
| new_temp_id = self.next_temp_id | |
| self.next_temp_id += 1 | |
| self.temp_id_features[new_temp_id] = [features] | |
| print(f"New temp_id: {new_temp_id} (threshold: {self.threshold:.2f})") | |
| return { | |
| 'temp_id': new_temp_id, | |
| 'confidence': 1.0, | |
| 'match_type': 'new' | |
| } | |
| def get_statistics(self) -> Dict: | |
| """Get simple statistics""" | |
| return { | |
| 'temp_ids': len(self.temp_id_features), | |
| 'threshold': self.threshold, | |
| 'current_frame': self.current_frame | |
| } |