mustafa2ak commited on
Commit
e96f282
·
verified ·
1 Parent(s): 19bce7a

Update tracking.py

Browse files
Files changed (1) hide show
  1. tracking.py +229 -162
tracking.py CHANGED
@@ -1,6 +1,5 @@
1
  """
2
- tracking.py - Production-ready tracking with comprehensive error handling
3
- Includes all bug fixes and defensive programming
4
  """
5
  import numpy as np
6
  from typing import List, Optional, Tuple, Dict
@@ -11,12 +10,9 @@ from detection import Detection
11
  import warnings
12
  warnings.filterwarnings('ignore')
13
 
14
-
15
  class Track:
16
- """Enhanced track with robust state management"""
17
-
18
  def __init__(self, detection: Detection, track_id: Optional[int] = None):
19
- """Initialize track from first detection"""
20
  self.track_id = track_id if track_id else self._generate_id()
21
  self.bbox = detection.bbox.copy() if hasattr(detection, 'bbox') else [0, 0, 100, 100]
22
  self.detections = [detection]
@@ -29,6 +25,11 @@ class Track:
29
  self.hits = 1
30
  self.consecutive_misses = 0
31
 
 
 
 
 
 
32
  # Store center points for trajectory
33
  cx = (self.bbox[0] + self.bbox[2]) / 2
34
  cy = (self.bbox[1] + self.bbox[3]) / 2
@@ -39,54 +40,87 @@ class Track:
39
  self.velocity = np.array([0.0, 0.0])
40
  self.acceleration = np.array([0.0, 0.0])
41
 
42
- # Appearance features for re-association
43
  self.appearance_features = []
44
  if hasattr(detection, 'features'):
45
  self.appearance_features.append(detection.features)
46
-
47
- # Size tracking for scale changes
48
  self.sizes = deque(maxlen=10)
49
  width = max(1, self.bbox[2] - self.bbox[0])
50
  height = max(1, self.bbox[3] - self.bbox[1])
51
  self.sizes.append((width, height))
 
52
 
53
  # Track quality metrics
54
  self.avg_confidence = self.confidence
55
  self.max_confidence = self.confidence
56
-
 
 
57
  def _generate_id(self) -> int:
