cricketDRS / app.py
AjaykumarPilla's picture
Update app.py
185b0b9 verified
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
# Load the trained YOLOv8n model from the Space's root directory
model = YOLO("best.pt") # Assumes best.pt is in the same directory as app.py
# Constants for LBW decision and video processing
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
FRAME_RATE = 30 # Input video frame rate
SLOW_MOTION_FACTOR = 2 # For slow motion (6x slower)
CONF_THRESHOLD = 0.3 # Lowered confidence threshold for better detection
def process_video(video_path):
# Initialize video capture
if not os.path.exists(video_path):
return [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
debug_log = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
frames.append(frame.copy()) # Store original frame
# Detect ball using the trained YOLOv8n model
results = model.predict(frame, conf=CONF_THRESHOLD)
detections = 0
for detection in results[0].boxes:
if detection.cls == 0: # Assuming class 0 is the ball
detections += 1
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
# Draw bounding box on frame for visualization
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame # Update frame with bounding box
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
cap.release()
if not ball_positions:
debug_log.append("No balls detected in any frame")
else:
debug_log.append(f"Total ball detections: {len(ball_positions)}")
return frames, ball_positions, "\n".join(debug_log)
def estimate_trajectory(ball_positions, frames):
# Simplified physics-based trajectory projection
if len(ball_positions) < 2:
return None, None, "Error: Fewer than 2 ball detections for trajectory"
# Extract x, y coordinates
x_coords = [pos[0] for pos in ball_positions]
y_coords = [pos[1] for pos in ball_positions]
times = np.arange(len(ball_positions)) / FRAME_RATE
# Interpolate to smooth trajectory
try:
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
except Exception as e:
return None, None, f"Error in trajectory interpolation: {str(e)}"
# Project trajectory forward (0.5 seconds post-impact)
t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
x_future = fx(t_future)
y_future = fy(t_future)
return list(zip(x_future, y_future)), t_future, "Trajectory estimated successfully"
def lbw_decision(ball_positions, trajectory, frames):
# Simplified LBW logic
if not frames:
return "Error: No frames processed", None
if not trajectory or len(ball_positions) < 2:
return "Not enough data (insufficient ball detections)", None
# Assume stumps are at the bottom center of the frame (calibration needed)
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9 # Approximate stumps position
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0) # Assume 3m pitch width
# Check pitching point (first detected position)
pitch_x, pitch_y = ball_positions[0]
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
return "Not Out (Pitched outside line)", None
# Check impact point (last detected position)
impact_x, impact_y = ball_positions[-1]
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
return "Not Out (Impact outside line)", None
# Check trajectory hitting stumps
for x, y in trajectory:
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
return "Out", trajectory
return "Not Out (Missing stumps)", trajectory
def generate_slow_motion(frames, trajectory, output_path):
# Generate very slow-motion video with ball detection and trajectory overlay
if not frames:
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
for frame in frames:
if trajectory:
for x, y in trajectory:
cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots for trajectory
for _ in range(SLOW_MOTION_FACTOR): # Duplicate frames for very slow motion
out.write(frame)
out.release()
return output_path
def drs_review(video):
# Process video and generate DRS output
frames, ball_positions, debug_log = process_video(video)
if not frames:
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
decision, trajectory = lbw_decision(ball_positions, trajectory, frames)
# Generate slow-motion replay even if Trajectory fails
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path = generate_slow_motion(frames, trajectory, output_path)
# Combine debug logs for output
debug_output = f"{debug_log}\n{trajectory_log}"
return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision and Debug Log"),
gr.Video(label="Very Slow-Motion Replay with Ball Detection and Trajectory")
],
title="AI-Powered DRS for LBW in Local Cricket",
description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes) and trajectory (blue dots)."
)
if __name__ == "__main__":
iface.launch()