import cv2 import numpy as np from collections import defaultdict from ultralytics import YOLO import os import gdown # Added for downloading from Google Drive import tempfile # Added for temporary file handling import torch # Added: Import torch for safe globals import ultralytics.nn.tasks # --- Configuration --- # Model and video paths will be handled dynamically or by Hugging Face Space environment # MODEL_PATH needs to be downloaded if not present. MODEL_DOWNLOAD_ID = "1-5fOSHOSB9UXyP_enOoZNAMScrePVcMD" # Google Drive file ID MODEL_NAME = "yolov11_model.pt" MODEL_PATH = MODEL_NAME # Will be downloaded to the current working directory # Thresholds for object detection and tracking CONFIDENCE_THRESHOLD = 0.5 IOU_TRACKING_THRESHOLD = 0.3 FEATURE_SIMILARITY_THRESHOLD = 0.5 MAX_LOST_FRAMES = 15 # --- Global Variables for Tracking State --- # These will be reset for each new video processed by Gradio's function call _next_player_id = 0 _active_players = {} _inactive_players = {} class Player: """ Represents a single player being tracked, holding their current state and historical features. """ def __init__(self, player_id, bbox, frame_num, features=None): self.player_id = player_id self.bbox = bbox self.last_seen_frame = frame_num self.features = features self.lost_frames_count = 0 def update_bbox(self, new_bbox, frame_num): self.bbox = new_bbox self.last_seen_frame = frame_num self.lost_frames_count = 0 def __repr__(self): return (f"Player(ID:{self.player_id}, Bbox:[{int(self.bbox[0])},{int(self.bbox[1])}," f"{int(self.bbox[2])},{int(self.bbox[3])}], LastSeen:{self.last_seen_frame})") def calculate_iou(boxA, boxB): xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxA[3], boxB[3]) interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) iou = interArea / float(boxAArea + boxBArea - interArea) if (boxAArea + boxBArea - interArea) > 0 else 0.0 return iou def extract_features(image, bbox): x1, y1, x2, y2 = map(int, bbox) h, w, _ = image.shape x1 = max(0, x1) y1 = max(0, y1) x2 = min(w, x2) y2 = min(h, y2) if x2 <= x1 or y2 <= y1: return None cropped_player = image[y1:y2, x1:x2] if cropped_player.size == 0 or cropped_player.shape[0] == 0 or cropped_player.shape[1] == 0: return None try: hsv_cropped = cv2.cvtColor(cropped_player, cv2.COLOR_BGR2HSV) hist = cv2.calcHist([hsv_cropped], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256]) hist = cv2.normalize(hist, hist).flatten() return hist except cv2.error as e: return None def compare_features(features1, features2): if features1 is None or features2 is None or len(features1) != len(features2): return 0.0 return cv2.compareHist(features1, features2, cv2.HISTCMP_CORREL) # Load model globally to avoid re-loading on every function call (efficiency for Gradio) # Model will be downloaded if not found. _yolo_model = None _player_class_id = -1 _model_loaded = False def _load_yolo_model(): """Helper function to load the YOLO model once.""" global _yolo_model, _player_class_id, _model_loaded if _model_loaded: return True print("Attempting to load YOLO model...") if not os.path.exists(MODEL_PATH): print(f"Model {MODEL_PATH} not found. Attempting to download from Google Drive...") try: gdown.download(id=MODEL_DOWNLOAD_ID, output=MODEL_PATH, quiet=False) print(f"Model downloaded to {MODEL_PATH}") except Exception as e: print(f"Error downloading model: {e}") return False try: # UPDATED: Allowlist custom classes AND common PyTorch modules for safe loading torch.serialization.add_safe_globals([ ultralytics.nn.tasks.DetectionModel, torch.nn.modules.container.Sequential, # ADD THIS LINE # Add other Ultralytics model types here if they cause similar errors # e.g., ultralytics.nn.tasks.SegmentationModel, ultralytics.nn.tasks.PoseModel etc. ]) _yolo_model = YOLO(MODEL_PATH) if torch.cuda.is_available(): _yolo_model.to('cuda') print("YOLO model moved to CUDA (GPU) device.") else: print("CUDA (GPU) is not available, YOLO model will run on CPU.") # ADD THIS LINE: Convert model to half-precision (FP16) - do this AFTER moving to device _yolo_model.half() print("YOLO model set to half-precision (FP16).") print("Model Class Names:", _yolo_model.names) found_player_class = False for class_id, class_name in _yolo_model.names.items(): if class_name.lower() == 'player': _player_class_id = class_id found_player_class = True break if not found_player_class: print("Error: 'player' class not found in model's names. Check model training.") return False print(f"Detected 'player' class ID: {_player_class_id}") _model_loaded = True return True except Exception as e: print(f"Error loading YOLO model: {e}") return False # Main processing function to be called by Gradio def process_reid_video(input_video_path): """ Processes an input video for player re-identification and returns the path to the output video. """ global _next_player_id, _active_players, _inactive_players # Reset global tracking variables for each new video submission _next_player_id = 0 _active_players = {} _inactive_players = {} # Ensure model is loaded if not _model_loaded and not _load_yolo_model(): print("Failed to load YOLO model. Cannot process video.") # Create a dummy video or raise an error for Gradio to display dummy_output_path = os.path.join(tempfile.gettempdir(), "error_output.mp4") dummy_writer = cv2.VideoWriter(dummy_output_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (640, 480)) if dummy_writer.isOpened(): blank_frame = np.zeros((480, 640, 3), dtype=np.uint8) cv2.putText(blank_frame, "ERROR: Model Failed to Load!", (50, 240), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) for _ in range(30): # Write a few frames dummy_writer.write(blank_frame) dummy_writer.release() return dummy_output_path cap = cv2.VideoCapture(input_video_path) if not cap.isOpened(): print(f"Error: Could not open input video {input_video_path}") return None # Gradio will show an error if None is returned frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) # Create a temporary output file for Gradio temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) OUTPUT_VIDEO_PATH = temp_output_file.name temp_output_file.close() # Close the file handle as cv2.VideoWriter needs to open it fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(OUTPUT_VIDEO_PATH, fourcc, fps, (frame_width, frame_height)) if not out.isOpened(): print(f"Error: Could not open video writer for {OUTPUT_VIDEO_PATH}.") cap.release() return None print(f"Processing video: {input_video_path}") frame_num = 0 while True: ret, frame = cap.read() if not ret: break frame_num += 1 # print(f"Processing Frame: {frame_num}") # Keep this commented for cleaner Gradio logs all_raw_detections = [] player_detections_for_tracking = [] results = _yolo_model(frame, verbose=False,half=True,imgsz=640) for r in results: # Ensure results.boxes.data is not empty if r.boxes and r.boxes.data is not None: for *xyxy, conf, cls in r.boxes.data: bbox = xyxy confidence = float(conf) class_id = int(cls) class_name = _yolo_model.names.get(class_id, "unknown") all_raw_detections.append({ 'bbox': bbox, 'confidence': confidence, 'class_id': class_id, 'class_name': class_name }) if class_id == _player_class_id and confidence > CONFIDENCE_THRESHOLD: player_detections_for_tracking.append(bbox) current_frame_assigned_ids = [] matched_detections_indices = set() # 2a. Short-term Tracking (IoU-based) for i, det_bbox in enumerate(player_detections_for_tracking): best_match_player_id = -1 max_iou = 0.0 for player_id, player in list(_active_players.items()): iou = calculate_iou(player.bbox, det_bbox) if iou > max_iou: max_iou = iou best_match_player_id = player_id if max_iou >= IOU_TRACKING_THRESHOLD: _active_players[best_match_player_id].update_bbox(det_bbox, frame_num) current_frame_assigned_ids.append(best_match_player_id) matched_detections_indices.add(i) # 2b. Re-identification (Feature-based) unmatched_detections = [det_bbox for i, det_bbox in enumerate(player_detections_for_tracking) if i not in matched_detections_indices] for det_bbox in unmatched_detections: player_features = extract_features(frame, det_bbox) if player_features is None: continue best_reid_match_id = -1 max_similarity = 0.0 for player_id, player in list(_inactive_players.items()): similarity = compare_features(player_features, player.features) if similarity > max_similarity: max_similarity = similarity best_reid_match_id = player_id if max_similarity >= FEATURE_SIMILARITY_THRESHOLD: reidentified_player = _inactive_players.pop(best_reid_match_id) reidentified_player.update_bbox(det_bbox, frame_num) reidentified_player.features = player_features _active_players[best_reid_match_id] = reidentified_player current_frame_assigned_ids.append(best_reid_match_id) else: new_player = Player(_next_player_id, det_bbox, frame_num, player_features) _active_players[_next_player_id] = new_player current_frame_assigned_ids.append(_next_player_id) _next_player_id += 1 # 2c. Update Player Status players_to_deactivate = [] for player_id, player in list(_active_players.items()): if player_id not in current_frame_assigned_ids: player.lost_frames_count += 1 if player.lost_frames_count > MAX_LOST_FRAMES: players_to_deactivate.append(player_id) for player_id in players_to_deactivate: player = _active_players.pop(player_id) _inactive_players[player_id] = player # --- 3. Visualization --- display_frame = frame.copy() # Debug Visualization Layer (raw YOLO detections) for det in all_raw_detections: x1, y1, x2, y2 = map(int, det['bbox']) class_name = det['class_name'] confidence = det['confidence'] color = (0, 0, 255) # Red for other/low confidence if class_name.lower() == 'player': if confidence > CONFIDENCE_THRESHOLD: color = (0, 255, 0) # Green for players detected above threshold else: color = (0, 128, 0) # Darker green for players below threshold elif class_name.lower() == 'ball': color = (255, 255, 0) # Cyan elif class_name.lower() == 'goalkeeper': color = (0, 165, 255) # Orange elif class_name.lower() == 'referee': color = (255, 0, 255) # Magenta cv2.rectangle(display_frame, (x1, y1), (x2, y2), color, 1) cv2.putText(display_frame, f"{class_name}: {confidence:.2f}", (x1, y1 - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1) # Final Visualization Layer (TRACKED players with consistent IDs) for player_id, player in _active_players.items(): x1, y1, x2, y2 = map(int, player.bbox) cv2.rectangle(display_frame, (x1, y1), (x2, y2), (0, 255, 255), 3) # Thicker, yellow box cv2.putText(display_frame, f"ID: {player.player_id}", (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2) out.write(display_frame) cap.release() out.release() print(f"Processing finished. Output video saved to: {OUTPUT_VIDEO_PATH}") return OUTPUT_VIDEO_PATH # Call model loading function once when the module is imported _load_yolo_model()