surabhic's picture
Update reid_processor.py
597ae5c verified
import cv2
import numpy as np
from collections import defaultdict
from ultralytics import YOLO
import os
import gdown # Added for downloading from Google Drive
import tempfile # Added for temporary file handling
import torch # Added: Import torch for safe globals
import ultralytics.nn.tasks
# --- Configuration ---
# Model and video paths will be handled dynamically or by Hugging Face Space environment
# MODEL_PATH needs to be downloaded if not present.
MODEL_DOWNLOAD_ID = "1-5fOSHOSB9UXyP_enOoZNAMScrePVcMD" # Google Drive file ID
MODEL_NAME = "yolov11_model.pt"
MODEL_PATH = MODEL_NAME # Will be downloaded to the current working directory
# Thresholds for object detection and tracking
CONFIDENCE_THRESHOLD = 0.5
IOU_TRACKING_THRESHOLD = 0.3
FEATURE_SIMILARITY_THRESHOLD = 0.5
MAX_LOST_FRAMES = 15
# --- Global Variables for Tracking State ---
# These will be reset for each new video processed by Gradio's function call
_next_player_id = 0
_active_players = {}
_inactive_players = {}
class Player:
"""
Represents a single player being tracked, holding their current state and historical features.
"""
def __init__(self, player_id, bbox, frame_num, features=None):
self.player_id = player_id
self.bbox = bbox
self.last_seen_frame = frame_num
self.features = features
self.lost_frames_count = 0
def update_bbox(self, new_bbox, frame_num):
self.bbox = new_bbox
self.last_seen_frame = frame_num
self.lost_frames_count = 0
def __repr__(self):
return (f"Player(ID:{self.player_id}, Bbox:[{int(self.bbox[0])},{int(self.bbox[1])},"
f"{int(self.bbox[2])},{int(self.bbox[3])}], LastSeen:{self.last_seen_frame})")
def calculate_iou(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
iou = interArea / float(boxAArea + boxBArea - interArea) if (boxAArea + boxBArea - interArea) > 0 else 0.0
return iou
def extract_features(image, bbox):
x1, y1, x2, y2 = map(int, bbox)
h, w, _ = image.shape
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(w, x2)
y2 = min(h, y2)
if x2 <= x1 or y2 <= y1:
return None
cropped_player = image[y1:y2, x1:x2]
if cropped_player.size == 0 or cropped_player.shape[0] == 0 or cropped_player.shape[1] == 0:
return None
try:
hsv_cropped = cv2.cvtColor(cropped_player, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist([hsv_cropped], [0, 1, 2], None, [8, 8, 8],
[0, 180, 0, 256, 0, 256])
hist = cv2.normalize(hist, hist).flatten()
return hist
except cv2.error as e:
return None
def compare_features(features1, features2):
if features1 is None or features2 is None or len(features1) != len(features2):
return 0.0
return cv2.compareHist(features1, features2, cv2.HISTCMP_CORREL)
# Load model globally to avoid re-loading on every function call (efficiency for Gradio)
# Model will be downloaded if not found.
_yolo_model = None
_player_class_id = -1
_model_loaded = False
def _load_yolo_model():
"""Helper function to load the YOLO model once."""
global _yolo_model, _player_class_id, _model_loaded
if _model_loaded:
return True
print("Attempting to load YOLO model...")
if not os.path.exists(MODEL_PATH):
print(f"Model {MODEL_PATH} not found. Attempting to download from Google Drive...")
try:
gdown.download(id=MODEL_DOWNLOAD_ID, output=MODEL_PATH, quiet=False)
print(f"Model downloaded to {MODEL_PATH}")
except Exception as e:
print(f"Error downloading model: {e}")
return False
try:
# UPDATED: Allowlist custom classes AND common PyTorch modules for safe loading
torch.serialization.add_safe_globals([
ultralytics.nn.tasks.DetectionModel,
torch.nn.modules.container.Sequential, # ADD THIS LINE
# Add other Ultralytics model types here if they cause similar errors
# e.g., ultralytics.nn.tasks.SegmentationModel, ultralytics.nn.tasks.PoseModel etc.
])
_yolo_model = YOLO(MODEL_PATH)
if torch.cuda.is_available():
_yolo_model.to('cuda')
print("YOLO model moved to CUDA (GPU) device.")
else:
print("CUDA (GPU) is not available, YOLO model will run on CPU.")
# ADD THIS LINE: Convert model to half-precision (FP16) - do this AFTER moving to device
_yolo_model.half()
print("YOLO model set to half-precision (FP16).")
print("Model Class Names:", _yolo_model.names)
found_player_class = False
for class_id, class_name in _yolo_model.names.items():
if class_name.lower() == 'player':
_player_class_id = class_id
found_player_class = True
break
if not found_player_class:
print("Error: 'player' class not found in model's names. Check model training.")
return False
print(f"Detected 'player' class ID: {_player_class_id}")
_model_loaded = True
return True
except Exception as e:
print(f"Error loading YOLO model: {e}")
return False
# Main processing function to be called by Gradio
def process_reid_video(input_video_path):
"""
Processes an input video for player re-identification and returns the path to the output video.
"""
global _next_player_id, _active_players, _inactive_players
# Reset global tracking variables for each new video submission
_next_player_id = 0
_active_players = {}
_inactive_players = {}
# Ensure model is loaded
if not _model_loaded and not _load_yolo_model():
print("Failed to load YOLO model. Cannot process video.")
# Create a dummy video or raise an error for Gradio to display
dummy_output_path = os.path.join(tempfile.gettempdir(), "error_output.mp4")
dummy_writer = cv2.VideoWriter(dummy_output_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (640, 480))
if dummy_writer.isOpened():
blank_frame = np.zeros((480, 640, 3), dtype=np.uint8)
cv2.putText(blank_frame, "ERROR: Model Failed to Load!", (50, 240),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
for _ in range(30): # Write a few frames
dummy_writer.write(blank_frame)
dummy_writer.release()
return dummy_output_path
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
print(f"Error: Could not open input video {input_video_path}")
return None # Gradio will show an error if None is returned
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Create a temporary output file for Gradio
temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
OUTPUT_VIDEO_PATH = temp_output_file.name
temp_output_file.close() # Close the file handle as cv2.VideoWriter needs to open it
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(OUTPUT_VIDEO_PATH, fourcc, fps, (frame_width, frame_height))
if not out.isOpened():
print(f"Error: Could not open video writer for {OUTPUT_VIDEO_PATH}.")
cap.release()
return None
print(f"Processing video: {input_video_path}")
frame_num = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_num += 1
# print(f"Processing Frame: {frame_num}") # Keep this commented for cleaner Gradio logs
all_raw_detections = []
player_detections_for_tracking = []
results = _yolo_model(frame, verbose=False,half=True,imgsz=640)
for r in results:
# Ensure results.boxes.data is not empty
if r.boxes and r.boxes.data is not None:
for *xyxy, conf, cls in r.boxes.data:
bbox = xyxy
confidence = float(conf)
class_id = int(cls)
class_name = _yolo_model.names.get(class_id, "unknown")
all_raw_detections.append({
'bbox': bbox,
'confidence': confidence,
'class_id': class_id,
'class_name': class_name
})
if class_id == _player_class_id and confidence > CONFIDENCE_THRESHOLD:
player_detections_for_tracking.append(bbox)
current_frame_assigned_ids = []
matched_detections_indices = set()
# 2a. Short-term Tracking (IoU-based)
for i, det_bbox in enumerate(player_detections_for_tracking):
best_match_player_id = -1
max_iou = 0.0
for player_id, player in list(_active_players.items()):
iou = calculate_iou(player.bbox, det_bbox)
if iou > max_iou:
max_iou = iou
best_match_player_id = player_id
if max_iou >= IOU_TRACKING_THRESHOLD:
_active_players[best_match_player_id].update_bbox(det_bbox, frame_num)
current_frame_assigned_ids.append(best_match_player_id)
matched_detections_indices.add(i)
# 2b. Re-identification (Feature-based)
unmatched_detections = [det_bbox for i, det_bbox in enumerate(player_detections_for_tracking)
if i not in matched_detections_indices]
for det_bbox in unmatched_detections:
player_features = extract_features(frame, det_bbox)
if player_features is None:
continue
best_reid_match_id = -1
max_similarity = 0.0
for player_id, player in list(_inactive_players.items()):
similarity = compare_features(player_features, player.features)
if similarity > max_similarity:
max_similarity = similarity
best_reid_match_id = player_id
if max_similarity >= FEATURE_SIMILARITY_THRESHOLD:
reidentified_player = _inactive_players.pop(best_reid_match_id)
reidentified_player.update_bbox(det_bbox, frame_num)
reidentified_player.features = player_features
_active_players[best_reid_match_id] = reidentified_player
current_frame_assigned_ids.append(best_reid_match_id)
else:
new_player = Player(_next_player_id, det_bbox, frame_num, player_features)
_active_players[_next_player_id] = new_player
current_frame_assigned_ids.append(_next_player_id)
_next_player_id += 1
# 2c. Update Player Status
players_to_deactivate = []
for player_id, player in list(_active_players.items()):
if player_id not in current_frame_assigned_ids:
player.lost_frames_count += 1
if player.lost_frames_count > MAX_LOST_FRAMES:
players_to_deactivate.append(player_id)
for player_id in players_to_deactivate:
player = _active_players.pop(player_id)
_inactive_players[player_id] = player
# --- 3. Visualization ---
display_frame = frame.copy()
# Debug Visualization Layer (raw YOLO detections)
for det in all_raw_detections:
x1, y1, x2, y2 = map(int, det['bbox'])
class_name = det['class_name']
confidence = det['confidence']
color = (0, 0, 255) # Red for other/low confidence
if class_name.lower() == 'player':
if confidence > CONFIDENCE_THRESHOLD:
color = (0, 255, 0) # Green for players detected above threshold
else:
color = (0, 128, 0) # Darker green for players below threshold
elif class_name.lower() == 'ball':
color = (255, 255, 0) # Cyan
elif class_name.lower() == 'goalkeeper':
color = (0, 165, 255) # Orange
elif class_name.lower() == 'referee':
color = (255, 0, 255) # Magenta
cv2.rectangle(display_frame, (x1, y1), (x2, y2), color, 1)
cv2.putText(display_frame, f"{class_name}: {confidence:.2f}", (x1, y1 - 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
# Final Visualization Layer (TRACKED players with consistent IDs)
for player_id, player in _active_players.items():
x1, y1, x2, y2 = map(int, player.bbox)
cv2.rectangle(display_frame, (x1, y1), (x2, y2), (0, 255, 255), 3) # Thicker, yellow box
cv2.putText(display_frame, f"ID: {player.player_id}", (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
out.write(display_frame)
cap.release()
out.release()
print(f"Processing finished. Output video saved to: {OUTPUT_VIDEO_PATH}")
return OUTPUT_VIDEO_PATH
# Call model loading function once when the module is imported
_load_yolo_model()