mustafa2ak commited on
Commit
ed48ba6
ยท
verified ยท
1 Parent(s): 96d23c9

Update reid.py

Browse files
Files changed (1) hide show
  1. reid.py +698 -278
reid.py CHANGED
@@ -1,285 +1,705 @@
1
  """
2
- reid.py - Dog Re-Identification
3
- Uses:
4
- - Dog pose model (YOLOv8 trained on pose)
5
- - Ensemble CNN features (ResNet50, EfficientNet, ViT)
6
- - Color histograms
7
- - Temporal coherence
8
  """
9
-
10
- import time
11
- import numpy as np
12
  import cv2
13
- import torch
14
- import torch.nn as nn
15
- import torchvision.models as models
16
- import torchvision.transforms as transforms
17
- from sklearn.metrics.pairwise import cosine_similarity
18
- from typing import Dict, List, Optional, Tuple
19
- from ultralytics import YOLO
20
- from tracking import Track
21
-
22
- # Try importing timm
23
- try:
24
- import timm
25
- TIMM_AVAILABLE = True
26
- except ImportError:
27
- TIMM_AVAILABLE = False
28
- print("โš ๏ธ timm not available. Run: pip install timm")
29
-
30
-
31
- # ---------------- Pose Feature Extractor ----------------
32
-
33
- class PoseFeatureExtractor:
34
- """Extract keypoint-based features from dogs using YOLO pose model."""
35
-
36
- def __init__(self, model_path="dog-pose-trained.pt", embedding_norm=True):
37
- self.model = YOLO(model_path)
38
- self.embedding_norm = embedding_norm
39
-
40
- def __call__(self, image: np.ndarray) -> Optional[np.ndarray]:
41
- results = self.model(image, verbose=False)
42
- if len(results) == 0 or not hasattr(results[0], "keypoints"):
43
- return None
44
-
45
- if len(results[0].keypoints.xy) == 0:
46
- return None
47
-
48
- kps = results[0].keypoints.xy[0].cpu().numpy().flatten().astype(np.float32)
49
-
50
- if self.embedding_norm:
51
- norm = np.linalg.norm(kps)
52
- if norm > 0:
53
- kps = kps / norm
54
-
55
- return kps
56
-
57
-
58
- # ---------------- Ensemble Feature Extractor ----------------
59
-
60
- class EnsembleFeatureExtractor:
61
- """Extract features from multiple pretrained CNN/ViT models."""
62
-
63
- def __init__(self, device: str = "cuda"):
64
- self.device = device if torch.cuda.is_available() else "cpu"
65
-
66
- # ResNet50
67
- self.resnet = models.resnet50(weights="IMAGENET1K_V1")
68
- self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
69
- self.resnet.to(self.device).eval()
70
-
71
- # EfficientNet + ViT
72
- if TIMM_AVAILABLE:
73
- self.efficientnet = timm.create_model("efficientnet_b0", pretrained=True, num_classes=0)
74
- self.vit = timm.create_model("vit_small_patch16_224", pretrained=True, num_classes=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  else:
76
- self.efficientnet = models.efficientnet_b0(weights="IMAGENET1K_V1")
77
- self.efficientnet.classifier = nn.Identity()
78
- self.vit = models.resnet152(weights="IMAGENET1K_V1")
79
- self.vit = nn.Sequential(*list(self.vit.children())[:-1])
80
-
81
- self.efficientnet.to(self.device).eval()
82
- self.vit.to(self.device).eval()
83
-
84
- self.transform = transforms.Compose([
85
- transforms.ToPILImage(),
86
- transforms.Resize((224, 224)),
87
- transforms.ToTensor(),
88
- transforms.Normalize([0.485, 0.456, 0.406],
89
- [0.229, 0.224, 0.225])
90
- ])
91
-
92
- def _extract(self, model, image):
93
- t = self.transform(image).unsqueeze(0).to(self.device)
94
- with torch.no_grad():
95
- feat = model(t).squeeze().cpu().numpy()
96
- feat = feat / (np.linalg.norm(feat) + 1e-7)
97
- return feat
98
-
99
- def extract(self, image: np.ndarray) -> Optional[np.ndarray]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  try:
101
- feats = []
102
- feats.append(self._extract(self.resnet, image))
103
- feats.append(self._extract(self.efficientnet, image))
104
- feats.append(self._extract(self.vit, image))
105
- emb = np.concatenate(feats)
106
- emb = emb / (np.linalg.norm(emb) + 1e-7)
107
- return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- print("โŒ Ensemble extraction error:", e)
110
- return None
111
-
112
-
113
- # ---------------- ReID System ----------------
114
-
115
- class DogReID:
116
- """
117
- Dog Re-ID using Pose + CNN + Color + Temporal coherence.
118
- """
119
-
120
- def __init__(self,
121
- pose_model_path="dog-pose-trained.pt",
122
- similarity_threshold=0.7,
123
- device="cuda",
124
- W_pose=0.3, W_cnn=0.4, W_color=0.2, W_temp=0.1):
125
-
126
- self.device = device if torch.cuda.is_available() else "cpu"
127
- self.similarity_threshold = similarity_threshold
128
-
129
- # weights
130
- self.W_pose = W_pose
131
- self.W_cnn = W_cnn
132
- self.W_color = W_color
133
- self.W_temp = W_temp
134
-
135
- # feature extractors
136
- self.pose_extractor = PoseFeatureExtractor(pose_model_path)
137
- self.ensemble_extractor = EnsembleFeatureExtractor(device=self.device)
138
-
139
- # original storage
140
- self.pose_db: Dict[int, List[np.ndarray]] = {}
141
- self.cnn_db: Dict[int, List[np.ndarray]] = {}
142
- self.color_db: Dict[int, List[np.ndarray]] = {}
143
- self.last_seen: Dict[int, Tuple[float, float, float]] = {}
144
- self.track_to_dog: Dict[int, int] = {}
145
- self.next_id = 1
146
-
147
- # โœ… compatibility attributes for app.py
148
- self.dog_database: Dict[int, Dict[str, list]] = {}
149
- self.dog_images: Dict[int, list] = {}
150
- self.next_dog_id: int = 0
151
-
152
- # --------- Color Histograms ---------
153
- def extract_color_histogram(self, image: np.ndarray) -> Optional[np.ndarray]:
154
- hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
155
- hist = cv2.calcHist([hsv], [0, 1, 2], None,
156
- [60, 64, 64], [0, 180, 0, 256, 0, 256])
157
- hist = cv2.normalize(hist, hist).flatten()
158
- return hist
159
-
160
- def compare_color(self, h1, h2) -> float:
161
- return cv2.compareHist(h1.astype("float32"),
162
- h2.astype("float32"),
163
- cv2.HISTCMP_CORREL)
164
-
165
- # --------- Temporal Coherence ---------
166
- def temporal_score(self, track: Track, dog_id: int) -> float:
167
- if dog_id not in self.last_seen:
168
- return 1.0
169
- last_x, last_y, last_t = self.last_seen[dog_id]
170
- bbox = track.bbox
171
- cx = (bbox[0] + bbox[2]) / 2
172
- cy = (bbox[1] + bbox[3]) / 2
173
- dt = time.time() - last_t
174
- dist = np.hypot(cx - last_x, cy - last_y)
175
- max_dist = 500 * dt
176
- return 1.0 if dist <= max_dist else 0.5
177
-
178
- # --------- Match or Register ---------
179
- def match_or_register(self, track: Track) -> Tuple[int, float]:
180
- det = None
181
- for d in reversed(track.detections):
182
- if d.image_crop is not None:
183
- det = d
184
- break
185
- if det is None:
186
- return 0, 0.0
187
-
188
- # extract features
189
- pose_feat = self.pose_extractor(det.image_crop)
190
- cnn_feat = self.ensemble_extractor.extract(det.image_crop)
191
- color_hist = self.extract_color_histogram(det.image_crop)
192
-
193
- if pose_feat is None and cnn_feat is None:
194
- return 0, 0.0
195
-
196
- # match loop
197
- best_id, best_score = None, -1
198
- for dog_id in set(self.pose_db.keys()) | set(self.cnn_db.keys()):
199
- # pose sim
200
- S_pose = 0
201
- if pose_feat is not None and dog_id in self.pose_db:
202
- S_pose = cosine_similarity(pose_feat.reshape(1, -1),
203
- np.mean(self.pose_db[dog_id], axis=0).reshape(1, -1))[0, 0]
204
-
205
- # cnn sim
206
- S_cnn = 0
207
- if cnn_feat is not None and dog_id in self.cnn_db:
208
- S_cnn = cosine_similarity(cnn_feat.reshape(1, -1),
209
- np.mean(self.cnn_db[dog_id], axis=0).reshape(1, -1))[0, 0]
210
-
211
- # color sim
212
- S_color = 0
213
- if color_hist is not None and dog_id in self.color_db:
214
- S_color = np.mean([self.compare_color(color_hist, h) for h in self.color_db[dog_id]])
215
-
216
- # temporal sim
217
- S_temp = self.temporal_score(track, dog_id)
218
-
219
- # weighted score
220
- score = (self.W_pose * S_pose +
221
- self.W_cnn * S_cnn +
222
- self.W_color * S_color +
223
- self.W_temp * S_temp)
224
-
225
- if score > best_score:
226
- best_id, best_score = dog_id, score
227
-
228
- # decide
229
- if best_id is not None and best_score >= self.similarity_threshold:
230
- if pose_feat is not None:
231
- self.pose_db[best_id].append(pose_feat)
232
- if cnn_feat is not None:
233
- self.cnn_db[best_id].append(cnn_feat)
234
- if color_hist is not None:
235
- self.color_db.setdefault(best_id, []).append(color_hist)
236
- bbox = track.bbox
237
- cx, cy = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
238
- self.last_seen[best_id] = (cx, cy, time.time())
239
- self.track_to_dog[track.track_id] = best_id
240
- return best_id, best_score
241
  else:
242
- dog_id = self.next_id
243
- self.next_id += 1
244
- if pose_feat is not None:
245
- self.pose_db[dog_id] = [pose_feat]
246
- if cnn_feat is not None:
247
- self.cnn_db[dog_id] = [cnn_feat]
248
- if color_hist is not None:
249
- self.color_db[dog_id] = [color_hist]
250
- bbox = track.bbox
251
- cx, cy = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
252
- self.last_seen[dog_id] = (cx, cy, time.time())
253
- self.track_to_dog[track.track_id] = dog_id
254
- return dog_id, best_score
255
-
256
- def count(self):
257
- return len(set(self.pose_db.keys()) | set(self.cnn_db.keys()))
258
- @property
259
- def dog_count(self):
260
- return len(self.dog_database)
261
-
262
- def set_threshold(self, threshold: float):
263
- self.similarity_threshold = threshold
264
-
265
- def save_database(self, path="dog_reid_db.json"):
266
- data = {
267
- "dog_database": {str(k): v for k, v in self.dog_database.items()},
268
- "dog_images": {str(k): v for k, v in self.dog_images.items()},
269
- "next_dog_id": self.next_dog_id,
270
- }
271
- with open(path, "w") as f:
272
- json.dump(data, f)
273
-
274
- def load_database(self, path="dog_reid_db.json"):
275
- if not os.path.exists(path):
276
- return
277
- with open(path, "r") as f:
278
- data = json.load(f)
279
- self.dog_database = {int(k): v for k, v in data.get("dog_database", {}).items()}
280
- self.dog_images = {int(k): v for k, v in data.get("dog_images", {}).items()}
281
- self.next_dog_id = data.get("next_dog_id", 0)
282
-
283
- # Backward compatibility
284
- SimpleReID = DogReID
285
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ enhanced_gradio_app.py - Enhanced Gradio Interface for Dog Monitoring
3
+ With SQLite database, dataset curation, and export features
 
 
 
 
4
  """
5
+ import gradio as gr
 
 
6
  import cv2
7
+ import numpy as np
8
+ import pandas as pd
9
+ from PIL import Image
10
+ import time
11
+ import json
12
+ import zipfile
13
+ import tempfile
14
+ from pathlib import Path
15
+ from typing import List, Dict, Optional, Tuple
16
+ from datetime import datetime
17
+
18
+ # Import core modules
19
+ from detection import DogDetector
20
+ from tracking import SimpleTracker
21
+ from reid import SimplifiedDogReID
22
+ from database import DogDatabase
23
+ from threshold_optimizer import ThresholdOptimizer
24
+
25
+ class EnhancedDogMonitoringApp:
26
+ """Enhanced app with database and dataset management"""
27
+
28
+ def __init__(self, db_path: str = "dog_monitoring.db"):
29
+ """Initialize the enhanced monitoring system"""
30
+ # Core components
31
+ self.detector = DogDetector(device='cuda')
32
+ self.tracker = SimpleTracker()
33
+ self.reid = SimplifiedDogReID(device='cuda')
34
+
35
+ # Database
36
+ self.db = DogDatabase(db_path)
37
+
38
+ # Threshold optimizer
39
+ self.threshold_optimizer = ThresholdOptimizer()
40
+
41
+ # Processing parameters
42
+ self.detection_confidence = 0.45
43
+ self.reid_threshold = 0.7
44
+ self.process_every_n_frames = 3
45
+
46
+ # Current session info
47
+ self.current_video_path = None
48
+ self.current_frame_count = 0
49
+ self.processing_active = False
50
+
51
+ def process_video(self, video_path: str, progress=None):
52
+ """Process video with database storage"""
53
+ if not video_path:
54
+ return None, [], "No video uploaded"
55
+
56
+ self.current_video_path = video_path
57
+ self.processing_active = True
58
+
59
+ # Reset tracking for new video
60
+ self.tracker = SimpleTracker()
61
+ self.reid.reset()
62
+
63
+ # Open video
64
+ cap = cv2.VideoCapture(video_path)
65
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ fps = cap.get(cv2.CAP_PROP_FPS)
67
+
68
+ # Prepare output
69
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
70
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
71
+ output_path = "output_video.mp4"
72
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
73
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
74
+
75
+ frame_count = 0
76
+ dogs_in_video = set()
77
+
78
+ while self.processing_active:
79
+ ret, frame = cap.read()
80
+ if not ret:
81
+ break
82
+
83
+ frame_count += 1
84
+ self.current_frame_count = frame_count
85
+
86
+ # Update progress
87
+ if progress and total_frames > 0:
88
+ progress(frame_count / total_frames,
89
+ f"Processing frame {frame_count}/{total_frames}")
90
+
91
+ # Process every Nth frame
92
+ if frame_count % self.process_every_n_frames == 0:
93
+ # Detect dogs
94
+ detections = self.detector.detect(frame)
95
+
96
+ # Update tracker
97
+ tracks = self.tracker.update(detections)
98
+
99
+ # Process each track
100
+ for track in tracks:
101
+ # Re-identify
102
+ dog_id, confidence = self.reid.match_or_register(track)
103
+
104
+ if dog_id > 0:
105
+ dogs_in_video.add(dog_id)
106
+
107
+ # Save to database
108
+ self._save_to_database(
109
+ dog_id, track, confidence,
110
+ frame_count, video_path
111
+ )
112
+
113
+ # Draw on frame
114
+ self._draw_track(frame, track, dog_id, confidence)
115
+
116
+ # Feed to optimizer
117
+ self.threshold_optimizer.add_reid_sample(
118
+ similarity=confidence,
119
+ matched_dog_id=dog_id
120
+ )
121
+
122
+ # Add overlay
123
+ self._add_overlay(frame, frame_count, len(dogs_in_video))
124
+
125
+ # Write frame
126
+ out.write(frame)
127
+
128
+ cap.release()
129
+ out.release()
130
+
131
+ self.processing_active = False
132
+
133
+ # Create summary
134
+ summary = f"Processed {frame_count} frames, detected {len(dogs_in_video)} unique dogs"
135
+
136
+ return output_path, self._get_dog_gallery(), summary
137
+
138
+ def _save_to_database(self, dog_id: int, track, confidence: float,
139
+ frame_number: int, video_source: str):
140
+ """Save dog data to database"""
141
+ # Ensure dog exists in database
142
+ self.db.add_dog(dog_id)
143
+
144
+ # Get latest detection with image
145
+ detection = None
146
+ for det in reversed(track.detections):
147
+ if det.image_crop is not None:
148
+ detection = det
149
+ break
150
+
151
+ if detection:
152
+ # Save image
153
+ self.db.save_image(
154
+ dog_id=dog_id,
155
+ image=detection.image_crop,
156
+ frame_number=frame_number,
157
+ video_source=video_source,
158
+ bbox=detection.bbox,
159
+ confidence=confidence
160
+ )
161
+
162
+ # Save features
163
+ features = self.reid.dog_database.get(dog_id, [])
164
+ if features:
165
+ latest_feature = features[-1]
166
+ self.db.save_features(
167
+ dog_id=dog_id,
168
+ resnet_features=latest_feature.resnet_features,
169
+ color_histogram=latest_feature.color_histogram,
170
+ confidence=confidence
171
+ )
172
+
173
+ # Save sighting
174
+ bbox = detection.bbox
175
+ position = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
176
+ self.db.add_sighting(
177
+ dog_id=dog_id,
178
+ position=position,
179
+ video_source=video_source,
180
+ frame_number=frame_number,
181
+ confidence=confidence
182
+ )
183
+
184
+ # Update dog sighting count
185
+ self.db.update_dog_sighting(dog_id)
186
+
187
+ def _draw_track(self, frame: np.ndarray, track, dog_id: int, confidence: float):
188
+ """Draw bounding box with dog ID"""
189
+ bbox = track.bbox
190
+ x1, y1, x2, y2 = map(int, bbox)
191
+
192
+ # Color based on confidence
193
+ if confidence > 0.8:
194
+ color = (0, 255, 0)
195
+ elif confidence > 0.6:
196
+ color = (0, 165, 255)
197
  else:
198
+ color = (0, 0, 255)
199
+
200
+ # Draw box and label
201
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
202
+
203
+ label = f"Dog #{dog_id} ({confidence:.0%})"
204
+ label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
205
+ cv2.rectangle(frame, (x1, y1 - label_size[1] - 10),
206
+ (x1 + label_size[0], y1), color, -1)
207
+ cv2.putText(frame, label, (x1, y1 - 5),
208
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
209
+
210
+ def _add_overlay(self, frame: np.ndarray, frame_count: int, dog_count: int):
211
+ """Add info overlay to frame"""
212
+ h, w = frame.shape[:2]
213
+
214
+ # Semi-transparent background
215
+ overlay = frame.copy()
216
+ cv2.rectangle(overlay, (10, 10), (250, 80), (0, 0, 0), -1)
217
+ frame[:] = cv2.addWeighted(overlay, 0.3, frame, 0.7, 0)
218
+
219
+ # Add text
220
+ cv2.putText(frame, f"Frame: {frame_count}", (20, 35),
221
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
222
+ cv2.putText(frame, f"Dogs: {dog_count}", (20, 60),
223
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
224
+
225
+ def _get_dog_gallery(self) -> List[Tuple[np.ndarray, str]]:
226
+ """Get gallery of detected dogs from database"""
227
+ gallery = []
228
+ dogs = self.db.get_all_dogs()
229
+
230
+ for _, dog in dogs.head(12).iterrows():
231
+ images = self.db.get_dog_images(dog['dog_id'], include_discarded=False)
232
+ if images:
233
+ img = images[0]['image']
234
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
235
+ caption = f"Dog #{dog['dog_id']} | Sightings: {dog['total_sightings']}"
236
+ gallery.append((img_rgb, caption))
237
+
238
+ return gallery
239
+
240
+ # ========== Dataset Management Functions ==========
241
+
242
+ def get_dog_list(self) -> pd.DataFrame:
243
+ """Get list of all dogs for management"""
244
+ dogs = self.db.get_all_dogs()
245
+ return dogs[['dog_id', 'name', 'first_seen', 'last_seen', 'total_sightings', 'status']]
246
+
247
+ def get_dog_images_for_review(self, dog_id: int) -> Tuple[List, List]:
248
+ """Get dog images and body parts for review/validation"""
249
+ if dog_id is None:
250
+ return [], []
251
+
252
+ dog_id = int(dog_id)
253
+
254
+ # Get full images
255
+ images = self.db.get_dog_images(dog_id, include_discarded=True)
256
+
257
+ display_images = []
258
+ for img_data in images:
259
+ img_rgb = cv2.cvtColor(img_data['image'], cv2.COLOR_BGR2RGB)
260
+ status = "โœ“" if img_data['is_validated'] else "โœ—" if img_data['is_discarded'] else "?"
261
+ display_images.append({
262
+ 'image': img_rgb,
263
+ 'image_id': img_data['image_id'],
264
+ 'status': status,
265
+ 'confidence': img_data['confidence']
266
+ })
267
+
268
+ # Get body parts
269
+ body_parts = self.db.get_body_parts(dog_id, include_discarded=True)
270
+
271
+ display_parts = []
272
+ for part_data in body_parts:
273
+ part_rgb = cv2.cvtColor(part_data['image'], cv2.COLOR_BGR2RGB)
274
+ status = "โœ“" if part_data['is_validated'] else "โœ—" if part_data.get('is_discarded') else "?"
275
+ display_parts.append({
276
+ 'image': part_rgb,
277
+ 'part_id': part_data['part_id'],
278
+ 'part_type': part_data['part_type'],
279
+ 'status': status,
280
+ 'confidence': part_data['confidence']
281
+ })
282
+
283
+ return display_images, display_parts
284
+
285
+ def validate_body_parts(self, part_ids: List[int], action: str) -> str:
286
+ """Validate or discard body parts"""
287
+ if not part_ids:
288
+ return "No parts selected"
289
+
290
+ count = 0
291
+ for part_id in part_ids:
292
+ if action == "validate":
293
+ self.db.validate_body_part(part_id, is_valid=True)
294
+ elif action == "discard":
295
+ self.db.validate_body_part(part_id, is_valid=False)
296
+ count += 1
297
+
298
+ return f"Updated {count} body parts"
299
+
300
+ def validate_images(self, dog_id: int, image_ids: List[int], action: str) -> str:
301
+ """Validate or discard images"""
302
+ if not image_ids:
303
+ return "No images selected"
304
+
305
+ count = 0
306
+ for img_id in image_ids:
307
+ if action == "validate":
308
+ self.db.validate_image(img_id, is_valid=True)
309
+ elif action == "discard":
310
+ self.db.validate_image(img_id, is_valid=False)
311
+ count += 1
312
+
313
+ return f"Updated {count} images"
314
+
315
+ def merge_dogs_handler(self, keep_id: int, merge_id: int) -> str:
316
+ """Handle dog merging"""
317
+ if keep_id == merge_id:
318
+ return "Cannot merge dog with itself"
319
+
320
+ if self.db.merge_dogs(keep_id, merge_id):
321
+ return f"Successfully merged Dog #{merge_id} into Dog #{keep_id}"
322
+ else:
323
+ return "Failed to merge dogs"
324
+
325
+ def export_dataset(self, output_format: str, validated_only: bool) -> str:
326
+ """Export dataset for training"""
327
  try:
328
+ # Create temporary directory
329
+ with tempfile.TemporaryDirectory() as temp_dir:
330
+ # Export dataset
331
+ metadata = self.db.export_training_dataset(
332
+ temp_dir,
333
+ validated_only=validated_only
334
+ )
335
+
336
+ # Create zip file
337
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
338
+ zip_path = f"dog_dataset_{timestamp}.zip"
339
+
340
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
341
+ for root, dirs, files in Path(temp_dir).walk():
342
+ for file in files:
343
+ file_path = Path(root) / file
344
+ arcname = file_path.relative_to(temp_dir)
345
+ zipf.write(file_path, arcname)
346
+
347
+ return f"Dataset exported: {zip_path} ({metadata['total_images']} images)"
348
+
349
  except Exception as e:
350
+ return f"Export failed: {str(e)}"
351
+
352
+ def reset_database_handler(self, confirm_text: str) -> str:
353
+ """Handle database reset"""
354
+ if confirm_text.lower() != "reset":
355
+ return "Type 'reset' to confirm database reset"
356
+
357
+ if self.db.reset_database(confirm=True):
358
+ self.reid.reset()
359
+ return "Database reset successfully"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  else:
361
+ return "Database reset failed"
362
+
363
+ def get_database_stats(self) -> str:
364
+ """Get database statistics"""
365
+ stats = self.db.get_database_statistics()
366
+
367
+ return f"""
368
+ ๐Ÿ“Š Database Statistics:
369
+ โ€ข Active Dogs: {stats['total_active_dogs']}
370
+ โ€ข Total Images: {stats['total_images']}
371
+ โ€ข Validated Images: {stats['validated_images']}
372
+ โ€ข Total Sightings: {stats['total_sightings']}
373
+
374
+ ๐Ÿ† Most Seen: {stats.get('most_seen_dog', {}).get('name', 'None')}
375
+ ({stats.get('most_seen_dog', {}).get('sightings', 0)} sightings)
376
+ """
377
+
378
+ def stop_processing(self):
379
+ """Stop video processing"""
380
+ self.processing_active = False
381
+ return "Processing stopped"
382
+
383
+ # ========== Gradio Interface ==========
384
+
385
+ def create_interface(self) -> gr.Blocks:
386
+ """Create enhanced Gradio interface"""
387
+ with gr.Blocks(title="Enhanced Dog Monitoring System", theme=gr.themes.Soft()) as app:
388
+ gr.Markdown("""
389
+ # ๐Ÿ• Enhanced Stray Dog Monitoring System
390
+ **Detection โ€ข Tracking โ€ข Re-ID โ€ข Database โ€ข Dataset Export**
391
+ """)
392
+
393
+ with gr.Tabs():
394
+ # Tab 1: Video Processing
395
+ with gr.Tab("๐Ÿ“น Process Video"):
396
+ with gr.Row():
397
+ with gr.Column(scale=1):
398
+ video_input = gr.Video(label="Upload Video")
399
+
400
+ with gr.Row():
401
+ process_btn = gr.Button("โ–ถ๏ธ Process", variant="primary")
402
+ stop_btn = gr.Button("โน๏ธ Stop")
403
+
404
+ # Settings
405
+ gr.Markdown("### Settings")
406
+ detection_slider = gr.Slider(
407
+ 0.1, 0.9, 0.45, step=0.05,
408
+ label="Detection Confidence"
409
+ )
410
+ reid_slider = gr.Slider(
411
+ 0.3, 0.95, 0.7, step=0.05,
412
+ label="ReID Threshold"
413
+ )
414
+ frame_skip = gr.Slider(
415
+ 1, 10, 3, step=1,
416
+ label="Process Every N Frames"
417
+ )
418
+
419
+ with gr.Column(scale=2):
420
+ video_output = gr.Video(label="Processed Video")
421
+ processing_status = gr.Textbox(label="Status")
422
+ dog_gallery = gr.Gallery(
423
+ label="Detected Dogs",
424
+ columns=4,
425
+ rows=3
426
+ )
427
+
428
+ # Tab 2: Dog Management
429
+ with gr.Tab("๐Ÿถ Manage Dogs"):
430
+ with gr.Row():
431
+ with gr.Column(scale=1):
432
+ gr.Markdown("### Dog Registry")
433
+ refresh_btn = gr.Button("๐Ÿ”„ Refresh List")
434
+ dog_table = gr.Dataframe(
435
+ headers=["ID", "Name", "First Seen", "Last Seen", "Sightings", "Status"],
436
+ interactive=False
437
+ )
438
+
439
+ # Merge dogs
440
+ gr.Markdown("### Merge Dogs")
441
+ with gr.Row():
442
+ keep_dog_id = gr.Number(label="Keep Dog ID", precision=0)
443
+ merge_dog_id = gr.Number(label="Merge Dog ID", precision=0)
444
+ merge_btn = gr.Button("๐Ÿ”€ Merge Dogs")
445
+ merge_status = gr.Textbox(label="Merge Status")
446
+
447
+ with gr.Column(scale=2):
448
+ gr.Markdown("### Review Dog Images & Body Parts")
449
+ selected_dog_id = gr.Number(
450
+ label="Dog ID to Review",
451
+ precision=0
452
+ )
453
+ load_images_btn = gr.Button("๐Ÿ“ท Load Images")
454
+
455
+ with gr.Tab("Full Images"):
456
+ dog_images_gallery = gr.Gallery(
457
+ label="Full Dog Images",
458
+ columns=4,
459
+ rows=3,
460
+ height=300
461
+ )
462
+
463
+ with gr.Row():
464
+ selected_images = gr.CheckboxGroup(
465
+ label="Select Images",
466
+ choices=[]
467
+ )
468
+
469
+ with gr.Row():
470
+ validate_imgs_btn = gr.Button("โœ… Validate Selected")
471
+ discard_imgs_btn = gr.Button("โŒ Discard Selected")
472
+
473
+ with gr.Tab("Body Parts"):
474
+ body_parts_gallery = gr.Gallery(
475
+ label="Body Part Crops",
476
+ columns=4,
477
+ rows=3,
478
+ height=300
479
+ )
480
+
481
+ gr.Markdown("""
482
+ **Part Types**:
483
+ - Head (top 35% of dog)
484
+ - Torso (middle 40%)
485
+ - Rear (bottom 40%)
486
+
487
+ Validate correctly cropped parts, discard mixed/wrong crops.
488
+ """)
489
+
490
+ with gr.Row():
491
+ selected_parts = gr.CheckboxGroup(
492
+ label="Select Parts",
493
+ choices=[]
494
+ )
495
+
496
+ with gr.Row():
497
+ validate_parts_btn = gr.Button("โœ… Validate Parts")
498
+ discard_parts_btn = gr.Button("โŒ Discard Parts")
499
+
500
+ validation_status = gr.Textbox(label="Validation Status")
501
+
502
+ # Tab 3: Dataset Export
503
+ with gr.Tab("๐Ÿ’พ Export Dataset"):
504
+ gr.Markdown("""
505
+ ### Export Training Dataset
506
+ Export validated dog images for ResNet fine-tuning
507
+ """)
508
+
509
+ with gr.Row():
510
+ with gr.Column():
511
+ export_format = gr.Radio(
512
+ ["Images + CSV", "COCO Format", "YOLO Format"],
513
+ value="Images + CSV",
514
+ label="Export Format"
515
+ )
516
+
517
+ validated_only = gr.Checkbox(
518
+ label="Export validated images only",
519
+ value=True
520
+ )
521
+
522
+ export_btn = gr.Button("๐Ÿ“ฆ Export Dataset", variant="primary")
523
+ export_status = gr.Textbox(label="Export Status")
524
+
525
+ gr.Markdown("""
526
+ ### Dataset Info
527
+ The exported dataset includes:
528
+ - Individual dog images organized by ID
529
+ - CSV file with labels and metadata
530
+ - Bounding box annotations
531
+ - Pose keypoints (if available)
532
+
533
+ Use this dataset to fine-tune ResNet for better re-identification!
534
+ """)
535
+
536
+ with gr.Column():
537
+ stats_display = gr.Textbox(
538
+ label="Database Statistics",
539
+ lines=10
540
+ )
541
+ refresh_stats_btn = gr.Button("๐Ÿ“Š Refresh Stats")
542
+
543
+ # Tab 4: Database Management
544
+ with gr.Tab("โš™๏ธ Database"):
545
+ gr.Markdown("### Database Management")
546
+
547
+ with gr.Row():
548
+ with gr.Column():
549
+ gr.Markdown("""
550
+ โš ๏ธ **Warning**: Resetting the database will delete all data!
551
+ Type 'reset' to confirm.
552
+ """)
553
+
554
+ reset_confirm = gr.Textbox(
555
+ label="Type 'reset' to confirm",
556
+ placeholder="reset"
557
+ )
558
+
559
+ reset_btn = gr.Button("๐Ÿ—‘๏ธ Reset Database", variant="stop")
560
+ reset_status = gr.Textbox(label="Reset Status")
561
+
562
+ gr.Markdown("### Database Optimization")
563
+ optimize_btn = gr.Button("๐Ÿ”ง Optimize Database")
564
+ optimize_status = gr.Textbox(label="Optimization Status")
565
+
566
+ # Event handlers
567
+
568
+ # Video processing
569
+ process_btn.click(
570
+ self.process_video,
571
+ inputs=[video_input],
572
+ outputs=[video_output, dog_gallery, processing_status]
573
+ )
574
+
575
+ stop_btn.click(
576
+ self.stop_processing,
577
+ outputs=[processing_status]
578
+ )
579
+
580
+ # Settings updates
581
+ detection_slider.change(
582
+ lambda v: setattr(self, 'detection_confidence', v) or f"Detection: {v:.2f}",
583
+ inputs=[detection_slider],
584
+ outputs=[processing_status]
585
+ )
586
+
587
+ reid_slider.change(
588
+ lambda v: setattr(self, 'reid_threshold', v) or f"ReID: {v:.2f}",
589
+ inputs=[reid_slider],
590
+ outputs=[processing_status]
591
+ )
592
+
593
+ frame_skip.change(
594
+ lambda v: setattr(self, 'process_every_n_frames', int(v)) or f"Skip: {int(v)}",
595
+ inputs=[frame_skip],
596
+ outputs=[processing_status]
597
+ )
598
+
599
+ # Dog management
600
+ refresh_btn.click(
601
+ self.get_dog_list,
602
+ outputs=[dog_table]
603
+ )
604
+
605
+ merge_btn.click(
606
+ self.merge_dogs_handler,
607
+ inputs=[keep_dog_id, merge_dog_id],
608
+ outputs=[merge_status]
609
+ )
610
+
611
+ def load_dog_images(dog_id):
612
+ if dog_id is None:
613
+ return [], [], [], [], "No dog selected"
614
+
615
+ images, parts = self.get_dog_images_for_review(int(dog_id))
616
+
617
+ # Format full images
618
+ img_gallery = [(img['image'], f"{img['status']} | {img['confidence']:.1%}")
619
+ for img in images]
620
+ img_choices = [f"Image {i+1}" for i in range(len(images))]
621
+
622
+ # Format body parts with type labels
623
+ part_gallery = [(p['image'], f"{p['part_type'].upper()} {p['status']} | {p['confidence']:.1%}")
624
+ for p in parts]
625
+ part_choices = [f"{p['part_type'].capitalize()} {i+1}" for i, p in enumerate(parts)]
626
+
627
+ return (img_gallery, gr.update(choices=img_choices),
628
+ part_gallery, gr.update(choices=part_choices),
629
+ f"Loaded {len(images)} images, {len(parts)} body parts")
630
+
631
+ load_images_btn.click(
632
+ load_dog_images,
633
+ inputs=[selected_dog_id],
634
+ outputs=[dog_images_gallery, selected_images,
635
+ body_parts_gallery, selected_parts, validation_status]
636
+ )
637
+
638
+ # Validate/discard body parts
639
+ validate_parts_btn.click(
640
+ self.validate_body_parts,
641
+ inputs=[selected_parts, gr.State("validate")],
642
+ outputs=[validation_status]
643
+ )
644
+
645
+ discard_parts_btn.click(
646
+ self.validate_body_parts,
647
+ inputs=[selected_parts, gr.State("discard")],
648
+ outputs=[validation_status]
649
+ )
650
+
651
+ # Dataset export
652
+ export_btn.click(
653
+ self.export_dataset,
654
+ inputs=[export_format, validated_only],
655
+ outputs=[export_status]
656
+ )
657
+
658
+ refresh_stats_btn.click(
659
+ self.get_database_stats,
660
+ outputs=[stats_display]
661
+ )
662
+
663
+ # Database management
664
+ reset_btn.click(
665
+ self.reset_database_handler,
666
+ inputs=[reset_confirm],
667
+ outputs=[reset_status]
668
+ )
669
+
670
+ def optimize_database():
671
+ self.db.vacuum()
672
+ return "Database optimized"
673
+
674
+ optimize_btn.click(
675
+ optimize_database,
676
+ outputs=[optimize_status]
677
+ )
678
+
679
+ # Load initial data
680
+ app.load(
681
+ self.get_dog_list,
682
+ outputs=[dog_table]
683
+ )
684
+
685
+ app.load(
686
+ self.get_database_stats,
687
+ outputs=[stats_display]
688
+ )
689
+
690
+ return app
691
+
692
+ def main():
693
+ """Launch the enhanced application"""
694
+ app = EnhancedDogMonitoringApp()
695
+ interface = app.create_interface()
696
+
697
+ interface.launch(
698
+ server_name="0.0.0.0",
699
+ server_port=7860,
700
+ share=False,
701
+ show_error=True
702
+ )
703
+
704
+ if __name__ == "__main__":
705
+ main()