|
import cv2 |
|
import numpy as np |
|
from collections import defaultdict |
|
from ultralytics import YOLO |
|
import os |
|
import gdown |
|
import tempfile |
|
import torch |
|
import ultralytics.nn.tasks |
|
|
|
|
|
|
|
|
|
MODEL_DOWNLOAD_ID = "1-5fOSHOSB9UXyP_enOoZNAMScrePVcMD" |
|
MODEL_NAME = "yolov11_model.pt" |
|
MODEL_PATH = MODEL_NAME |
|
|
|
|
|
CONFIDENCE_THRESHOLD = 0.5 |
|
IOU_TRACKING_THRESHOLD = 0.3 |
|
FEATURE_SIMILARITY_THRESHOLD = 0.5 |
|
MAX_LOST_FRAMES = 15 |
|
|
|
|
|
|
|
_next_player_id = 0 |
|
_active_players = {} |
|
_inactive_players = {} |
|
|
|
class Player: |
|
""" |
|
Represents a single player being tracked, holding their current state and historical features. |
|
""" |
|
def __init__(self, player_id, bbox, frame_num, features=None): |
|
self.player_id = player_id |
|
self.bbox = bbox |
|
self.last_seen_frame = frame_num |
|
self.features = features |
|
self.lost_frames_count = 0 |
|
|
|
def update_bbox(self, new_bbox, frame_num): |
|
self.bbox = new_bbox |
|
self.last_seen_frame = frame_num |
|
self.lost_frames_count = 0 |
|
|
|
def __repr__(self): |
|
return (f"Player(ID:{self.player_id}, Bbox:[{int(self.bbox[0])},{int(self.bbox[1])}," |
|
f"{int(self.bbox[2])},{int(self.bbox[3])}], LastSeen:{self.last_seen_frame})") |
|
|
|
def calculate_iou(boxA, boxB): |
|
xA = max(boxA[0], boxB[0]) |
|
yA = max(boxA[1], boxB[1]) |
|
xB = min(boxA[2], boxB[2]) |
|
yB = min(boxA[3], boxB[3]) |
|
|
|
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) |
|
|
|
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) |
|
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) |
|
|
|
iou = interArea / float(boxAArea + boxBArea - interArea) if (boxAArea + boxBArea - interArea) > 0 else 0.0 |
|
return iou |
|
|
|
def extract_features(image, bbox): |
|
x1, y1, x2, y2 = map(int, bbox) |
|
|
|
h, w, _ = image.shape |
|
x1 = max(0, x1) |
|
y1 = max(0, y1) |
|
x2 = min(w, x2) |
|
y2 = min(h, y2) |
|
|
|
if x2 <= x1 or y2 <= y1: |
|
return None |
|
|
|
cropped_player = image[y1:y2, x1:x2] |
|
|
|
if cropped_player.size == 0 or cropped_player.shape[0] == 0 or cropped_player.shape[1] == 0: |
|
return None |
|
|
|
try: |
|
hsv_cropped = cv2.cvtColor(cropped_player, cv2.COLOR_BGR2HSV) |
|
hist = cv2.calcHist([hsv_cropped], [0, 1, 2], None, [8, 8, 8], |
|
[0, 180, 0, 256, 0, 256]) |
|
hist = cv2.normalize(hist, hist).flatten() |
|
return hist |
|
except cv2.error as e: |
|
return None |
|
|
|
def compare_features(features1, features2): |
|
if features1 is None or features2 is None or len(features1) != len(features2): |
|
return 0.0 |
|
return cv2.compareHist(features1, features2, cv2.HISTCMP_CORREL) |
|
|
|
|
|
|
|
_yolo_model = None |
|
_player_class_id = -1 |
|
_model_loaded = False |
|
|
|
def _load_yolo_model(): |
|
"""Helper function to load the YOLO model once.""" |
|
global _yolo_model, _player_class_id, _model_loaded |
|
|
|
if _model_loaded: |
|
return True |
|
|
|
print("Attempting to load YOLO model...") |
|
if not os.path.exists(MODEL_PATH): |
|
print(f"Model {MODEL_PATH} not found. Attempting to download from Google Drive...") |
|
try: |
|
gdown.download(id=MODEL_DOWNLOAD_ID, output=MODEL_PATH, quiet=False) |
|
print(f"Model downloaded to {MODEL_PATH}") |
|
except Exception as e: |
|
print(f"Error downloading model: {e}") |
|
return False |
|
|
|
try: |
|
|
|
torch.serialization.add_safe_globals([ |
|
ultralytics.nn.tasks.DetectionModel, |
|
torch.nn.modules.container.Sequential, |
|
|
|
|
|
]) |
|
|
|
_yolo_model = YOLO(MODEL_PATH) |
|
if torch.cuda.is_available(): |
|
_yolo_model.to('cuda') |
|
print("YOLO model moved to CUDA (GPU) device.") |
|
else: |
|
print("CUDA (GPU) is not available, YOLO model will run on CPU.") |
|
|
|
|
|
_yolo_model.half() |
|
print("YOLO model set to half-precision (FP16).") |
|
|
|
print("Model Class Names:", _yolo_model.names) |
|
found_player_class = False |
|
for class_id, class_name in _yolo_model.names.items(): |
|
if class_name.lower() == 'player': |
|
_player_class_id = class_id |
|
found_player_class = True |
|
break |
|
|
|
if not found_player_class: |
|
print("Error: 'player' class not found in model's names. Check model training.") |
|
return False |
|
|
|
print(f"Detected 'player' class ID: {_player_class_id}") |
|
_model_loaded = True |
|
return True |
|
|
|
except Exception as e: |
|
print(f"Error loading YOLO model: {e}") |
|
return False |
|
|
|
|
|
def process_reid_video(input_video_path): |
|
""" |
|
Processes an input video for player re-identification and returns the path to the output video. |
|
""" |
|
global _next_player_id, _active_players, _inactive_players |
|
|
|
|
|
_next_player_id = 0 |
|
_active_players = {} |
|
_inactive_players = {} |
|
|
|
|
|
if not _model_loaded and not _load_yolo_model(): |
|
print("Failed to load YOLO model. Cannot process video.") |
|
|
|
dummy_output_path = os.path.join(tempfile.gettempdir(), "error_output.mp4") |
|
dummy_writer = cv2.VideoWriter(dummy_output_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (640, 480)) |
|
if dummy_writer.isOpened(): |
|
blank_frame = np.zeros((480, 640, 3), dtype=np.uint8) |
|
cv2.putText(blank_frame, "ERROR: Model Failed to Load!", (50, 240), |
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
|
for _ in range(30): |
|
dummy_writer.write(blank_frame) |
|
dummy_writer.release() |
|
return dummy_output_path |
|
|
|
|
|
cap = cv2.VideoCapture(input_video_path) |
|
if not cap.isOpened(): |
|
print(f"Error: Could not open input video {input_video_path}") |
|
return None |
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
|
|
|
temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) |
|
OUTPUT_VIDEO_PATH = temp_output_file.name |
|
temp_output_file.close() |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(OUTPUT_VIDEO_PATH, fourcc, fps, (frame_width, frame_height)) |
|
|
|
if not out.isOpened(): |
|
print(f"Error: Could not open video writer for {OUTPUT_VIDEO_PATH}.") |
|
cap.release() |
|
return None |
|
|
|
print(f"Processing video: {input_video_path}") |
|
frame_num = 0 |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame_num += 1 |
|
|
|
|
|
all_raw_detections = [] |
|
player_detections_for_tracking = [] |
|
|
|
results = _yolo_model(frame, verbose=False,half=True,imgsz=640) |
|
for r in results: |
|
|
|
if r.boxes and r.boxes.data is not None: |
|
for *xyxy, conf, cls in r.boxes.data: |
|
bbox = xyxy |
|
confidence = float(conf) |
|
class_id = int(cls) |
|
class_name = _yolo_model.names.get(class_id, "unknown") |
|
|
|
all_raw_detections.append({ |
|
'bbox': bbox, |
|
'confidence': confidence, |
|
'class_id': class_id, |
|
'class_name': class_name |
|
}) |
|
|
|
if class_id == _player_class_id and confidence > CONFIDENCE_THRESHOLD: |
|
player_detections_for_tracking.append(bbox) |
|
|
|
current_frame_assigned_ids = [] |
|
matched_detections_indices = set() |
|
|
|
|
|
for i, det_bbox in enumerate(player_detections_for_tracking): |
|
best_match_player_id = -1 |
|
max_iou = 0.0 |
|
|
|
for player_id, player in list(_active_players.items()): |
|
iou = calculate_iou(player.bbox, det_bbox) |
|
if iou > max_iou: |
|
max_iou = iou |
|
best_match_player_id = player_id |
|
|
|
if max_iou >= IOU_TRACKING_THRESHOLD: |
|
_active_players[best_match_player_id].update_bbox(det_bbox, frame_num) |
|
current_frame_assigned_ids.append(best_match_player_id) |
|
matched_detections_indices.add(i) |
|
|
|
|
|
unmatched_detections = [det_bbox for i, det_bbox in enumerate(player_detections_for_tracking) |
|
if i not in matched_detections_indices] |
|
|
|
for det_bbox in unmatched_detections: |
|
player_features = extract_features(frame, det_bbox) |
|
if player_features is None: |
|
continue |
|
|
|
best_reid_match_id = -1 |
|
max_similarity = 0.0 |
|
|
|
for player_id, player in list(_inactive_players.items()): |
|
similarity = compare_features(player_features, player.features) |
|
if similarity > max_similarity: |
|
max_similarity = similarity |
|
best_reid_match_id = player_id |
|
|
|
if max_similarity >= FEATURE_SIMILARITY_THRESHOLD: |
|
reidentified_player = _inactive_players.pop(best_reid_match_id) |
|
reidentified_player.update_bbox(det_bbox, frame_num) |
|
reidentified_player.features = player_features |
|
_active_players[best_reid_match_id] = reidentified_player |
|
current_frame_assigned_ids.append(best_reid_match_id) |
|
else: |
|
new_player = Player(_next_player_id, det_bbox, frame_num, player_features) |
|
_active_players[_next_player_id] = new_player |
|
current_frame_assigned_ids.append(_next_player_id) |
|
_next_player_id += 1 |
|
|
|
|
|
players_to_deactivate = [] |
|
for player_id, player in list(_active_players.items()): |
|
if player_id not in current_frame_assigned_ids: |
|
player.lost_frames_count += 1 |
|
if player.lost_frames_count > MAX_LOST_FRAMES: |
|
players_to_deactivate.append(player_id) |
|
|
|
for player_id in players_to_deactivate: |
|
player = _active_players.pop(player_id) |
|
_inactive_players[player_id] = player |
|
|
|
|
|
display_frame = frame.copy() |
|
|
|
|
|
for det in all_raw_detections: |
|
x1, y1, x2, y2 = map(int, det['bbox']) |
|
class_name = det['class_name'] |
|
confidence = det['confidence'] |
|
|
|
color = (0, 0, 255) |
|
if class_name.lower() == 'player': |
|
if confidence > CONFIDENCE_THRESHOLD: |
|
color = (0, 255, 0) |
|
else: |
|
color = (0, 128, 0) |
|
elif class_name.lower() == 'ball': |
|
color = (255, 255, 0) |
|
elif class_name.lower() == 'goalkeeper': |
|
color = (0, 165, 255) |
|
elif class_name.lower() == 'referee': |
|
color = (255, 0, 255) |
|
|
|
cv2.rectangle(display_frame, (x1, y1), (x2, y2), color, 1) |
|
cv2.putText(display_frame, f"{class_name}: {confidence:.2f}", (x1, y1 - 25), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1) |
|
|
|
|
|
for player_id, player in _active_players.items(): |
|
x1, y1, x2, y2 = map(int, player.bbox) |
|
cv2.rectangle(display_frame, (x1, y1), (x2, y2), (0, 255, 255), 3) |
|
cv2.putText(display_frame, f"ID: {player.player_id}", (x1, y1 - 5), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2) |
|
|
|
out.write(display_frame) |
|
|
|
cap.release() |
|
out.release() |
|
|
|
print(f"Processing finished. Output video saved to: {OUTPUT_VIDEO_PATH}") |
|
return OUTPUT_VIDEO_PATH |
|
|
|
|
|
_load_yolo_model() |