import os from tempfile import TemporaryDirectory import gradio as gr import numpy as np import pandas as pd import spaces import torch from huggingface_hub import Repository from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \ SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay from tqdm import trange, tqdm os.chmod("/usr/local/lib/python3.10/site-packages/rlgym_tools/rocket_league/replays/carball", 0o755) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN")) repo.git_pull() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL = torch.jit.load("vortex-ngp/vortex-ngp-daily-energy.pt", map_location=DEVICE) MODEL.eval() @spaces.GPU @torch.inference_mode() def infer(model, replay_file, nullify_goal_difference=False, ignore_ties=False): num_outputs = 123 swap_team_idx = torch.arange(num_outputs) mid = num_outputs // 2 swap_team_idx[mid:-1] = swap_team_idx[:mid] swap_team_idx[:mid] += num_outputs // 2 replay = ParsedReplay.load(replay_file) it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df)) replay_frames = [] serialized_states = [] serialized_scoreboards = [] seconds_remaining = [] for replay_frame in it: replay_frames.append(replay_frame) sstate = serialize_game_state(replay_frame.state) sscoreboard = serialize_scoreboard(replay_frame.scoreboard) serialized_states.append(sstate) serialized_scoreboards.append(sscoreboard) seconds_remaining.append(replay_frame.episode_seconds_remaining) serialized_states = torch.from_numpy(np.stack(serialized_states)) serialized_scoreboards = torch.from_numpy(np.stack(serialized_scoreboards)) seconds_remaining = torch.tensor(seconds_remaining) it.close() timer = serialized_scoreboards[:, SB_GAME_TIMER_SECONDS].clone() is_ot = timer > 450 ot_time_remaining = seconds_remaining[is_ot] if len(ot_time_remaining) > 0: ot_timer = ot_time_remaining[0] - ot_time_remaining timer[is_ot] = -ot_timer # Negate to indicate overtime goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE] goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0])) bs = 900 predictions = [] it = trange(len(serialized_states), desc="Running model") for i in range(0, len(serialized_states), bs): batch = (serialized_states[i:i + bs].clone().to(DEVICE), serialized_scoreboards[i:i + bs].clone().to(DEVICE)) if nullify_goal_difference or ignore_ties: batch[1][:, SB_BLUE_SCORE] = 0 batch[1][:, SB_ORANGE_SCORE] = 0 if ignore_ties: batch[1][:, SB_GAME_TIMER_SECONDS] = float("inf") out = model(*batch) it.update(len(batch[0])) predictions.append(out) predictions = torch.cat(predictions, dim=0) probs = predictions.softmax(dim=-1) bin_seconds = torch.linspace(0, 60, num_outputs // 2) class_names = [ f"{t}: {s:g}s" for t in ["Blue", "Orange"] for s in bin_seconds.tolist() ] class_names.append("Tie") preds = probs.cpu().numpy() preds = pd.DataFrame(data=preds, columns=class_names) preds["Blue"] = preds[[c for c in preds.columns if c.startswith("Blue")]].sum(axis=1) preds["Orange"] = preds[[c for c in preds.columns if c.startswith("Orange")]].sum(axis=1) preds["Timer"] = timer preds["Goal"] = goal_diff_diff preds["Touch"] = "" pid_to_name = {int(p["unique_id"]): p["name"] for p in replay.metadata["players"] if p["unique_id"] in replay.player_dfs} for i, replay_frame in enumerate(replay_frames): state = replay_frame.state for aid, car in state.cars.items(): if car.ball_touches > 0: team = "Blue" if car.is_blue else "Orange" name = pid_to_name[aid] name = name.replace("|", " ") # Replace pipe with space to not conflict with sep if preds.at[i, "Touch"] != "": preds.at[i, "Touch"] += "|" preds.at[i, "Touch"] += f"{team}|{name}" # Sort columns main_cols = ["Timer", "Blue", "Orange", "Tie", "Goal", "Touch"] preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]] # Set index name preds.index.name = "Frame" remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool) remove_ties_mask = remove_ties_mask.numpy() if remove_ties_mask.any(): tie_probs = preds.loc[remove_ties_mask, "Tie"] q = (1 - tie_probs) for c in preds.columns: if c.startswith("Blue") or c.startswith("Orange"): preds.loc[remove_ties_mask, c] /= q if ignore_ties: preds = preds.drop("Tie", axis=1) else: preds.loc[remove_ties_mask, "Tie"] = 0.0 return preds def plot_plotly(preds: pd.DataFrame): import plotly.graph_objects as go preds_df = preds.drop(["Touch", "Timer", "Goal"], axis=1) * 100 timer = preds["Timer"] fig = go.Figure() def format_timer(t): sign = '+' if t < 0 else '' return f"{sign}{abs(t) // 60:01.0f}:{abs(t) % 60:02.0f}" timer_text = [format_timer(t.item()) for t in timer.values] hovertemplate = 'Frame %{x}
Prob: %{y:.3g}%
Timer: %{customdata}' # Add traces for Blue, Orange, and Tie probabilities from the DataFrame fig.add_trace( go.Scatter(x=preds_df.index, y=preds_df["Blue"], mode='lines', name='Blue', line=dict(color='blue'), customdata=timer_text, hovertemplate=hovertemplate)) fig.add_trace( go.Scatter(x=preds_df.index, y=preds_df["Orange"], mode='lines', name='Orange', line=dict(color='orange'), customdata=timer_text, hovertemplate=hovertemplate)) if "Tie" in preds.columns: fig.add_trace( go.Scatter(x=preds_df.index, y=preds_df["Tie"], mode='lines', name='Tie', line=dict(color='gray'), customdata=timer_text, hovertemplate=hovertemplate)) # Add the horizontal line at y=50% fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability") # Add goal indicators b = o = 0 for goal_frame in preds["Goal"].index[preds["Goal"] != 0]: if preds["Goal"][goal_frame] > 0: b += 1 elif preds["Goal"][goal_frame] < 0: o += 1 fig.add_vline(x=goal_frame, line_dash="dash", line_color="red", annotation_text=f"{b}-{o}", annotation_position="top right") # Add touch indicators as points touches = {} for touch_frame in preds.index[preds["Touch"] != ""]: teams_players = preds.at[touch_frame, "Touch"].split('|') for team, player in zip(teams_players[::2], teams_players[1::2]): team = team.strip() player = player.strip() touches.setdefault(team, []).append((touch_frame, player)) for team in "Blue", "Orange": team_touches = touches.get(team, []) if not team_touches: continue x = [t[0] for t in team_touches] y = [preds_df.at[t[0], team] for t in team_touches] touch_players = [t[1] for t in team_touches] custom_data = [f"{timer_text[f]}
Touch by {p}" for f, p in zip(x, touch_players)] fig.add_trace( go.Scatter(x=x, y=y, mode='markers', name=f'{team} touches', marker=dict(size=5, color=team.lower(), symbol='circle-open-dot'), customdata=custom_data, hovertemplate=hovertemplate )) # Define the formatting function for the secondary x-axis labels def format_timer_ticks(x): """Converts a frame number to a formatted time string.""" x = int(x) # Ensure the index is within the bounds of the timer series x = max(0, min(x, len(timer) - 1)) # Calculate the time value t = timer.iloc[x] * 300 # Format the time as MM:SS, with a '+' for negative values (representing overtime) sign = '+' if t < 0 else '' minutes = int(abs(t) // 60) seconds = int(abs(t) % 60) return f"{sign}{minutes:01}:{seconds:02}" # Generate positions and labels for the secondary axis ticks # Creates 10 evenly spaced ticks for clarity tick_positions = np.linspace(0, len(preds_df) - 1, 10) tick_labels = [format_timer_ticks(val) for val in tick_positions] # Configure the figure's layout, titles, and both x-axes fig.update_layout( title="Interactive Probability Plot", xaxis=dict( title="Frame", gridcolor='#e5e7eb' # A light gray grid for a modern look ), yaxis=dict( title="Probability", gridcolor='#e5e7eb' ), # --- Secondary X-Axis Configuration --- xaxis2=dict( title="Timer", overlaying='x', # This makes it a secondary axis side='top', # Position it at the top tickmode='array', tickvals=tick_positions, ticktext=tick_labels ), legend=dict(x=0.01, y=0.99, yanchor="top", xanchor="left"), # Position legend inside plot plot_bgcolor='white' # A clean white background ) # fig.show() return fig DESCRIPTION = """ # Next Goal Predictor Upload a replay file to get a plot of the next goal prediction. The model is trained on about 14 000 hours of SSL and RLCS replays in 1v1, 2v2, and 3v3 using [this dataset](https://www.kaggle.com/datasets/rolvarild/high-level-rocket-league-replay-dataset).
It predicts the probability that each team will score at 1 second intervals up to 60+ seconds. It also predicts ties (ball hitting the ground at 0s)
The plot only shows the totals for each team, but you can download the full predictions if you want. """.strip() RADIO_OPTIONS = ["Default", "Nullify goal difference", "Ignore ties"] RADIO_INFO = """ - **Default**: Uses the model as it is trained, with no modifications. - **Nullify goal difference**: Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team. - **Ignore ties**: Makes the model pretend every situation is an overtime (e.g. ties are impossible). """.strip() with TemporaryDirectory() as temp_dir: with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) # Use gr.Column to stack components vertically with gr.Column(): file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"]) checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=RADIO_OPTIONS[0], info=RADIO_INFO) submit_button = gr.Button("Generate Predictions") plot_output = gr.Plot(label="Predictions") download_button = gr.DownloadButton("Download Predictions", visible=False) def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)): # Make plot on button click nullify_goal_difference = radio_option == 1 ignore_ties = radio_option == 2 print(f"Processing file: {replay_file}") replay_stem = os.path.splitext(os.path.basename(replay_file))[0] postfix = "" if nullify_goal_difference: postfix += "_nullify_goal_difference" elif ignore_ties: postfix += "_ignore_ties" preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}{postfix}.csv") if os.path.exists(preds_file): print(f"Predictions file already exists: {preds_file}") preds = pd.read_csv(preds_file, dtype={"Touch": str}) preds["Touch"] = preds["Touch"].fillna("") else: preds = infer(MODEL, replay_file, nullify_goal_difference=nullify_goal_difference, ignore_ties=ignore_ties) plt = plot_plotly(preds) print(f"Plot generated for file: {replay_file}") preds.to_csv(preds_file) if len(os.listdir(temp_dir)) > 100: # Delete least recent file oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f))) os.remove(os.path.join(temp_dir, oldest_file)) return plt, gr.DownloadButton(value=preds_file, visible=True) submit_button.click( fn=make_plot, inputs=[file_input, checkboxes], outputs=[plot_output, download_button], show_progress="full", ) demo.queue(default_concurrency_limit=None) demo.launch()