Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from typing import List, Optional, Tuple | |
| from scipy.optimize import linear_sum_assignment | |
| from collections import deque | |
| import uuid | |
| from detection import Detection # Add this line | |
| class Track: | |
| """Simple track for a single dog""" | |
| def __init__(self, detection: Detection, track_id: Optional[int] = None): | |
| """Initialize track from first detection""" | |
| self.track_id = track_id if track_id else self._generate_id() | |
| self.bbox = detection.bbox | |
| self.detections = [detection] | |
| self.confidence = detection.confidence | |
| # Track state | |
| self.age = 1 | |
| self.time_since_update = 0 | |
| self.state = 'tentative' # tentative -> confirmed -> deleted | |
| self.hits = 1 | |
| # Store center points for trajectory | |
| cx = (self.bbox[0] + self.bbox[2]) / 2 | |
| cy = (self.bbox[1] + self.bbox[3]) / 2 | |
| self.trajectory = deque(maxlen=30) | |
| self.trajectory.append((cx, cy)) | |
| def _generate_id(self) -> int: | |
| """Generate unique track ID""" | |
| return int(uuid.uuid4().int % 100000) | |
| def predict(self): | |
| """Simple prediction - just use last position""" | |
| self.age += 1 | |
| self.time_since_update += 1 | |
| def update(self, detection: Detection): | |
| """Update track with new detection""" | |
| self.bbox = detection.bbox | |
| self.detections.append(detection) | |
| self.confidence = detection.confidence | |
| self.hits += 1 | |
| self.time_since_update = 0 | |
| # Update trajectory | |
| cx = (self.bbox[0] + self.bbox[2]) / 2 | |
| cy = (self.bbox[1] + self.bbox[3]) / 2 | |
| self.trajectory.append((cx, cy)) | |
| # Confirm track after 3 hits | |
| if self.state == 'tentative' and self.hits >= 3: | |
| self.state = 'confirmed' | |
| # Keep only recent detections to save memory | |
| if len(self.detections) > 10: | |
| self.detections = self.detections[-10:] | |
| def mark_missed(self): | |
| """Mark track as missed in current frame""" | |
| if self.state == 'confirmed' and self.time_since_update > 15: | |
| self.state = 'deleted' | |
| class SimpleTracker: | |
| """ | |
| Simplified ByteTrack - IoU-based tracking | |
| Robust and proven approach without complexity | |
| """ | |
| def __init__(self, | |
| match_threshold: float = 0.5, | |
| track_buffer: int = 30): | |
| """ | |
| Initialize tracker | |
| Args: | |
| match_threshold: IoU threshold for matching (0.5 works well) | |
| track_buffer: Frames to keep lost tracks | |
| """ | |
| self.match_threshold = match_threshold | |
| self.track_buffer = track_buffer | |
| self.tracks: List[Track] = [] | |
| self.track_id_count = 1 | |
| def update(self, detections: List[Detection]) -> List[Track]: | |
| """ | |
| Update tracks with new detections | |
| Args: | |
| detections: List of detections from current frame | |
| Returns: | |
| List of active tracks | |
| """ | |
| # Predict existing tracks | |
| for track in self.tracks: | |
| track.predict() | |
| # Get active tracks | |
| active_tracks = [t for t in self.tracks if t.state != 'deleted'] | |
| if len(detections) > 0 and len(active_tracks) > 0: | |
| # Calculate IoU matrix | |
| iou_matrix = self._calculate_iou_matrix(active_tracks, detections) | |
| # Hungarian matching | |
| matched, unmatched_tracks, unmatched_dets = self._associate( | |
| iou_matrix, self.match_threshold | |
| ) | |
| # Update matched tracks | |
| for t_idx, d_idx in matched: | |
| active_tracks[t_idx].update(detections[d_idx]) | |
| # Mark unmatched tracks as missed | |
| for t_idx in unmatched_tracks: | |
| active_tracks[t_idx].mark_missed() | |
| # Create new tracks for unmatched detections | |
| for d_idx in unmatched_dets: | |
| new_track = Track(detections[d_idx], self.track_id_count) | |
| self.track_id_count += 1 | |
| self.tracks.append(new_track) | |
| elif len(detections) > 0: | |
| # No existing tracks - create new ones | |
| for detection in detections: | |
| new_track = Track(detection, self.track_id_count) | |
| self.track_id_count += 1 | |
| self.tracks.append(new_track) | |
| else: | |
| # No detections - mark all as missed | |
| for track in active_tracks: | |
| track.mark_missed() | |
| # Remove deleted tracks | |
| self.tracks = [t for t in self.tracks if t.state != 'deleted'] | |
| # Return confirmed tracks | |
| return [t for t in self.tracks if t.state == 'confirmed'] | |
| def _calculate_iou_matrix(self, tracks: List[Track], | |
| detections: List[Detection]) -> np.ndarray: | |
| """Calculate IoU between all tracks and detections""" | |
| matrix = np.zeros((len(tracks), len(detections))) | |
| for t_idx, track in enumerate(tracks): | |
| for d_idx, detection in enumerate(detections): | |
| matrix[t_idx, d_idx] = self._iou(track.bbox, detection.bbox) | |
| return matrix | |
| def _iou(self, bbox1: List[float], bbox2: List[float]) -> float: | |
| """Calculate Intersection over Union""" | |
| x1 = max(bbox1[0], bbox2[0]) | |
| y1 = max(bbox1[1], bbox2[1]) | |
| x2 = min(bbox1[2], bbox2[2]) | |
| y2 = min(bbox1[3], bbox2[3]) | |
| if x2 < x1 or y2 < y1: | |
| return 0.0 | |
| intersection = (x2 - x1) * (y2 - y1) | |
| area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) | |
| area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def _associate(self, iou_matrix: np.ndarray, | |
| threshold: float) -> Tuple[List, List, List]: | |
| """Hungarian algorithm for optimal assignment""" | |
| matched_indices = [] | |
| if iou_matrix.max() >= threshold: | |
| # Convert to cost matrix | |
| cost_matrix = 1 - iou_matrix | |
| row_ind, col_ind = linear_sum_assignment(cost_matrix) | |
| for r, c in zip(row_ind, col_ind): | |
| if iou_matrix[r, c] >= threshold: | |
| matched_indices.append([r, c]) | |
| unmatched_tracks = [] | |
| unmatched_detections = [] | |
| for t in range(iou_matrix.shape[0]): | |
| if t not in [m[0] for m in matched_indices]: | |
| unmatched_tracks.append(t) | |
| for d in range(iou_matrix.shape[1]): | |
| if d not in [m[1] for m in matched_indices]: | |
| unmatched_detections.append(d) | |
| return matched_indices, unmatched_tracks, unmatched_detections |