File size: 5,755 Bytes
2f08e4e
e23353b
 
2f08e4e
ed48ba6
e53ac82
 
1b0864d
e53ac82
e23353b
 
 
8d924f5
11e442b
e23353b
 
23b3953
e23353b
e53ac82
19bce7a
e23353b
 
d3ec202
e23353b
 
d3ec202
f6982b8
11e442b
8d924f5
e23353b
 
23b3953
e23353b
 
 
 
8d924f5
1b0864d
d3ec202
1b0864d
 
 
23b3953
9c944d2
d3ec202
9c944d2
 
 
 
e23353b
8d924f5
e23353b
1b0864d
e23353b
 
 
 
 
 
11e442b
e23353b
11e442b
e23353b
 
 
 
 
 
 
 
 
 
11e442b
 
e23353b
d3ec202
11e442b
 
 
 
d3ec202
11e442b
 
e23353b
23b3953
 
11e442b
e23353b
d3ec202
e23353b
11e442b
50025d3
e23353b
23b3953
e23353b
 
 
23b3953
c680668
203884d
e23353b
2f08e4e
70e918d
a26442f
70e918d
 
19bce7a
e23353b
 
d3ec202
e23353b
 
 
 
23b3953
e23353b
 
 
2f08e4e
e23353b
 
 
 
 
 
203884d
e23353b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c680668
e23353b
d3ec202
 
e23353b
c14191c
e23353b
23b3953
e23353b
 
 
 
23b3953
e23353b
d3ec202
e23353b
d3ec202
e23353b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Simplified ReID - Basic threshold matching only
No adaptive thresholds, no quality scoring, no smart storage
"""
import numpy as np
import cv2
import torch
import timm
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, Optional
from collections import defaultdict
from datetime import datetime


class SimplifiedReID:
    """Simplified ReID with basic threshold matching"""
    
    def __init__(self, device: str = 'cuda'):
        self.device = device if torch.cuda.is_available() else 'cpu'
        
        # Single threshold for all matching
        self.threshold = 0.40
        
        # Session tracking (temp IDs)
        self.temp_id_features = {}  # temp_id -> list of feature vectors
        self.next_temp_id = 1
        self.current_frame = 0
        self.current_video_source = "unknown"
        
        # Initialize model
        self._initialize_model()
        
        print(f"Simplified ReID initialized on {self.device}")
    
    def _initialize_model(self):
        """Load MegaDescriptor model"""
        try:
            self.model = timm.create_model(
                'hf-hub:BVRA/MegaDescriptor-L-384',
                pretrained=True
            )
            self.model.to(self.device).eval()
            
            self.transform = timm.data.create_transform(
                input_size=(384, 384),
                is_training=False,
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]
            )
            print("MegaDescriptor-L-384 loaded")
        except Exception as e:
            print(f"Error loading model: {e}")
            self.model = None
    
    def set_threshold(self, threshold: float):
        """Set matching threshold"""
        self.threshold = max(0.10, min(0.95, threshold))
        print(f"Threshold set to: {self.threshold:.2f}")
    
    def set_video_source(self, video_path: str):
        """Set current video source"""
        self.current_video_source = video_path
    
    def reset_session(self):
        """Clear session data"""
        self.temp_id_features.clear()
        self.next_temp_id = 1
        self.current_frame = 0
        print("Session reset")
    
    def extract_features(self, image: np.ndarray) -> Optional[np.ndarray]:
        """Extract feature vector from image"""
        if image is None or image.size == 0 or self.model is None:
            return None
        
        try:
            img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            from PIL import Image
            pil_img = Image.fromarray(img_rgb)
            img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                features = self.model(img_tensor)
            
            features = features.squeeze().cpu().numpy()
            features = features / (np.linalg.norm(features) + 1e-7)
            
            return features
        except Exception as e:
            print(f"Feature extraction error: {e}")
            return None
    
    def match_or_register(self, track) -> Dict:
        """
        Simple matching: compare features against existing temp_ids
        If match found -> return temp_id
        If no match -> create new temp_id
        """
        self.current_frame += 1
        
        # Get image crop from track
        detection = None
        for det in reversed(track.detections[-3:]):
            if det.image_crop is not None:
                detection = det
                break
        
        if detection is None or detection.image_crop is None:
            return {'temp_id': 0}
        
        # Extract features
        features = self.extract_features(detection.image_crop)
        if features is None:
            return {'temp_id': 0}
        
        # Search for match in existing temp_ids
        best_temp_id = None
        best_score = -1.0
        
        for temp_id, features_list in self.temp_id_features.items():
            # Compare against all stored features for this temp_id
            similarities = []
            for stored_features in features_list:
                sim = np.dot(features, stored_features)
                similarities.append(sim)
            
            if similarities:
                max_sim = max(similarities)
                if max_sim > best_score:
                    best_score = max_sim
                    best_temp_id = temp_id
        
        # Check if best score passes threshold
        if best_temp_id is not None and best_score >= self.threshold:
            # Match found - add features to existing temp_id
            self.temp_id_features[best_temp_id].append(features)
            
            # Limit storage (keep last 30)
            if len(self.temp_id_features[best_temp_id]) > 30:
                self.temp_id_features[best_temp_id] = self.temp_id_features[best_temp_id][-30:]
            
            return {
                'temp_id': best_temp_id,
                'confidence': best_score,
                'match_type': 'existing'
            }
        else:
            # No match - create new temp_id
            new_temp_id = self.next_temp_id
            self.next_temp_id += 1
            self.temp_id_features[new_temp_id] = [features]
            
            print(f"New temp_id: {new_temp_id} (threshold: {self.threshold:.2f})")
            
            return {
                'temp_id': new_temp_id,
                'confidence': 1.0,
                'match_type': 'new'
            }
    
    def get_statistics(self) -> Dict:
        """Get simple statistics"""
        return {
            'temp_ids': len(self.temp_id_features),
            'threshold': self.threshold,
            'current_frame': self.current_frame
        }