58
- """Generate unique track ID"""
59
  return int(uuid.uuid4().int % 100000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
  def predict(self):
62
- """Enhanced motion prediction with safety checks"""
63
  self.age += 1
64
  self.time_since_update += 1
65
  self.consecutive_misses += 1
66
 
67
  try:
68
  if len(self.trajectory) >= 3:
69
- # Calculate velocity and acceleration from recent positions
70
  positions = np.array(list(self.trajectory))[-3:]
71
-
72
- # Velocity from last two positions
73
  self.velocity = positions[-1] - positions[-2]
74
 
75
- # Limit velocity to reasonable values
76
- max_velocity = 50 # pixels per frame
77
  velocity_magnitude = np.linalg.norm(self.velocity)
78
  if velocity_magnitude > max_velocity:
79
  self.velocity = self.velocity / velocity_magnitude * max_velocity
80
 
81
- # Acceleration from velocity change
82
  if len(positions) == 3:
83
  prev_velocity = positions[-2] - positions[-3]
84
  self.acceleration = (self.velocity - prev_velocity) * 0.3
85
 
86
- # Predict next position with damping
87
  predicted_pos = positions[-1] + self.velocity * 0.7 + self.acceleration * 0.1
88
 
89
- # Get average recent size for stable bbox
90
  if self.sizes:
91
  avg_width = np.mean([s[0] for s in self.sizes])
92
  avg_height = np.mean([s[1] for s in self.sizes])
@@ -94,7 +128,6 @@ class Track:
94
  avg_width = max(10, self.bbox[2] - self.bbox[0])
95
  avg_height = max(10, self.bbox[3] - self.bbox[1])
96
 
97
- # Update bbox with predicted center and smoothed size
98
  self.bbox = [
99
  predicted_pos[0] - avg_width/2,
100
  predicted_pos[1] - avg_height/2,
@@ -102,29 +135,38 @@ class Track:
102
  predicted_pos[1] + avg_height/2
103
  ]
104
  except Exception as e:
105
- # Fallback: Keep current bbox
106
- print(f"Track prediction error: {e}")
107
  pass
108
-
109
  def update(self, detection: Detection):
110
- """Update track with new detection"""
111
  try:
 
 
 
 
 
112
  # Update bbox
113
  if hasattr(detection, 'bbox'):
114
  self.bbox = detection.bbox.copy()
115
-
116
- self.detections.append(detection)
117
 
118
  # Update confidence
119
  if hasattr(detection, 'confidence'):
120
  self.confidence = detection.confidence
121
- self.avg_confidence = (self.avg_confidence * 0.9 + self.confidence * 0.1)
 
122
  self.max_confidence = max(self.max_confidence, self.confidence)
123
 
124
  self.hits += 1
125
  self.time_since_update = 0
126
  self.consecutive_misses = 0
127
 
 
 
 
 
 
 
128
  # Update trajectory
129
  cx = (self.bbox[0] + self.bbox[2]) / 2
130
  cy = (self.bbox[1] + self.bbox[3]) / 2
@@ -135,59 +177,52 @@ class Track:
135
  height = max(1, self.bbox[3] - self.bbox[1])
136
  self.sizes.append((width, height))
137
 
138
- # Store appearance features if available
139
  if hasattr(detection, 'features'):
140
  self.appearance_features.append(detection.features)
141
  if len(self.appearance_features) > 5:
142
  self.appearance_features = self.appearance_features[-5:]
143
 
144
- # Confirm track after 2 hits
145
- if self.state == 'tentative' and self.hits >= 2:
146
  self.state = 'confirmed'
147
-
148
- # Keep only recent detections to save memory
149
  if len(self.detections) > 5:
150
- # Clear old detection images to save memory
151
  for old_det in self.detections[:-5]:
152
  if hasattr(old_det, 'image_crop'):
153
  old_det.image_crop = None
154
  self.detections = self.detections[-5:]
155
-
156
  except Exception as e:
157
  print(f"Track update error: {e}")
158
-
159
  def mark_missed(self):
160
- """Mark track as missed in current frame"""
161
- if self.state == 'confirmed':
162
- # More lenient deletion criteria
163
- if self.consecutive_misses > 15:
164
  self.state = 'deleted'
165
- elif self.time_since_update > 30:
166
  self.state = 'deleted'
167
  elif self.state == 'tentative':
168
  if self.consecutive_misses > 3:
169
  self.state = 'deleted'
 
 
 
170
 
171
 
172
- class RobustTracker:
173
  """
174
- Production-ready tracker with comprehensive error handling
175
  """
176
-
177
  def __init__(self,
178
  match_threshold: float = 0.35,
179
- track_buffer: int = 30,
180
  min_iou_for_match: float = 0.15,
181
  use_appearance: bool = False):
182
- """
183
- Initialize tracker with safe defaults
184
 
185
- Args:
186
- match_threshold: IoU threshold for matching (0.35 is balanced)
187
- track_buffer: Frames to keep lost tracks
188
- min_iou_for_match: Minimum IoU to consider a match
189
- use_appearance: Whether to use appearance features (set False for speed)
190
- """
191
  self.match_threshold = match_threshold
192
  self.track_buffer = track_buffer
193
  self.min_iou_for_match = min_iou_for_match
@@ -196,24 +231,90 @@ class RobustTracker:
196
  self.tracks: List[Track] = []
197
  self.track_id_count = 1
198
 
 
 
 
 
199
  # Enhanced parameters
200
- self.max_center_distance = 150 # pixels (reduced for stricter matching)
201
- self.min_size_similarity = 0.4 # Size change threshold
202
 
203
- # Debug mode
204
  self.debug = False
 
 
 
 
 
205
 
206
- def update(self, detections: List[Detection]) -> List[Track]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  """
208
- Update tracks with robust error handling
 
209
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if not detections:
211
- # No detections - just predict existing tracks
212
  for track in self.tracks:
213
  track.predict()
214
  track.mark_missed()
215
 
216
- # Remove deleted tracks
 
 
217
  self.tracks = [t for t in self.tracks if t.state != 'deleted']
218
  return [t for t in self.tracks if t.state == 'confirmed']
219
 
@@ -221,31 +322,29 @@ class RobustTracker:
221
  # Predict existing tracks
222
  for track in self.tracks:
223
  track.predict()
224
-
225
  # Split tracks by state
226
  confirmed_tracks = [t for t in self.tracks if t.state == 'confirmed']
227
  tentative_tracks = [t for t in self.tracks if t.state == 'tentative']
228
 
229
- # Initialize matched indices
230
  matched_track_indices = set()
231
  matched_det_indices = set()
232
 
233
- # Stage 1: Match confirmed tracks with all detections
234
  if confirmed_tracks and detections:
235
  matched_track_indices, matched_det_indices = self._associate_tracks(
236
- confirmed_tracks, detections,
237
  matched_track_indices, matched_det_indices,
238
  threshold_mult=1.0
239
  )
240
 
241
- # Stage 2: Match tentative tracks with unmatched detections
242
  if tentative_tracks:
243
- unmatched_dets = [detections[i] for i in range(len(detections))
244
  if i not in matched_det_indices]
245
 
246
  if unmatched_dets:
247
- # Create temporary mapping
248
- temp_det_mapping = [i for i in range(len(detections))
249
  if i not in matched_det_indices]
250
 
251
  tent_matched_tracks, tent_matched_dets = self._associate_tracks(
@@ -254,114 +353,101 @@ class RobustTracker:
254
  threshold_mult=0.7
255
  )
256
 
257
- # Map back to original detection indices
258
  for det_idx in tent_matched_dets:
259
  matched_det_indices.add(temp_det_mapping[det_idx])
260
 
261
- # Mark unmatched tracks as missed
262
  for i, track in enumerate(confirmed_tracks):
263
  if i not in matched_track_indices:
264
  track.mark_missed()
265
-
266
  for track in tentative_tracks:
267
  if track.time_since_update > 0:
268
  track.mark_missed()
269
 
270
- # Create new tracks for unmatched detections
271
  for det_idx in range(len(detections)):
272
  if det_idx not in matched_det_indices:
273
  detection = detections[det_idx]
274
 
275
- # Check if detection is too close to existing tracks
276
- if self._is_new_track(detection):
277
  new_track = Track(detection, self.track_id_count)
278
  self.track_id_count += 1
279
  self.tracks.append(new_track)
 
 
 
 
 
280
 
281
  # Remove deleted tracks
282
  self.tracks = [t for t in self.tracks if t.state != 'deleted']
283
 
284
- # Return only confirmed tracks
285
- return [t for t in self.tracks if t.state == 'confirmed']
286
-
287
  except Exception as e:
288
  print(f"Tracker update error: {e}")
289
- # Return existing confirmed tracks as fallback
290
  return [t for t in self.tracks if t.state == 'confirmed']
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def _associate_tracks(self, tracks: List[Track], detections: List[Detection],
293
  existing_matched_tracks: set, existing_matched_dets: set,
294
  threshold_mult: float = 1.0) -> Tuple[set, set]:
295
- """
296
- Safe track-detection association
297
-
298
- Returns:
299
- (matched_track_indices, matched_det_indices)
300
- """
301
  if not tracks or not detections:
302
  return existing_matched_tracks, existing_matched_dets
303
-
304
  try:
305
- # Calculate cost matrix
306
  cost_matrix = self._calculate_enhanced_cost_matrix(tracks, detections)
307
 
308
  if cost_matrix.size == 0:
309
  return existing_matched_tracks, existing_matched_dets
310
 
311
- # Hungarian matching
312
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
313
 
314
  matched_tracks = existing_matched_tracks.copy()
315
  matched_dets = existing_matched_dets.copy()
316
 
317
- # Process matches
318
  for r, c in zip(row_ind, col_ind):
319
- # Check bounds
320
  if r >= len(tracks) or c >= len(detections):
321
  continue
322
-
323
- # Check cost threshold
324
  threshold = (1 - self.match_threshold * threshold_mult)
325
  if cost_matrix[r, c] < threshold:
326
  tracks[r].update(detections[c])
327
  matched_tracks.add(r)
328
  matched_dets.add(c)
329
-
330
- return matched_tracks, matched_dets
331
 
 
 
332
  except Exception as e:
333
  print(f"Association error: {e}")
334
  return existing_matched_tracks, existing_matched_dets
335
 
336
- def _is_new_track(self, detection: Detection) -> bool:
337
- """Check if detection represents a new track"""
338
- try:
339
- det_center = self._get_center(detection.bbox)
340
-
341
- for track in self.tracks:
342
- if track.state == 'deleted':
343
- continue
344
-
345
- track_center = self._get_center(track.bbox)
346
- dist = np.linalg.norm(np.array(det_center) - np.array(track_center))
347
-
348
- # Very close to existing track - likely same object
349
- if dist < 30:
350
- return False
351
-
352
- return True
353
-
354
- except Exception as e:
355
- print(f"New track check error: {e}")
356
- return True # Default to creating new track
357
-
358
  def _calculate_enhanced_cost_matrix(self, tracks: List[Track],
359
  detections: List[Detection]) -> np.ndarray:
360
- """Calculate cost matrix with error handling"""
361
  try:
362
  if not tracks or not detections:
363
  return np.array([])
364
-
365
  n_tracks = len(tracks)
366
  n_dets = len(detections)
367
  cost_matrix = np.ones((n_tracks, n_dets))
@@ -369,7 +455,7 @@ class RobustTracker:
369
  for t_idx, track in enumerate(tracks):
370
  if not hasattr(track, 'bbox') or len(track.bbox) != 4:
371
  continue
372
-
373
  track_center = np.array(self._get_center(track.bbox))
374
  track_size = np.array([
375
  max(1, track.bbox[2] - track.bbox[0]),
@@ -379,66 +465,47 @@ class RobustTracker:
379
  for d_idx, detection in enumerate(detections):
380
  if not hasattr(detection, 'bbox') or len(detection.bbox) != 4:
381
  continue
382
-
383
  # IoU cost
384
  iou = self._iou(track.bbox, detection.bbox)
385
 
386
- # Center distance cost
387
  det_center = np.array(self._get_center(detection.bbox))
388
  distance = np.linalg.norm(track_center - det_center)
389
 
390
- # Size similarity cost
391
  det_size = np.array([
392
  max(1, detection.bbox[2] - detection.bbox[0]),
393
  max(1, detection.bbox[3] - detection.bbox[1])
394
  ])
395
 
396
- # Prevent division by zero
397
  size_ratio = np.minimum(track_size, det_size) / (np.maximum(track_size, det_size) + 1e-6)
398
  size_cost = 1 - np.mean(size_ratio)
399
 
400
- # Check basic constraints
401
  if iou >= self.min_iou_for_match and distance < self.max_center_distance:
402
  iou_cost = 1 - iou
403
  dist_cost = distance / self.max_center_distance
404
 
405
- # Weighted combination (IoU is most important)
406
- total_cost = (0.6 * iou_cost +
407
- 0.25 * dist_cost +
408
- 0.15 * size_cost)
409
 
410
- # Add appearance cost if available and enabled
411
- if (self.use_appearance and
412
- hasattr(track, 'appearance_features') and
413
- track.appearance_features and
414
- hasattr(detection, 'features')):
415
- try:
416
- track_feat = np.mean(track.appearance_features, axis=0)
417
- det_feat = detection.features
418
-
419
- # Cosine similarity
420
- feat_norm = np.linalg.norm(track_feat) * np.linalg.norm(det_feat)
421
- if feat_norm > 0:
422
- app_sim = np.dot(track_feat, det_feat) / feat_norm
423
- app_cost = 1 - max(0, min(1, app_sim))
424
- total_cost = (0.5 * iou_cost + 0.2 * dist_cost +
425
- 0.15 * size_cost + 0.15 * app_cost)
426
- except:
427
- pass # Use cost without appearance
428
 
429
  cost_matrix[t_idx, d_idx] = total_cost
430
  else:
431
  cost_matrix[t_idx, d_idx] = 1.0
432
-
433
- return cost_matrix
434
 
 
 
435
  except Exception as e:
436
- print(f"Cost matrix calculation error: {e}")
437
- # Return high cost matrix as fallback
438
  return np.ones((len(tracks), len(detections)))
439
 
440
  def _get_center(self, bbox: List[float]) -> Tuple[float, float]:
441
- """Get bbox center with validation"""
442
  try:
443
  if len(bbox) >= 4:
444
  return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
@@ -447,11 +514,11 @@ class RobustTracker:
447
  return (0, 0)
448
 
449
  def _iou(self, bbox1: List[float], bbox2: List[float]) -> float:
450
- """Calculate IoU with validation"""
451
  try:
452
  if len(bbox1) < 4 or len(bbox2) < 4:
453
  return 0.0
454
-
455
  x1 = max(bbox1[0], bbox2[0])
456
  y1 = max(bbox1[1], bbox2[1])
457
  x2 = min(bbox1[2], bbox2[2])
@@ -459,42 +526,42 @@ class RobustTracker:
459
 
460
  if x2 < x1 or y2 < y1:
461
  return 0.0
462
-
463
- intersection = (x2 - x1) * (y2 - y1)
464
 
 
465
  area1 = max(1, (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]))
466
  area2 = max(1, (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]))
467
  union = area1 + area2 - intersection
468
 
469
  return max(0, min(1, intersection / (union + 1e-6)))
470
-
471
  except Exception as e:
472
- print(f"IoU calculation error: {e}")
473
  return 0.0
474
 
475
  def set_match_threshold(self, threshold: float):
476
  """Update matching threshold"""
477
  self.match_threshold = max(0.1, min(0.8, threshold))
478
- print(f"Tracking threshold updated to: {self.match_threshold:.2f}")
479
 
480
  def reset(self):
481
- """Reset tracker state"""
482
  self.tracks.clear()
483
  self.track_id_count = 1
484
  print("Tracker reset")
485
 
486
  def get_statistics(self) -> Dict:
487
- """Get tracker statistics"""
488
  confirmed = len([t for t in self.tracks if t.state == 'confirmed'])
 
489
  tentative = len([t for t in self.tracks if t.state == 'tentative'])
490
 
491
  return {
492
  'total_tracks': len(self.tracks),
493
  'confirmed_tracks': confirmed,
 
494
  'tentative_tracks': tentative,
495
  'next_id': self.track_id_count
496
  }
497
 
498
-
499
- # Compatibility alias
500
- SimpleTracker = RobustTracker
 
1
  """
2
+ Enhanced Tracking with Track Validation and Improved Accuracy (Enhancement 7)
 
3
  """
4
  import numpy as np
5
  from typing import List, Optional, Tuple, Dict
 
10
  import warnings
11
  warnings.filterwarnings('ignore')
12
 
 
13
  class Track:
14
+ """Enhanced track with validation and quality metrics"""
 
15
  def __init__(self, detection: Detection, track_id: Optional[int] = None):
 
16
  self.track_id = track_id if track_id else self._generate_id()
17
  self.bbox = detection.bbox.copy() if hasattr(detection, 'bbox') else [0, 0, 100, 100]
18
  self.detections = [detection]
 
25
  self.hits = 1
26
  self.consecutive_misses = 0
27
 
28
+ # ENHANCEMENT 7: Track validation
29
+ self.validation_score = 0 # Increases with consistent detections
30
+ self.min_validation_hits = 3 # Require 3 consistent detections
31
+ self.is_validated = False
32
+
33
  # Store center points for trajectory
34
  cx = (self.bbox[0] + self.bbox[2]) / 2
35
  cy = (self.bbox[1] + self.bbox[3]) / 2
 
40
  self.velocity = np.array([0.0, 0.0])
41
  self.acceleration = np.array([0.0, 0.0])
42
 
43
+ # Appearance features
44
  self.appearance_features = []
45
  if hasattr(detection, 'features'):
46
  self.appearance_features.append(detection.features)
47
+
48
+ # Size tracking with validation
49
  self.sizes = deque(maxlen=10)
50
  width = max(1, self.bbox[2] - self.bbox[0])
51
  height = max(1, self.bbox[3] - self.bbox[1])
52
  self.sizes.append((width, height))
53
+ self.initial_size = (width, height)
54
 
55
  # Track quality metrics
56
  self.avg_confidence = self.confidence
57
  self.max_confidence = self.confidence
58
+ self.confidence_history = deque(maxlen=10)
59
+ self.confidence_history.append(self.confidence)
60
+
61
  def _generate_id(self) -> int:
 
62
  return int(uuid.uuid4().int % 100000)
63
+
64
+ def validate_detection_consistency(self, detection: Detection) -> bool:
65
+ """
66
+ ENHANCEMENT 7: Validate detection consistency before accepting
67
+ Checks size similarity and position consistency
68
+ """
69
+ if not hasattr(detection, 'bbox') or len(detection.bbox) != 4:
70
+ return False
71
+
72
+ # Check size similarity
73
+ new_width = detection.bbox[2] - detection.bbox[0]
74
+ new_height = detection.bbox[3] - detection.bbox[1]
75
+
76
+ if self.initial_size[0] > 0 and self.initial_size[1] > 0:
77
+ width_ratio = new_width / self.initial_size[0]
78
+ height_ratio = new_height / self.initial_size[1]
79
+
80
+ # Allow 50% size variation max (prevents wild mismatches)
81
+ if width_ratio < 0.5 or width_ratio > 2.0:
82
+ return False
83
+ if height_ratio < 0.5 or height_ratio > 2.0:
84
+ return False
85
+
86
+ # Check position consistency (not jumping too far)
87
+ new_cx = (detection.bbox[0] + detection.bbox[2]) / 2
88
+ new_cy = (detection.bbox[1] + detection.bbox[3]) / 2
89
+
90
+ if len(self.trajectory) > 0:
91
+ last_cx, last_cy = self.trajectory[-1]
92
+ distance = np.sqrt((new_cx - last_cx)**2 + (new_cy - last_cy)**2)
93
+
94
+ # Max reasonable movement per frame (adjust based on your video)
95
+ max_movement = 100 # pixels
96
+ if distance > max_movement:
97
+ return False
98
 
99
+ return True
100
+
101
  def predict(self):
102
+ """Enhanced motion prediction"""
103
  self.age += 1
104
  self.time_since_update += 1
105
  self.consecutive_misses += 1
106
 
107
  try:
108
  if len(self.trajectory) >= 3:
 
109
  positions = np.array(list(self.trajectory))[-3:]
 
 
110
  self.velocity = positions[-1] - positions[-2]
111
 
112
+ # Limit velocity
113
+ max_velocity = 50
114
  velocity_magnitude = np.linalg.norm(self.velocity)
115
  if velocity_magnitude > max_velocity:
116
  self.velocity = self.velocity / velocity_magnitude * max_velocity
117
 
 
118
  if len(positions) == 3:
119
  prev_velocity = positions[-2] - positions[-3]
120
  self.acceleration = (self.velocity - prev_velocity) * 0.3
121
 
 
122
  predicted_pos = positions[-1] + self.velocity * 0.7 + self.acceleration * 0.1
123
 
 
124
  if self.sizes:
125
  avg_width = np.mean([s[0] for s in self.sizes])
126
  avg_height = np.mean([s[1] for s in self.sizes])
 
128
  avg_width = max(10, self.bbox[2] - self.bbox[0])
129
  avg_height = max(10, self.bbox[3] - self.bbox[1])
130
 
 
131
  self.bbox = [
132
  predicted_pos[0] - avg_width/2,
133
  predicted_pos[1] - avg_height/2,
 
135
  predicted_pos[1] + avg_height/2
136
  ]
137
  except Exception as e:
 
 
138
  pass
139
+
140
  def update(self, detection: Detection):
141
+ """Update track with validation"""
142
  try:
143
+ # ENHANCEMENT 7: Validate consistency
144
+ if not self.validate_detection_consistency(detection):
145
+ print(f" ⚠️ Track {self.track_id}: Rejected inconsistent detection")
146
+ return
147
+
148
  # Update bbox
149
  if hasattr(detection, 'bbox'):
150
  self.bbox = detection.bbox.copy()
151
+ self.detections.append(detection)
 
152
 
153
  # Update confidence
154
  if hasattr(detection, 'confidence'):
155
  self.confidence = detection.confidence
156
+ self.confidence_history.append(self.confidence)
157
+ self.avg_confidence = np.mean(list(self.confidence_history))
158
  self.max_confidence = max(self.max_confidence, self.confidence)
159
 
160
  self.hits += 1
161
  self.time_since_update = 0
162
  self.consecutive_misses = 0
163
 
164
+ # ENHANCEMENT 7: Update validation score
165
+ self.validation_score += 1
166
+ if self.validation_score >= self.min_validation_hits and not self.is_validated:
167
+ self.is_validated = True
168
+ print(f" ✅ Track {self.track_id} validated after {self.validation_score} consistent hits")
169
+
170
  # Update trajectory
171
  cx = (self.bbox[0] + self.bbox[2]) / 2
172
  cy = (self.bbox[1] + self.bbox[3]) / 2
 
177
  height = max(1, self.bbox[3] - self.bbox[1])
178
  self.sizes.append((width, height))
179
 
180
+ # Store appearance features
181
  if hasattr(detection, 'features'):
182
  self.appearance_features.append(detection.features)
183
  if len(self.appearance_features) > 5:
184
  self.appearance_features = self.appearance_features[-5:]
185
 
186
+ # Confirm track after validation
187
+ if self.is_validated and self.state == 'tentative':
188
  self.state = 'confirmed'
189
+
190
+ # Keep only recent detections
191
  if len(self.detections) > 5:
 
192
  for old_det in self.detections[:-5]:
193
  if hasattr(old_det, 'image_crop'):
194
  old_det.image_crop = None
195
  self.detections = self.detections[-5:]
196
+
197
  except Exception as e:
198
  print(f"Track update error: {e}")
199
+
200
  def mark_missed(self):
201
+ """Mark track as missed"""
202
+ # ENHANCEMENT 7: Only delete validated tracks after longer period
203
+ if self.is_validated and self.state == 'confirmed':
204
+ if self.consecutive_misses > 20: # Extended buffer for validated tracks
205
  self.state = 'deleted'
206
+ elif self.time_since_update > 40:
207
  self.state = 'deleted'
208
  elif self.state == 'tentative':
209
  if self.consecutive_misses > 3:
210
  self.state = 'deleted'
211
+
212
+ # Reduce validation score when missed
213
+ self.validation_score = max(0, self.validation_score - 0.5)
214
 
215
 
216
+ class EnhancedTracker:
217
  """
218
+ Enhanced Tracker with Track Validation (Enhancement 7)
219
  """
 
220
  def __init__(self,
221
  match_threshold: float = 0.35,
222
+ track_buffer: int = 40, # Increased from 30
223
  min_iou_for_match: float = 0.15,
224
  use_appearance: bool = False):
 
 
225
 
 
 
 
 
 
 
226
  self.match_threshold = match_threshold
227
  self.track_buffer = track_buffer
228
  self.min_iou_for_match = min_iou_for_match
 
231
  self.tracks: List[Track] = []
232
  self.track_id_count = 1
233
 
234
+ # ENHANCEMENT 7: Size constraints for valid dogs
235
+ self.min_dog_size = 30 # Minimum width/height in pixels
236
+ self.max_dog_size = 800 # Maximum width/height in pixels
237
+
238
  # Enhanced parameters
239
+ self.max_center_distance = 120
240
+ self.min_size_similarity = 0.4
241
 
 
242
  self.debug = False
243
+
244
+ def _is_valid_detection_size(self, detection: Detection) -> bool:
245
+ """ENHANCEMENT 7: Size-based filtering"""
246
+ if not hasattr(detection, 'bbox') or len(detection.bbox) != 4:
247
+ return False
248
 
249
+ width = detection.bbox[2] - detection.bbox[0]
250
+ height = detection.bbox[3] - detection.bbox[1]
251
+
252
+ # Filter too small or too large
253
+ if width < self.min_dog_size or height < self.min_dog_size:
254
+ return False
255
+ if width > self.max_dog_size or height > self.max_dog_size:
256
+ return False
257
+
258
+ # Aspect ratio check (dogs shouldn't be super wide or super tall)
259
+ if width > 0 and height > 0:
260
+ aspect_ratio = width / height
261
+ if aspect_ratio < 0.3 or aspect_ratio > 3.0:
262
+ return False
263
+
264
+ return True
265
+
266
+ def _check_appearance_similarity(self, detection: Detection) -> bool:
267
  """
268
+ ENHANCEMENT 7: Check if detection is too similar to existing tracks
269
+ Prevents duplicate tracks for the same dog
270
  """
271
+ if not hasattr(detection, 'bbox'):
272
+ return True
273
+
274
+ det_center = self._get_center(detection.bbox)
275
+ det_size = (detection.bbox[2] - detection.bbox[0],
276
+ detection.bbox[3] - detection.bbox[1])
277
+
278
+ for track in self.tracks:
279
+ if track.state == 'deleted':
280
+ continue
281
+
282
+ track_center = self._get_center(track.bbox)
283
+ track_size = (track.bbox[2] - track.bbox[0],
284
+ track.bbox[3] - track.bbox[1])
285
+
286
+ # Check center distance
287
+ distance = np.linalg.norm(np.array(det_center) - np.array(track_center))
288
+
289
+ # Check size similarity
290
+ size_diff = abs(det_size[0] - track_size[0]) + abs(det_size[1] - track_size[1])
291
+ avg_size = (det_size[0] + det_size[1] + track_size[0] + track_size[1]) / 4
292
+
293
+ # If very close and similar size, it's likely the same dog
294
+ if distance < 40 and size_diff < avg_size * 0.3:
295
+ return False # Too similar to existing track
296
+
297
+ return True
298
+
299
+ def update(self, detections: List[Detection]) -> List[Track]:
300
+ """Update tracks with enhanced validation"""
301
+
302
+ # ENHANCEMENT 7: Filter detections by size
303
+ valid_detections = [d for d in detections if self._is_valid_detection_size(d)]
304
+
305
+ if len(valid_detections) < len(detections):
306
+ print(f" 🔍 Filtered {len(detections) - len(valid_detections)} invalid size detections")
307
+
308
+ detections = valid_detections
309
+
310
  if not detections:
 
311
  for track in self.tracks:
312
  track.predict()
313
  track.mark_missed()
314
 
315
+ # Move lost tracks to sleeping (call ReID's move_to_sleeping)
316
+ self._handle_lost_tracks()
317
+
318
  self.tracks = [t for t in self.tracks if t.state != 'deleted']
319
  return [t for t in self.tracks if t.state == 'confirmed']
320
 
 
322
  # Predict existing tracks
323
  for track in self.tracks:
324
  track.predict()
325
+
326
  # Split tracks by state
327
  confirmed_tracks = [t for t in self.tracks if t.state == 'confirmed']
328
  tentative_tracks = [t for t in self.tracks if t.state == 'tentative']
329
 
 
330
  matched_track_indices = set()
331
  matched_det_indices = set()
332
 
333
+ # Stage 1: Match confirmed tracks
334
  if confirmed_tracks and detections:
335
  matched_track_indices, matched_det_indices = self._associate_tracks(
336
+ confirmed_tracks, detections,
337
  matched_track_indices, matched_det_indices,
338
  threshold_mult=1.0
339
  )
340
 
341
+ # Stage 2: Match tentative tracks
342
  if tentative_tracks:
343
+ unmatched_dets = [detections[i] for i in range(len(detections))
344
  if i not in matched_det_indices]
345
 
346
  if unmatched_dets:
347
+ temp_det_mapping = [i for i in range(len(detections))
 
348
  if i not in matched_det_indices]
349
 
350
  tent_matched_tracks, tent_matched_dets = self._associate_tracks(
 
353
  threshold_mult=0.7
354
  )
355
 
 
356
  for det_idx in tent_matched_dets:
357
  matched_det_indices.add(temp_det_mapping[det_idx])
358
 
359
+ # Mark unmatched tracks
360
  for i, track in enumerate(confirmed_tracks):
361
  if i not in matched_track_indices:
362
  track.mark_missed()
363
+
364
  for track in tentative_tracks:
365
  if track.time_since_update > 0:
366
  track.mark_missed()
367
 
368
+ # ENHANCEMENT 7: Create new tracks with validation
369
  for det_idx in range(len(detections)):
370
  if det_idx not in matched_det_indices:
371
  detection = detections[det_idx]
372
 
373
+ # Check appearance similarity before creating track
374
+ if self._check_appearance_similarity(detection):
375
  new_track = Track(detection, self.track_id_count)
376
  self.track_id_count += 1
377
  self.tracks.append(new_track)
378
+ else:
379
+ print(f" 🚫 Rejected duplicate track candidate")
380
+
381
+ # Handle lost tracks
382
+ self._handle_lost_tracks()
383
 
384
  # Remove deleted tracks
385
  self.tracks = [t for t in self.tracks if t.state != 'deleted']
386
 
387
+ # Return only validated confirmed tracks
388
+ return [t for t in self.tracks if t.state == 'confirmed' and t.is_validated]
389
+
390
  except Exception as e:
391
  print(f"Tracker update error: {e}")
 
392
  return [t for t in self.tracks if t.state == 'confirmed']
393
 
394
+ def _handle_lost_tracks(self):
395
+ """Handle tracks that are about to be deleted"""
396
+ for track in self.tracks:
397
+ # If track is validated and about to be deleted, could move to sleeping
398
+ if (track.is_validated and
399
+ track.state == 'confirmed' and
400
+ track.consecutive_misses == 18): # Just before deletion at 20
401
+
402
+ # This would be called by demo to move to ReID sleeping tracks
403
+ if hasattr(self, '_reid_callback'):
404
+ self._reid_callback(track.track_id)
405
+
406
+ def set_reid_callback(self, callback):
407
+ """Set callback to ReID for moving tracks to sleeping"""
408
+ self._reid_callback = callback
409
+
410
  def _associate_tracks(self, tracks: List[Track], detections: List[Detection],
411
  existing_matched_tracks: set, existing_matched_dets: set,
412
  threshold_mult: float = 1.0) -> Tuple[set, set]:
413
+ """Track-detection association"""
 
 
 
 
 
414
  if not tracks or not detections:
415
  return existing_matched_tracks, existing_matched_dets
416
+
417
  try:
 
418
  cost_matrix = self._calculate_enhanced_cost_matrix(tracks, detections)
419
 
420
  if cost_matrix.size == 0:
421
  return existing_matched_tracks, existing_matched_dets
422
 
 
423
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
424
 
425
  matched_tracks = existing_matched_tracks.copy()
426
  matched_dets = existing_matched_dets.copy()
427
 
 
428
  for r, c in zip(row_ind, col_ind):
 
429
  if r >= len(tracks) or c >= len(detections):
430
  continue
431
+
 
432
  threshold = (1 - self.match_threshold * threshold_mult)
433
  if cost_matrix[r, c] < threshold:
434
  tracks[r].update(detections[c])
435
  matched_tracks.add(r)
436
  matched_dets.add(c)
 
 
437
 
438
+ return matched_tracks, matched_dets
439
+
440
  except Exception as e:
441
  print(f"Association error: {e}")
442
  return existing_matched_tracks, existing_matched_dets
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  def _calculate_enhanced_cost_matrix(self, tracks: List[Track],
445
  detections: List[Detection]) -> np.ndarray:
446
+ """Calculate cost matrix"""
447
  try:
448
  if not tracks or not detections:
449
  return np.array([])
450
+
451
  n_tracks = len(tracks)
452
  n_dets = len(detections)
453
  cost_matrix = np.ones((n_tracks, n_dets))
 
455
  for t_idx, track in enumerate(tracks):
456
  if not hasattr(track, 'bbox') or len(track.bbox) != 4:
457
  continue
458
+
459
  track_center = np.array(self._get_center(track.bbox))
460
  track_size = np.array([
461
  max(1, track.bbox[2] - track.bbox[0]),
 
465
  for d_idx, detection in enumerate(detections):
466
  if not hasattr(detection, 'bbox') or len(detection.bbox) != 4:
467
  continue
468
+
469
  # IoU cost
470
  iou = self._iou(track.bbox, detection.bbox)
471
 
472
+ # Center distance
473
  det_center = np.array(self._get_center(detection.bbox))
474
  distance = np.linalg.norm(track_center - det_center)
475
 
476
+ # Size similarity
477
  det_size = np.array([
478
  max(1, detection.bbox[2] - detection.bbox[0]),
479
  max(1, detection.bbox[3] - detection.bbox[1])
480
  ])
481
 
 
482
  size_ratio = np.minimum(track_size, det_size) / (np.maximum(track_size, det_size) + 1e-6)
483
  size_cost = 1 - np.mean(size_ratio)
484
 
 
485
  if iou >= self.min_iou_for_match and distance < self.max_center_distance:
486
  iou_cost = 1 - iou
487
  dist_cost = distance / self.max_center_distance
488
 
489
+ total_cost = (0.6 * iou_cost +
490
+ 0.25 * dist_cost +
491
+ 0.15 * size_cost)
 
492
 
493
+ # Boost validated tracks
494
+ if track.is_validated:
495
+ total_cost *= 0.9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
  cost_matrix[t_idx, d_idx] = total_cost
498
  else:
499
  cost_matrix[t_idx, d_idx] = 1.0
 
 
500
 
501
+ return cost_matrix
502
+
503
  except Exception as e:
504
+ print(f"Cost matrix error: {e}")
 
505
  return np.ones((len(tracks), len(detections)))
506
 
507
  def _get_center(self, bbox: List[float]) -> Tuple[float, float]:
508
+ """Get bbox center"""
509
  try:
510
  if len(bbox) >= 4:
511
  return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
 
514
  return (0, 0)
515
 
516
  def _iou(self, bbox1: List[float], bbox2: List[float]) -> float:
517
+ """Calculate IoU"""
518
  try:
519
  if len(bbox1) < 4 or len(bbox2) < 4:
520
  return 0.0
521
+
522
  x1 = max(bbox1[0], bbox2[0])
523
  y1 = max(bbox1[1], bbox2[1])
524
  x2 = min(bbox1[2], bbox2[2])
 
526
 
527
  if x2 < x1 or y2 < y1:
528
  return 0.0
 
 
529
 
530
+ intersection = (x2 - x1) * (y2 - y1)
531
  area1 = max(1, (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]))
532
  area2 = max(1, (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]))
533
  union = area1 + area2 - intersection
534
 
535
  return max(0, min(1, intersection / (union + 1e-6)))
536
+
537
  except Exception as e:
 
538
  return 0.0
539
 
540
  def set_match_threshold(self, threshold: float):
541
  """Update matching threshold"""
542
  self.match_threshold = max(0.1, min(0.8, threshold))
543
+ print(f"Tracking threshold: {self.match_threshold:.2f}")
544
 
545
  def reset(self):
546
+ """Reset tracker"""
547
  self.tracks.clear()
548
  self.track_id_count = 1
549
  print("Tracker reset")
550
 
551
  def get_statistics(self) -> Dict:
552
+ """Get statistics"""
553
  confirmed = len([t for t in self.tracks if t.state == 'confirmed'])
554
+ validated = len([t for t in self.tracks if t.is_validated])
555
  tentative = len([t for t in self.tracks if t.state == 'tentative'])
556
 
557
  return {
558
  'total_tracks': len(self.tracks),
559
  'confirmed_tracks': confirmed,
560
+ 'validated_tracks': validated,
561
  'tentative_tracks': tentative,
562
  'next_id': self.track_id_count
563
  }
564
 
565
+ # Compatibility aliases
566
+ SimpleTracker = EnhancedTracker
567
+ RobustTracker = EnhancedTracker