Spaces:
Sleeping
Sleeping
| """ | |
| reid_adaptive.py - Enhanced ReID with Adaptive Thresholding | |
| Automatically adjusts similarity threshold based on data distribution | |
| """ | |
| import numpy as np | |
| from collections import deque | |
| from typing import List, Dict, Tuple, Optional | |
| import scipy.stats as stats | |
| class AdaptiveThreshold: | |
| """ | |
| Manages adaptive similarity threshold using statistical methods | |
| """ | |
| def __init__(self, | |
| initial_threshold: float = 0.7, | |
| window_size: int = 100, | |
| adaptation_rate: float = 0.1): | |
| """ | |
| Args: | |
| initial_threshold: Starting threshold value | |
| window_size: Number of recent similarities to consider | |
| adaptation_rate: How quickly to adapt (0-1) | |
| """ | |
| self.base_threshold = initial_threshold | |
| self.current_threshold = initial_threshold | |
| self.adaptation_rate = adaptation_rate | |
| # Store recent similarity scores | |
| self.similarity_history = deque(maxlen=window_size) | |
| self.match_history = deque(maxlen=window_size) # True/False outcomes | |
| # Statistics tracking | |
| self.threshold_history = deque(maxlen=1000) | |
| self.threshold_history.append(initial_threshold) | |
| def update_and_get_threshold(self, | |
| new_similarity: float, | |
| was_correct_match: Optional[bool] = None) -> float: | |
| """ | |
| Update threshold based on new data point | |
| Args: | |
| new_similarity: Latest similarity score | |
| was_correct_match: Feedback on whether last match was correct (if known) | |
| Returns: | |
| Adaptive threshold for this decision | |
| """ | |
| # Add to history | |
| self.similarity_history.append(new_similarity) | |
| if was_correct_match is not None: | |
| self.match_history.append(was_correct_match) | |
| # Need minimum samples before adapting | |
| if len(self.similarity_history) < 20: | |
| return self.current_threshold | |
| # Calculate adaptive threshold using multiple strategies | |
| thresholds = [] | |
| weights = [] | |
| # Strategy 1: Statistical threshold (mean - k*std) | |
| stat_threshold = self._statistical_threshold() | |
| if stat_threshold: | |
| thresholds.append(stat_threshold) | |
| weights.append(0.4) | |
| # Strategy 2: Distribution gap threshold | |
| gap_threshold = self._gap_threshold() | |
| if gap_threshold: | |
| thresholds.append(gap_threshold) | |
| weights.append(0.3) | |
| # Strategy 3: Performance-based adjustment | |
| perf_threshold = self._performance_threshold() | |
| if perf_threshold: | |
| thresholds.append(perf_threshold) | |
| weights.append(0.3) | |
| # Combine strategies | |
| if thresholds: | |
| weighted_threshold = np.average(thresholds, weights=weights[:len(thresholds)]) | |
| # Smooth adaptation | |
| self.current_threshold = ( | |
| self.current_threshold * (1 - self.adaptation_rate) + | |
| weighted_threshold * self.adaptation_rate | |
| ) | |
| # Keep within reasonable bounds | |
| self.current_threshold = np.clip(self.current_threshold, 0.4, 0.9) | |
| self.threshold_history.append(self.current_threshold) | |
| return self.current_threshold | |
| def _statistical_threshold(self) -> Optional[float]: | |
| """ | |
| Calculate threshold based on statistical distribution | |
| Uses Otsu's method variant for bimodal distribution | |
| """ | |
| if len(self.similarity_history) < 20: | |
| return None | |
| similarities = np.array(self.similarity_history) | |
| # Check for bimodal distribution (matches vs non-matches) | |
| hist, bins = np.histogram(similarities, bins=20) | |
| # Find valley between peaks using gradient | |
| if len(hist) > 5: | |
| gradient = np.diff(hist) | |
| # Look for sign change from negative to positive (valley) | |
| valleys = [] | |
| for i in range(1, len(gradient)-1): | |
| if gradient[i-1] < 0 and gradient[i] > 0: | |
| valleys.append(bins[i+1]) | |
| if valleys: | |
| # Use the most prominent valley | |
| return float(np.median(valleys)) | |
| # Fallback: use mean - 1.5*std | |
| mean = np.mean(similarities) | |
| std = np.std(similarities) | |
| return max(0.4, mean - 1.5 * std) | |
| def _gap_threshold(self) -> Optional[float]: | |
| """ | |
| Find natural gap in similarity scores | |
| """ | |
| if len(self.similarity_history) < 30: | |
| return None | |
| similarities = sorted(self.similarity_history) | |
| # Find largest gap | |
| gaps = [] | |
| for i in range(1, len(similarities)): | |
| gap_size = similarities[i] - similarities[i-1] | |
| gap_position = (similarities[i] + similarities[i-1]) / 2 | |
| gaps.append((gap_size, gap_position)) | |
| if gaps: | |
| # Find significant gaps (> 90th percentile) | |
| gap_sizes = [g[0] for g in gaps] | |
| threshold_gap_size = np.percentile(gap_sizes, 90) | |
| significant_gaps = [g[1] for g in gaps if g[0] > threshold_gap_size] | |
| if significant_gaps: | |
| # Use gap closest to middle of range | |
| mid_range = (max(similarities) + min(similarities)) / 2 | |
| best_gap = min(significant_gaps, | |
| key=lambda x: abs(x - mid_range)) | |
| return float(best_gap) | |
| return None | |
| def _performance_threshold(self) -> Optional[float]: | |
| """ | |
| Adjust based on match accuracy feedback | |
| """ | |
| if len(self.match_history) < 10: | |
| return None | |
| # Calculate false positive and false negative rates | |
| recent_matches = list(self.match_history)[-50:] | |
| accuracy = sum(recent_matches) / len(recent_matches) | |
| # Adjust threshold based on accuracy | |
| if accuracy < 0.7: # Too many errors | |
| # Threshold might be too loose or too strict | |
| # Analyze error types by comparing to current threshold | |
| recent_sims = list(self.similarity_history)[-50:] | |
| high_sim_errors = sum(1 for i, correct in enumerate(recent_matches) | |
| if not correct and recent_sims[i] > self.current_threshold) | |
| low_sim_errors = sum(1 for i, correct in enumerate(recent_matches) | |
| if not correct and recent_sims[i] <= self.current_threshold) | |
| if high_sim_errors > low_sim_errors: | |
| # Too many false positives - increase threshold | |
| return self.current_threshold + 0.05 | |
| else: | |
| # Too many false negatives - decrease threshold | |
| return self.current_threshold - 0.05 | |
| return self.current_threshold | |
| class SimpleReIDAdaptive: | |
| """ | |
| Enhanced ReID with adaptive thresholding | |
| Drop-in replacement for SimpleReID | |
| """ | |
| def __init__(self, | |
| similarity_threshold: float = 0.7, | |
| device: str = 'cuda', | |
| use_adaptive: bool = True): | |
| """ | |
| Initialize ReID with optional adaptive thresholding | |
| Args: | |
| similarity_threshold: Initial/fallback threshold | |
| device: 'cuda' or 'cpu' | |
| use_adaptive: Whether to use adaptive thresholding | |
| """ | |
| # Initialize base ReID (same as before) | |
| self.device = device if torch.cuda.is_available() else 'cpu' | |
| self.base_threshold = similarity_threshold | |
| self.use_adaptive = use_adaptive | |
| # ... (rest of initialization same as SimpleReID) | |
| # Adaptive threshold manager | |
| self.adaptive_threshold = AdaptiveThreshold( | |
| initial_threshold=similarity_threshold | |
| ) | |
| # Per-dog adaptive thresholds (optional) | |
| self.dog_thresholds: Dict[int, AdaptiveThreshold] = {} | |
| def match_or_register(self, track: Track) -> Tuple[int, float]: | |
| """ | |
| Match with adaptive threshold | |
| """ | |
| if not track.detections: | |
| return 0, 0.0 | |
| # Extract features (same as before) | |
| features = self.extract_features(latest_detection.image_crop) | |
| if features is None: | |
| return 0, 0.0 | |
| # Calculate similarities with all dogs | |
| all_similarities = [] | |
| dog_similarities = {} | |
| for dog_id, stored_features in self.dog_database.items(): | |
| similarities = [] | |
| for stored_feat in stored_features[-5:]: | |
| sim = cosine_similarity( | |
| features.reshape(1, -1), | |
| stored_feat.reshape(1, -1) | |
| )[0, 0] | |
| similarities.append(sim) | |
| avg_similarity = np.mean(similarities) if similarities else 0.0 | |
| dog_similarities[dog_id] = avg_similarity | |
| all_similarities.extend(similarities) | |
| # Get adaptive threshold | |
| if self.use_adaptive and all_similarities: | |
| # Use global adaptive threshold | |
| max_sim = max(dog_similarities.values()) if dog_similarities else 0.0 | |
| threshold = self.adaptive_threshold.update_and_get_threshold(max_sim) | |
| # Optional: Per-dog thresholds for known difficult cases | |
| best_dog_id = max(dog_similarities, key=dog_similarities.get) if dog_similarities else None | |
| if best_dog_id and best_dog_id in self.dog_thresholds: | |
| dog_specific_threshold = self.dog_thresholds[best_dog_id].update_and_get_threshold( | |
| dog_similarities[best_dog_id] | |
| ) | |
| # Use more conservative threshold | |
| threshold = max(threshold, dog_specific_threshold) | |
| else: | |
| threshold = self.base_threshold | |
| # Find best match | |
| if dog_similarities: | |
| best_dog_id = max(dog_similarities, key=dog_similarities.get) | |
| best_similarity = dog_similarities[best_dog_id] | |
| if best_similarity >= threshold: | |
| # Update existing dog | |
| self.dog_database[best_dog_id].append(features) | |
| if len(self.dog_database[best_dog_id]) > 20: | |
| self.dog_database[best_dog_id] = self.dog_database[best_dog_id][-20:] | |
| self.dog_images[best_dog_id] = latest_detection.image_crop | |
| # Store decision for learning | |
| self._record_match_decision(best_dog_id, best_similarity, True) | |
| return best_dog_id, best_similarity | |
| # Register new dog | |
| new_dog_id = self.next_dog_id | |
| self.next_dog_id += 1 | |
| self.dog_database[new_dog_id] = [features] | |
| self.dog_images[new_dog_id] = latest_detection.image_crop | |
| # Initialize per-dog threshold if using adaptive | |
| if self.use_adaptive: | |
| self.dog_thresholds[new_dog_id] = AdaptiveThreshold( | |
| initial_threshold=self.adaptive_threshold.current_threshold | |
| ) | |
| return new_dog_id, 1.0 | |
| def _record_match_decision(self, dog_id: int, similarity: float, was_match: bool): | |
| """ | |
| Record matching decision for learning | |
| Can be enhanced with user feedback | |
| """ | |
| # This could be connected to user corrections | |
| # For now, we assume high-confidence matches are correct | |
| was_correct = similarity > 0.85 if was_match else similarity < 0.5 | |
| # Update global threshold learning | |
| if self.use_adaptive: | |
| self.adaptive_threshold.update_and_get_threshold( | |
| similarity, was_correct | |
| ) | |
| def get_threshold_info(self) -> Dict: | |
| """ | |
| Get current threshold information for debugging | |
| """ | |
| info = { | |
| 'current_threshold': self.adaptive_threshold.current_threshold, | |
| 'base_threshold': self.base_threshold, | |
| 'use_adaptive': self.use_adaptive, | |
| 'threshold_history': list(self.adaptive_threshold.threshold_history)[-20:], | |
| 'similarity_stats': { | |
| 'mean': np.mean(self.adaptive_threshold.similarity_history) if self.adaptive_threshold.similarity_history else 0, | |
| 'std': np.std(self.adaptive_threshold.similarity_history) if self.adaptive_threshold.similarity_history else 0, | |
| 'min': min(self.adaptive_threshold.similarity_history) if self.adaptive_threshold.similarity_history else 0, | |
| 'max': max(self.adaptive_threshold.similarity_history) if self.adaptive_threshold.similarity_history else 0 | |
| } | |
| } | |
| return info | |
| # Integration with Gradio UI | |
| def create_adaptive_controls(app): | |
| """ | |
| Add adaptive threshold controls to Gradio interface | |
| """ | |
| import gradio as gr | |
| with gr.Column(): | |
| gr.Markdown("### Adaptive Threshold Settings") | |
| adaptive_toggle = gr.Checkbox( | |
| label="Enable Adaptive Threshold", | |
| value=True, | |
| info="Automatically adjust threshold based on data" | |
| ) | |
| adaptation_rate = gr.Slider( | |
| minimum=0.01, | |
| maximum=0.5, | |
| value=0.1, | |
| step=0.01, | |
| label="Adaptation Rate", | |
| info="How quickly threshold adapts (lower = more stable)" | |
| ) | |
| window_size = gr.Slider( | |
| minimum=20, | |
| maximum=500, | |
| value=100, | |
| step=10, | |
| label="History Window", | |
| info="Number of recent matches to consider" | |
| ) | |
| # Threshold visualization | |
| threshold_plot = gr.LinePlot( | |
| label="Threshold History", | |
| x="Sample", | |
| y="Threshold", | |
| height=200 | |
| ) | |
| # Stats display | |
| threshold_info = gr.JSON( | |
| label="Threshold Statistics" | |
| ) | |
| return adaptive_toggle, adaptation_rate, window_size, threshold_plot, threshold_info |