lbw_drs_ai / app.py
dschandra's picture
Upload 2 files
4da8df7 verified
import gradio as gr
import tempfile
from pathlib import Path
import numpy as np
import cv2
from utils.video_utils import extract_frames_from_video, save_frames_to_video
from detection.detector import Detector
from tracking.ball_tracker import BasicBallTracker
from trajectory.fit_trajectory import TrajectoryFitter
from rules.lbw_engine import LBWEngine
from visualization.overlay_generator import OverlayGenerator
# Initialize components
detector = Detector()
tracker = BasicBallTracker()
traj_fitter = TrajectoryFitter()
umpire = LBWEngine()
visual = OverlayGenerator(config=None)
def review_video(video_file):
tmp_path = Path(video_file) # video_file is already a path string
if not tmp_path.is_file():
return "Video file not found", None
try:
frames = extract_frames_from_video(str(tmp_path))
except Exception as e:
return f"Error extracting frames: {str(e)}", None
# Ball detection and tracking
for idx, frm in enumerate(frames):
dets = detector.infer(frm)
tracker.update(dets, idx)
track_pts = [(int(x), int(y)) for _, x, y in tracker.get_track()]
if len(track_pts) < 5:
return "Insufficient ball points detected", None
# Fit trajectory and project the ball path dynamically
traj_fitter.fit(track_pts)
xs = np.linspace(track_pts[0][0], track_pts[-1][0], 100)
ys = traj_fitter.project(xs)
if len(xs) < 2 or len(ys) < 2:
return "Trajectory fitting failed", None
curve_pts = list(zip(xs.astype(int), ys.astype(int)))
# Dynamic pitch zone calculation (based on first detected point)
pitch_zone = determine_pitch_zone(track_pts[0])
# Dynamic impact zone calculation (based on trajectory and impact)
impact_zone = determine_impact_zone(track_pts)
# Check if the ball hits the stumps
hits_stumps = check_stumps_impact(curve_pts)
# Decision logic: determine if it’s OUT or NOT OUT dynamically
verdict, reason = umpire.decide({
"pitch_zone": pitch_zone,
"impact_zone": impact_zone,
"hits_stumps": hits_stumps,
"shot_offered": False, # This can be improved with player pose detection
})
# Annotate the frames with the trajectory and decision dynamically
annotated_frames = []
for frm in frames:
annotated_frame = visual.draw(frm.copy(), curve_pts, verdict)
annotated_frames.append(annotated_frame)
# Save the annotated video to a temporary file
out_file = tempfile.NamedTemporaryFile(suffix="_drs.mp4", delete=False)
save_frames_to_video(annotated_frames, out_file.name)
out_file.close() # Ensure the temporary file is closed and accessible
return verdict + ": " + reason, out_file.name
# Helper Functions
def determine_pitch_zone(first_point):
"""
Determine if the ball is in-line with the stumps or outside off/leg.
Dynamic based on the ball's first detected position.
"""
x, y = first_point # first_point is (x, y)
# Example logic: check if the ball is in-line or outside
if x < 300: # In-line with stumps (This is just an example threshold)
return "inline"
elif 300 <= x <= 500: # Outside off stump
return "outside_off"
else:
return "outside_leg"
def determine_impact_zone(track_points):
"""
Determine the impact zone: in-line or outside leg.
Based on the trajectory of the ball and its impact.
"""
# Check if the ball impacts the batsman's leg (dynamic)
impact_point = track_points[-1] # Last point could be an approximation of impact
x, y = impact_point
if 200 <= x <= 400: # Assuming this range as an in-line range for example
return "in_line"
else:
return "outside_leg"
def check_stumps_impact(trajectory_points):
"""
Predict if the ball would hit the stumps based on its trajectory.
"""
last_point = trajectory_points[-1]
# Example logic: Check if the final projected point is in-line with the stumps
x, y = last_point
if 200 <= x <= 400 and y <= 0: # Ball hitting the stumps (example range)
return True
else:
return False
# Gradio interface setup
gui = gr.Interface(
fn=review_video,
inputs=gr.Video(label="Upload LBW Appeal Video"),
outputs=[gr.Textbox(label="Verdict"), gr.Video(label="Annotated Output")],
title="LBW DRS AI Review",
description="Proof‑of‑concept: detects ball, projects trajectory, and renders decision.",
)
if __name__ == "__main__":
gui.launch()