Spaces:
Sleeping
Sleeping
File size: 7,912 Bytes
e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 f558e96 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 1ef50c2 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 b267b22 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 2bde947 e0fcf03 1ef50c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import plotly.graph_objects as go
import uuid
import os
import tempfile
# Load YOLOv8 model and resolve class index
model = YOLO("best.pt")
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Dynamically resolve ball class index
ball_class_index = None
for k, v in model.names.items():
if v.lower() == "cricketball":
ball_class_index = k
break
if ball_class_index is None:
raise ValueError("Class 'cricketBall' not found in model.names")
# Constants
STUMPS_WIDTH = 0.2286
BALL_DIAMETER = 0.073
FRAME_RATE = 20
SLOW_MOTION_FACTOR = 2
CONF_THRESHOLD = 0.2
IMPACT_ZONE_Y = 0.85
IMPACT_DELTA_Y = 50
PITCH_LENGTH = 20.12
STUMPS_HEIGHT = 0.71
MAX_POSITION_JUMP = 30
def process_video(video_path):
if not os.path.exists(video_path):
return [], [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames, ball_positions, detection_frames, debug_log = [], [], [], []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
frames.append(frame.copy())
results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
detections = 0
for detection in results[0].boxes:
if int(detection.cls) == ball_class_index:
detections += 1
if detections == 1:
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
detection_frames.append(frame_count - 1)
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
cap.release()
if not ball_positions:
debug_log.append("No balls detected in any frame")
else:
debug_log.append(f"Total ball detections: {len(ball_positions)}")
debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
return frames, ball_positions, detection_frames, "\n".join(debug_log)
def find_bounce_point(ball_coords):
for i in range(1, len(ball_coords) - 1):
if ball_coords[i-1][1] < ball_coords[i][1] > ball_coords[i+1][1]:
return ball_coords[i]
return ball_coords[len(ball_coords)//3] # fallback
def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
if not frames or not trajectory or len(ball_positions) < 2:
return "Not enough data", trajectory, pitch_point, impact_point
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
pitch_x, _ = pitch_point
impact_x, impact_y = impact_point
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
for x, y in trajectory:
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
if len(ball_positions) < 2:
return None, None, None, "Error: Not enough ball detections"
filtered_positions = [ball_positions[0]]
filtered_frames = [detection_frames[0]]
for i in range(1, len(ball_positions)):
prev, curr = filtered_positions[-1], ball_positions[i]
if np.linalg.norm(np.array(curr) - np.array(prev)) <= MAX_POSITION_JUMP:
filtered_positions.append(curr)
filtered_frames.append(detection_frames[i])
if len(filtered_positions) < 2:
return None, None, None, "Error: Filtered detections too few"
x_vals = [p[0] for p in filtered_positions]
y_vals = [p[1] for p in filtered_positions]
times = np.array(filtered_frames) / FRAME_RATE
try:
fx = interp1d(times, x_vals, kind='cubic', fill_value="extrapolate")
fy = interp1d(times, y_vals, kind='cubic', fill_value="extrapolate")
except Exception as e:
return None, None, None, f"Interpolation error: {str(e)}"
total_frames = max(filtered_frames) - min(filtered_frames) + 1
t_full = np.linspace(times[0], times[-1], max(5, total_frames * SLOW_MOTION_FACTOR))
x_full = fx(t_full)
y_full = fy(t_full)
trajectory = list(zip(x_full, y_full))
pitch_point = find_bounce_point(filtered_positions)
impact_point = filtered_positions[-1]
return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
if not frames or not trajectory:
return None
temp_file = os.path.join(tempfile.gettempdir(), f"drs_output_{uuid.uuid4()}.mp4")
height, width = frames[0].shape[:2]
out = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
min_frame = min(detection_frames)
max_frame = max(detection_frames)
total_frames = max_frame - min_frame + 1
traj_per_frame = max(1, len(trajectory) // total_frames)
indices = [min(i * traj_per_frame, len(trajectory)-1) for i in range(total_frames)]
for i, frame in enumerate(frames):
idx = i - min_frame
if 0 <= idx < len(indices):
end_idx = indices[idx]
points = np.array(trajectory[:end_idx+1], dtype=np.int32).reshape((-1, 1, 2))
cv2.polylines(frame, [points], False, (255, 0, 0), 2)
if pitch_point and i == detection_frames[0]:
cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
if impact_point and i == detection_frames[-1]:
cv2.circle(frame, tuple(map(int, impact_point)), 6, (0, 255, 255), -1)
for _ in range(SLOW_MOTION_FACTOR):
out.write(frame)
out.release()
return temp_file
def drs_review(video):
frames, ball_positions, detection_frames, debug_log = process_video(video)
if not frames or not ball_positions:
return "No frames or detections found.", None
frame_height, frame_width = frames[0].shape[:2]
trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
if not trajectory:
return f"{log}\n{debug_log}", None
decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
replay_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
return result_log, replay_path
# Gradio Interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Cricket Delivery Video"),
outputs=[
gr.Textbox(label="DRS Result and Debug Info"),
gr.Video(label="Replay with Trajectory & Decision")
],
title="GullyDRS - AI-Powered LBW Review",
description="Upload a cricket delivery video. The system will track the ball, estimate trajectory, and return a replay with an OUT/NOT OUT decision."
)
if __name__ == "__main__":
iface.launch()
|