Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tempfile | |
from pathlib import Path | |
import numpy as np | |
import cv2 | |
from utils.video_utils import extract_frames_from_video, save_frames_to_video | |
from detection.detector import Detector | |
from tracking.ball_tracker import BasicBallTracker | |
from trajectory.fit_trajectory import TrajectoryFitter | |
from rules.lbw_engine import LBWEngine | |
from visualization.overlay_generator import OverlayGenerator | |
# Initialize components | |
detector = Detector() | |
tracker = BasicBallTracker() | |
traj_fitter = TrajectoryFitter() | |
umpire = LBWEngine() | |
visual = OverlayGenerator(config=None) | |
def review_video(video_file): | |
tmp_path = Path(video_file) # video_file is already a path string | |
if not tmp_path.is_file(): | |
return "Video file not found", None | |
try: | |
frames = extract_frames_from_video(str(tmp_path)) | |
except Exception as e: | |
return f"Error extracting frames: {str(e)}", None | |
# Ball detection and tracking | |
for idx, frm in enumerate(frames): | |
dets = detector.infer(frm) | |
tracker.update(dets, idx) | |
track_pts = [(int(x), int(y)) for _, x, y in tracker.get_track()] | |
if len(track_pts) < 5: | |
return "Insufficient ball points detected", None | |
# Fit trajectory and project the ball path dynamically | |
traj_fitter.fit(track_pts) | |
xs = np.linspace(track_pts[0][0], track_pts[-1][0], 100) | |
ys = traj_fitter.project(xs) | |
if len(xs) < 2 or len(ys) < 2: | |
return "Trajectory fitting failed", None | |
curve_pts = list(zip(xs.astype(int), ys.astype(int))) | |
# Dynamic pitch zone calculation (based on first detected point) | |
pitch_zone = determine_pitch_zone(track_pts[0]) | |
# Dynamic impact zone calculation (based on trajectory and impact) | |
impact_zone = determine_impact_zone(track_pts) | |
# Check if the ball hits the stumps | |
hits_stumps = check_stumps_impact(curve_pts) | |
# Decision logic: determine if it’s OUT or NOT OUT dynamically | |
verdict, reason = umpire.decide({ | |
"pitch_zone": pitch_zone, | |
"impact_zone": impact_zone, | |
"hits_stumps": hits_stumps, | |
"shot_offered": False, # This can be improved with player pose detection | |
}) | |
# Annotate the frames with the trajectory and decision dynamically | |
annotated_frames = [] | |
for frm in frames: | |
annotated_frame = visual.draw(frm.copy(), curve_pts, verdict) | |
annotated_frames.append(annotated_frame) | |
# Save the annotated video to a temporary file | |
out_file = tempfile.NamedTemporaryFile(suffix="_drs.mp4", delete=False) | |
save_frames_to_video(annotated_frames, out_file.name) | |
out_file.close() # Ensure the temporary file is closed and accessible | |
return verdict + ": " + reason, out_file.name | |
# Helper Functions | |
def determine_pitch_zone(first_point): | |
""" | |
Determine if the ball is in-line with the stumps or outside off/leg. | |
Dynamic based on the ball's first detected position. | |
""" | |
x, y = first_point # first_point is (x, y) | |
# Example logic: check if the ball is in-line or outside | |
if x < 300: # In-line with stumps (This is just an example threshold) | |
return "inline" | |
elif 300 <= x <= 500: # Outside off stump | |
return "outside_off" | |
else: | |
return "outside_leg" | |
def determine_impact_zone(track_points): | |
""" | |
Determine the impact zone: in-line or outside leg. | |
Based on the trajectory of the ball and its impact. | |
""" | |
# Check if the ball impacts the batsman's leg (dynamic) | |
impact_point = track_points[-1] # Last point could be an approximation of impact | |
x, y = impact_point | |
if 200 <= x <= 400: # Assuming this range as an in-line range for example | |
return "in_line" | |
else: | |
return "outside_leg" | |
def check_stumps_impact(trajectory_points): | |
""" | |
Predict if the ball would hit the stumps based on its trajectory. | |
""" | |
last_point = trajectory_points[-1] | |
# Example logic: Check if the final projected point is in-line with the stumps | |
x, y = last_point | |
if 200 <= x <= 400 and y <= 0: # Ball hitting the stumps (example range) | |
return True | |
else: | |
return False | |
# Gradio interface setup | |
gui = gr.Interface( | |
fn=review_video, | |
inputs=gr.Video(label="Upload LBW Appeal Video"), | |
outputs=[gr.Textbox(label="Verdict"), gr.Video(label="Annotated Output")], | |
title="LBW DRS AI Review", | |
description="Proof‑of‑concept: detects ball, projects trajectory, and renders decision.", | |
) | |
if __name__ == "__main__": | |
gui.launch() | |