Spaces:
Sleeping
Sleeping
File size: 4,652 Bytes
4da8df7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
import tempfile
from pathlib import Path
import numpy as np
import cv2
from utils.video_utils import extract_frames_from_video, save_frames_to_video
from detection.detector import Detector
from tracking.ball_tracker import BasicBallTracker
from trajectory.fit_trajectory import TrajectoryFitter
from rules.lbw_engine import LBWEngine
from visualization.overlay_generator import OverlayGenerator
# Initialize components
detector = Detector()
tracker = BasicBallTracker()
traj_fitter = TrajectoryFitter()
umpire = LBWEngine()
visual = OverlayGenerator(config=None)
def review_video(video_file):
tmp_path = Path(video_file) # video_file is already a path string
if not tmp_path.is_file():
return "Video file not found", None
try:
frames = extract_frames_from_video(str(tmp_path))
except Exception as e:
return f"Error extracting frames: {str(e)}", None
# Ball detection and tracking
for idx, frm in enumerate(frames):
dets = detector.infer(frm)
tracker.update(dets, idx)
track_pts = [(int(x), int(y)) for _, x, y in tracker.get_track()]
if len(track_pts) < 5:
return "Insufficient ball points detected", None
# Fit trajectory and project the ball path dynamically
traj_fitter.fit(track_pts)
xs = np.linspace(track_pts[0][0], track_pts[-1][0], 100)
ys = traj_fitter.project(xs)
if len(xs) < 2 or len(ys) < 2:
return "Trajectory fitting failed", None
curve_pts = list(zip(xs.astype(int), ys.astype(int)))
# Dynamic pitch zone calculation (based on first detected point)
pitch_zone = determine_pitch_zone(track_pts[0])
# Dynamic impact zone calculation (based on trajectory and impact)
impact_zone = determine_impact_zone(track_pts)
# Check if the ball hits the stumps
hits_stumps = check_stumps_impact(curve_pts)
# Decision logic: determine if it’s OUT or NOT OUT dynamically
verdict, reason = umpire.decide({
"pitch_zone": pitch_zone,
"impact_zone": impact_zone,
"hits_stumps": hits_stumps,
"shot_offered": False, # This can be improved with player pose detection
})
# Annotate the frames with the trajectory and decision dynamically
annotated_frames = []
for frm in frames:
annotated_frame = visual.draw(frm.copy(), curve_pts, verdict)
annotated_frames.append(annotated_frame)
# Save the annotated video to a temporary file
out_file = tempfile.NamedTemporaryFile(suffix="_drs.mp4", delete=False)
save_frames_to_video(annotated_frames, out_file.name)
out_file.close() # Ensure the temporary file is closed and accessible
return verdict + ": " + reason, out_file.name
# Helper Functions
def determine_pitch_zone(first_point):
"""
Determine if the ball is in-line with the stumps or outside off/leg.
Dynamic based on the ball's first detected position.
"""
x, y = first_point # first_point is (x, y)
# Example logic: check if the ball is in-line or outside
if x < 300: # In-line with stumps (This is just an example threshold)
return "inline"
elif 300 <= x <= 500: # Outside off stump
return "outside_off"
else:
return "outside_leg"
def determine_impact_zone(track_points):
"""
Determine the impact zone: in-line or outside leg.
Based on the trajectory of the ball and its impact.
"""
# Check if the ball impacts the batsman's leg (dynamic)
impact_point = track_points[-1] # Last point could be an approximation of impact
x, y = impact_point
if 200 <= x <= 400: # Assuming this range as an in-line range for example
return "in_line"
else:
return "outside_leg"
def check_stumps_impact(trajectory_points):
"""
Predict if the ball would hit the stumps based on its trajectory.
"""
last_point = trajectory_points[-1]
# Example logic: Check if the final projected point is in-line with the stumps
x, y = last_point
if 200 <= x <= 400 and y <= 0: # Ball hitting the stumps (example range)
return True
else:
return False
# Gradio interface setup
gui = gr.Interface(
fn=review_video,
inputs=gr.Video(label="Upload LBW Appeal Video"),
outputs=[gr.Textbox(label="Verdict"), gr.Video(label="Annotated Output")],
title="LBW DRS AI Review",
description="Proof‑of‑concept: detects ball, projects trajectory, and renders decision.",
)
if __name__ == "__main__":
gui.launch()
|