Rolv-Arild's picture
Update app.py
d33fd9c verified
import os
from tempfile import TemporaryDirectory
import gradio as gr
import numpy as np
import pandas as pd
import spaces
import torch
from huggingface_hub import Repository
from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
from tqdm import trange, tqdm
os.chmod("/usr/local/lib/python3.10/site-packages/rlgym_tools/rocket_league/replays/carball", 0o755)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN"))
repo.git_pull()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = torch.jit.load("vortex-ngp/vortex-ngp-daily-energy.pt", map_location=DEVICE)
MODEL.eval()
@spaces.GPU
@torch.inference_mode()
def infer(model, replay_file,
nullify_goal_difference=False,
ignore_ties=False):
num_outputs = 123
swap_team_idx = torch.arange(num_outputs)
mid = num_outputs // 2
swap_team_idx[mid:-1] = swap_team_idx[:mid]
swap_team_idx[:mid] += num_outputs // 2
replay = ParsedReplay.load(replay_file)
it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
replay_frames = []
serialized_states = []
serialized_scoreboards = []
seconds_remaining = []
for replay_frame in it:
replay_frames.append(replay_frame)
sstate = serialize_game_state(replay_frame.state)
sscoreboard = serialize_scoreboard(replay_frame.scoreboard)
serialized_states.append(sstate)
serialized_scoreboards.append(sscoreboard)
seconds_remaining.append(replay_frame.episode_seconds_remaining)
serialized_states = torch.from_numpy(np.stack(serialized_states))
serialized_scoreboards = torch.from_numpy(np.stack(serialized_scoreboards))
seconds_remaining = torch.tensor(seconds_remaining)
it.close()
timer = serialized_scoreboards[:, SB_GAME_TIMER_SECONDS].clone()
is_ot = timer > 450
ot_time_remaining = seconds_remaining[is_ot]
if len(ot_time_remaining) > 0:
ot_timer = ot_time_remaining[0] - ot_time_remaining
timer[is_ot] = -ot_timer # Negate to indicate overtime
goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE]
goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
bs = 900
predictions = []
it = trange(len(serialized_states), desc="Running model")
for i in range(0, len(serialized_states), bs):
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
serialized_scoreboards[i:i + bs].clone().to(DEVICE))
if nullify_goal_difference or ignore_ties:
batch[1][:, SB_BLUE_SCORE] = 0
batch[1][:, SB_ORANGE_SCORE] = 0
if ignore_ties:
batch[1][:, SB_GAME_TIMER_SECONDS] = float("inf")
out = model(*batch)
it.update(len(batch[0]))
predictions.append(out)
predictions = torch.cat(predictions, dim=0)
probs = predictions.softmax(dim=-1)
bin_seconds = torch.linspace(0, 60, num_outputs // 2)
class_names = [
f"{t}: {s:g}s" for t in ["Blue", "Orange"]
for s in bin_seconds.tolist()
]
class_names.append("Tie")
preds = probs.cpu().numpy()
preds = pd.DataFrame(data=preds, columns=class_names)
preds["Blue"] = preds[[c for c in preds.columns if c.startswith("Blue")]].sum(axis=1)
preds["Orange"] = preds[[c for c in preds.columns if c.startswith("Orange")]].sum(axis=1)
preds["Timer"] = timer
preds["Goal"] = goal_diff_diff
preds["Touch"] = ""
pid_to_name = {int(p["unique_id"]): p["name"]
for p in replay.metadata["players"]
if p["unique_id"] in replay.player_dfs}
for i, replay_frame in enumerate(replay_frames):
state = replay_frame.state
for aid, car in state.cars.items():
if car.ball_touches > 0:
team = "Blue" if car.is_blue else "Orange"
name = pid_to_name[aid]
name = name.replace("|", " ") # Replace pipe with space to not conflict with sep
if preds.at[i, "Touch"] != "":
preds.at[i, "Touch"] += "|"
preds.at[i, "Touch"] += f"{team}|{name}"
# Sort columns
main_cols = ["Timer", "Blue", "Orange", "Tie", "Goal", "Touch"]
preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
# Set index name
preds.index.name = "Frame"
remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool)
remove_ties_mask = remove_ties_mask.numpy()
if remove_ties_mask.any():
tie_probs = preds.loc[remove_ties_mask, "Tie"]
q = (1 - tie_probs)
for c in preds.columns:
if c.startswith("Blue") or c.startswith("Orange"):
preds.loc[remove_ties_mask, c] /= q
if ignore_ties:
preds = preds.drop("Tie", axis=1)
else:
preds.loc[remove_ties_mask, "Tie"] = 0.0
return preds
def plot_plotly(preds: pd.DataFrame):
import plotly.graph_objects as go
preds_df = preds.drop(["Touch", "Timer", "Goal"], axis=1) * 100
timer = preds["Timer"]
fig = go.Figure()
def format_timer(t):
sign = '+' if t < 0 else ''
return f"{sign}{abs(t) // 60:01.0f}:{abs(t) % 60:02.0f}"
timer_text = [format_timer(t.item()) for t in timer.values]
hovertemplate = '<b>Frame %{x}</b><br>Prob: %{y:.3g}%<br>Timer: %{customdata}<extra></extra>'
# Add traces for Blue, Orange, and Tie probabilities from the DataFrame
fig.add_trace(
go.Scatter(x=preds_df.index, y=preds_df["Blue"],
mode='lines', name='Blue', line=dict(color='blue'),
customdata=timer_text, hovertemplate=hovertemplate))
fig.add_trace(
go.Scatter(x=preds_df.index, y=preds_df["Orange"],
mode='lines', name='Orange', line=dict(color='orange'),
customdata=timer_text, hovertemplate=hovertemplate))
if "Tie" in preds.columns:
fig.add_trace(
go.Scatter(x=preds_df.index, y=preds_df["Tie"],
mode='lines', name='Tie', line=dict(color='gray'),
customdata=timer_text, hovertemplate=hovertemplate))
# Add the horizontal line at y=50%
fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability")
# Add goal indicators
b = o = 0
for goal_frame in preds["Goal"].index[preds["Goal"] != 0]:
if preds["Goal"][goal_frame] > 0:
b += 1
elif preds["Goal"][goal_frame] < 0:
o += 1
fig.add_vline(x=goal_frame, line_dash="dash", line_color="red",
annotation_text=f"{b}-{o}", annotation_position="top right")
# Add touch indicators as points
touches = {}
for touch_frame in preds.index[preds["Touch"] != ""]:
teams_players = preds.at[touch_frame, "Touch"].split('|')
for team, player in zip(teams_players[::2], teams_players[1::2]):
team = team.strip()
player = player.strip()
touches.setdefault(team, []).append((touch_frame, player))
for team in "Blue", "Orange":
team_touches = touches.get(team, [])
if not team_touches:
continue
x = [t[0] for t in team_touches]
y = [preds_df.at[t[0], team] for t in team_touches]
touch_players = [t[1] for t in team_touches]
custom_data = [f"{timer_text[f]}<br>Touch by {p}"
for f, p in zip(x, touch_players)]
fig.add_trace(
go.Scatter(x=x, y=y,
mode='markers',
name=f'{team} touches',
marker=dict(size=5, color=team.lower(), symbol='circle-open-dot'),
customdata=custom_data,
hovertemplate=hovertemplate
))
# Define the formatting function for the secondary x-axis labels
def format_timer_ticks(x):
"""Converts a frame number to a formatted time string."""
x = int(x)
# Ensure the index is within the bounds of the timer series
x = max(0, min(x, len(timer) - 1))
# Calculate the time value
t = timer.iloc[x] * 300
# Format the time as MM:SS, with a '+' for negative values (representing overtime)
sign = '+' if t < 0 else ''
minutes = int(abs(t) // 60)
seconds = int(abs(t) % 60)
return f"{sign}{minutes:01}:{seconds:02}"
# Generate positions and labels for the secondary axis ticks
# Creates 10 evenly spaced ticks for clarity
tick_positions = np.linspace(0, len(preds_df) - 1, 10)
tick_labels = [format_timer_ticks(val) for val in tick_positions]
# Configure the figure's layout, titles, and both x-axes
fig.update_layout(
title="Interactive Probability Plot",
xaxis=dict(
title="Frame",
gridcolor='#e5e7eb' # A light gray grid for a modern look
),
yaxis=dict(
title="Probability",
gridcolor='#e5e7eb'
),
# --- Secondary X-Axis Configuration ---
xaxis2=dict(
title="Timer",
overlaying='x', # This makes it a secondary axis
side='top', # Position it at the top
tickmode='array',
tickvals=tick_positions,
ticktext=tick_labels
),
legend=dict(x=0.01, y=0.99, yanchor="top", xanchor="left"), # Position legend inside plot
plot_bgcolor='white' # A clean white background
)
# fig.show()
return fig
DESCRIPTION = """
# Next Goal Predictor
Upload a replay file to get a plot of the next goal prediction.
The model is trained on about 14 000 hours of SSL and RLCS replays in 1v1, 2v2, and 3v3 using [this dataset](https://www.kaggle.com/datasets/rolvarild/high-level-rocket-league-replay-dataset).<br>
It predicts the probability that each team will score at 1 second intervals up to 60+ seconds.
It also predicts ties (ball hitting the ground at 0s)<br>
The plot only shows the totals for each team, but you can download the full predictions if you want.
""".strip()
RADIO_OPTIONS = ["Default", "Nullify goal difference", "Ignore ties"]
RADIO_INFO = """
- **Default**: Uses the model as it is trained, with no modifications.
- **Nullify goal difference**: Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team.
- **Ignore ties**: Makes the model pretend every situation is an overtime (e.g. ties are impossible).
""".strip()
with TemporaryDirectory() as temp_dir:
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
# Use gr.Column to stack components vertically
with gr.Column():
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=RADIO_OPTIONS[0],
info=RADIO_INFO)
submit_button = gr.Button("Generate Predictions")
plot_output = gr.Plot(label="Predictions")
download_button = gr.DownloadButton("Download Predictions", visible=False)
def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)):
# Make plot on button click
nullify_goal_difference = radio_option == 1
ignore_ties = radio_option == 2
print(f"Processing file: {replay_file}")
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
postfix = ""
if nullify_goal_difference:
postfix += "_nullify_goal_difference"
elif ignore_ties:
postfix += "_ignore_ties"
preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}{postfix}.csv")
if os.path.exists(preds_file):
print(f"Predictions file already exists: {preds_file}")
preds = pd.read_csv(preds_file, dtype={"Touch": str})
preds["Touch"] = preds["Touch"].fillna("")
else:
preds = infer(MODEL, replay_file,
nullify_goal_difference=nullify_goal_difference,
ignore_ties=ignore_ties)
plt = plot_plotly(preds)
print(f"Plot generated for file: {replay_file}")
preds.to_csv(preds_file)
if len(os.listdir(temp_dir)) > 100:
# Delete least recent file
oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f)))
os.remove(os.path.join(temp_dir, oldest_file))
return plt, gr.DownloadButton(value=preds_file, visible=True)
submit_button.click(
fn=make_plot,
inputs=[file_input, checkboxes],
outputs=[plot_output, download_button],
show_progress="full",
)
demo.queue(default_concurrency_limit=None)
demo.launch()