Spaces:
Running
Running
import os | |
from tempfile import TemporaryDirectory | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import spaces | |
import torch | |
from huggingface_hub import Repository | |
from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \ | |
SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE | |
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym | |
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay | |
from tqdm import trange, tqdm | |
os.chmod("/usr/local/lib/python3.10/site-packages/rlgym_tools/rocket_league/replays/carball", 0o755) | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN")) | |
repo.git_pull() | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MODEL = torch.jit.load("vortex-ngp/vortex-ngp-daily-energy.pt", map_location=DEVICE) | |
MODEL.eval() | |
def infer(model, replay_file, | |
nullify_goal_difference=False, | |
ignore_ties=False): | |
num_outputs = 123 | |
swap_team_idx = torch.arange(num_outputs) | |
mid = num_outputs // 2 | |
swap_team_idx[mid:-1] = swap_team_idx[:mid] | |
swap_team_idx[:mid] += num_outputs // 2 | |
replay = ParsedReplay.load(replay_file) | |
it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df)) | |
replay_frames = [] | |
serialized_states = [] | |
serialized_scoreboards = [] | |
seconds_remaining = [] | |
for replay_frame in it: | |
replay_frames.append(replay_frame) | |
sstate = serialize_game_state(replay_frame.state) | |
sscoreboard = serialize_scoreboard(replay_frame.scoreboard) | |
serialized_states.append(sstate) | |
serialized_scoreboards.append(sscoreboard) | |
seconds_remaining.append(replay_frame.episode_seconds_remaining) | |
serialized_states = torch.from_numpy(np.stack(serialized_states)) | |
serialized_scoreboards = torch.from_numpy(np.stack(serialized_scoreboards)) | |
seconds_remaining = torch.tensor(seconds_remaining) | |
it.close() | |
timer = serialized_scoreboards[:, SB_GAME_TIMER_SECONDS].clone() | |
is_ot = timer > 450 | |
ot_time_remaining = seconds_remaining[is_ot] | |
if len(ot_time_remaining) > 0: | |
ot_timer = ot_time_remaining[0] - ot_time_remaining | |
timer[is_ot] = -ot_timer # Negate to indicate overtime | |
goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE] | |
goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0])) | |
bs = 900 | |
predictions = [] | |
it = trange(len(serialized_states), desc="Running model") | |
for i in range(0, len(serialized_states), bs): | |
batch = (serialized_states[i:i + bs].clone().to(DEVICE), | |
serialized_scoreboards[i:i + bs].clone().to(DEVICE)) | |
if nullify_goal_difference or ignore_ties: | |
batch[1][:, SB_BLUE_SCORE] = 0 | |
batch[1][:, SB_ORANGE_SCORE] = 0 | |
if ignore_ties: | |
batch[1][:, SB_GAME_TIMER_SECONDS] = float("inf") | |
out = model(*batch) | |
it.update(len(batch[0])) | |
predictions.append(out) | |
predictions = torch.cat(predictions, dim=0) | |
probs = predictions.softmax(dim=-1) | |
bin_seconds = torch.linspace(0, 60, num_outputs // 2) | |
class_names = [ | |
f"{t}: {s:g}s" for t in ["Blue", "Orange"] | |
for s in bin_seconds.tolist() | |
] | |
class_names.append("Tie") | |
preds = probs.cpu().numpy() | |
preds = pd.DataFrame(data=preds, columns=class_names) | |
preds["Blue"] = preds[[c for c in preds.columns if c.startswith("Blue")]].sum(axis=1) | |
preds["Orange"] = preds[[c for c in preds.columns if c.startswith("Orange")]].sum(axis=1) | |
preds["Timer"] = timer | |
preds["Goal"] = goal_diff_diff | |
preds["Touch"] = "" | |
pid_to_name = {int(p["unique_id"]): p["name"] | |
for p in replay.metadata["players"] | |
if p["unique_id"] in replay.player_dfs} | |
for i, replay_frame in enumerate(replay_frames): | |
state = replay_frame.state | |
for aid, car in state.cars.items(): | |
if car.ball_touches > 0: | |
team = "Blue" if car.is_blue else "Orange" | |
name = pid_to_name[aid] | |
name = name.replace("|", " ") # Replace pipe with space to not conflict with sep | |
if preds.at[i, "Touch"] != "": | |
preds.at[i, "Touch"] += "|" | |
preds.at[i, "Touch"] += f"{team}|{name}" | |
# Sort columns | |
main_cols = ["Timer", "Blue", "Orange", "Tie", "Goal", "Touch"] | |
preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]] | |
# Set index name | |
preds.index.name = "Frame" | |
remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool) | |
remove_ties_mask = remove_ties_mask.numpy() | |
if remove_ties_mask.any(): | |
tie_probs = preds.loc[remove_ties_mask, "Tie"] | |
q = (1 - tie_probs) | |
for c in preds.columns: | |
if c.startswith("Blue") or c.startswith("Orange"): | |
preds.loc[remove_ties_mask, c] /= q | |
if ignore_ties: | |
preds = preds.drop("Tie", axis=1) | |
else: | |
preds.loc[remove_ties_mask, "Tie"] = 0.0 | |
return preds | |
def plot_plotly(preds: pd.DataFrame): | |
import plotly.graph_objects as go | |
preds_df = preds.drop(["Touch", "Timer", "Goal"], axis=1) * 100 | |
timer = preds["Timer"] | |
fig = go.Figure() | |
def format_timer(t): | |
sign = '+' if t < 0 else '' | |
return f"{sign}{abs(t) // 60:01.0f}:{abs(t) % 60:02.0f}" | |
timer_text = [format_timer(t.item()) for t in timer.values] | |
hovertemplate = '<b>Frame %{x}</b><br>Prob: %{y:.3g}%<br>Timer: %{customdata}<extra></extra>' | |
# Add traces for Blue, Orange, and Tie probabilities from the DataFrame | |
fig.add_trace( | |
go.Scatter(x=preds_df.index, y=preds_df["Blue"], | |
mode='lines', name='Blue', line=dict(color='blue'), | |
customdata=timer_text, hovertemplate=hovertemplate)) | |
fig.add_trace( | |
go.Scatter(x=preds_df.index, y=preds_df["Orange"], | |
mode='lines', name='Orange', line=dict(color='orange'), | |
customdata=timer_text, hovertemplate=hovertemplate)) | |
if "Tie" in preds.columns: | |
fig.add_trace( | |
go.Scatter(x=preds_df.index, y=preds_df["Tie"], | |
mode='lines', name='Tie', line=dict(color='gray'), | |
customdata=timer_text, hovertemplate=hovertemplate)) | |
# Add the horizontal line at y=50% | |
fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability") | |
# Add goal indicators | |
b = o = 0 | |
for goal_frame in preds["Goal"].index[preds["Goal"] != 0]: | |
if preds["Goal"][goal_frame] > 0: | |
b += 1 | |
elif preds["Goal"][goal_frame] < 0: | |
o += 1 | |
fig.add_vline(x=goal_frame, line_dash="dash", line_color="red", | |
annotation_text=f"{b}-{o}", annotation_position="top right") | |
# Add touch indicators as points | |
touches = {} | |
for touch_frame in preds.index[preds["Touch"] != ""]: | |
teams_players = preds.at[touch_frame, "Touch"].split('|') | |
for team, player in zip(teams_players[::2], teams_players[1::2]): | |
team = team.strip() | |
player = player.strip() | |
touches.setdefault(team, []).append((touch_frame, player)) | |
for team in "Blue", "Orange": | |
team_touches = touches.get(team, []) | |
if not team_touches: | |
continue | |
x = [t[0] for t in team_touches] | |
y = [preds_df.at[t[0], team] for t in team_touches] | |
touch_players = [t[1] for t in team_touches] | |
custom_data = [f"{timer_text[f]}<br>Touch by {p}" | |
for f, p in zip(x, touch_players)] | |
fig.add_trace( | |
go.Scatter(x=x, y=y, | |
mode='markers', | |
name=f'{team} touches', | |
marker=dict(size=5, color=team.lower(), symbol='circle-open-dot'), | |
customdata=custom_data, | |
hovertemplate=hovertemplate | |
)) | |
# Define the formatting function for the secondary x-axis labels | |
def format_timer_ticks(x): | |
"""Converts a frame number to a formatted time string.""" | |
x = int(x) | |
# Ensure the index is within the bounds of the timer series | |
x = max(0, min(x, len(timer) - 1)) | |
# Calculate the time value | |
t = timer.iloc[x] * 300 | |
# Format the time as MM:SS, with a '+' for negative values (representing overtime) | |
sign = '+' if t < 0 else '' | |
minutes = int(abs(t) // 60) | |
seconds = int(abs(t) % 60) | |
return f"{sign}{minutes:01}:{seconds:02}" | |
# Generate positions and labels for the secondary axis ticks | |
# Creates 10 evenly spaced ticks for clarity | |
tick_positions = np.linspace(0, len(preds_df) - 1, 10) | |
tick_labels = [format_timer_ticks(val) for val in tick_positions] | |
# Configure the figure's layout, titles, and both x-axes | |
fig.update_layout( | |
title="Interactive Probability Plot", | |
xaxis=dict( | |
title="Frame", | |
gridcolor='#e5e7eb' # A light gray grid for a modern look | |
), | |
yaxis=dict( | |
title="Probability", | |
gridcolor='#e5e7eb' | |
), | |
# --- Secondary X-Axis Configuration --- | |
xaxis2=dict( | |
title="Timer", | |
overlaying='x', # This makes it a secondary axis | |
side='top', # Position it at the top | |
tickmode='array', | |
tickvals=tick_positions, | |
ticktext=tick_labels | |
), | |
legend=dict(x=0.01, y=0.99, yanchor="top", xanchor="left"), # Position legend inside plot | |
plot_bgcolor='white' # A clean white background | |
) | |
# fig.show() | |
return fig | |
DESCRIPTION = """ | |
# Next Goal Predictor | |
Upload a replay file to get a plot of the next goal prediction. | |
The model is trained on about 14 000 hours of SSL and RLCS replays in 1v1, 2v2, and 3v3 using [this dataset](https://www.kaggle.com/datasets/rolvarild/high-level-rocket-league-replay-dataset).<br> | |
It predicts the probability that each team will score at 1 second intervals up to 60+ seconds. | |
It also predicts ties (ball hitting the ground at 0s)<br> | |
The plot only shows the totals for each team, but you can download the full predictions if you want. | |
""".strip() | |
RADIO_OPTIONS = ["Default", "Nullify goal difference", "Ignore ties"] | |
RADIO_INFO = """ | |
- **Default**: Uses the model as it is trained, with no modifications. | |
- **Nullify goal difference**: Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team. | |
- **Ignore ties**: Makes the model pretend every situation is an overtime (e.g. ties are impossible). | |
""".strip() | |
with TemporaryDirectory() as temp_dir: | |
with gr.Blocks() as demo: | |
gr.Markdown(DESCRIPTION) | |
# Use gr.Column to stack components vertically | |
with gr.Column(): | |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"]) | |
checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=RADIO_OPTIONS[0], | |
info=RADIO_INFO) | |
submit_button = gr.Button("Generate Predictions") | |
plot_output = gr.Plot(label="Predictions") | |
download_button = gr.DownloadButton("Download Predictions", visible=False) | |
def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)): | |
# Make plot on button click | |
nullify_goal_difference = radio_option == 1 | |
ignore_ties = radio_option == 2 | |
print(f"Processing file: {replay_file}") | |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0] | |
postfix = "" | |
if nullify_goal_difference: | |
postfix += "_nullify_goal_difference" | |
elif ignore_ties: | |
postfix += "_ignore_ties" | |
preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}{postfix}.csv") | |
if os.path.exists(preds_file): | |
print(f"Predictions file already exists: {preds_file}") | |
preds = pd.read_csv(preds_file, dtype={"Touch": str}) | |
preds["Touch"] = preds["Touch"].fillna("") | |
else: | |
preds = infer(MODEL, replay_file, | |
nullify_goal_difference=nullify_goal_difference, | |
ignore_ties=ignore_ties) | |
plt = plot_plotly(preds) | |
print(f"Plot generated for file: {replay_file}") | |
preds.to_csv(preds_file) | |
if len(os.listdir(temp_dir)) > 100: | |
# Delete least recent file | |
oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f))) | |
os.remove(os.path.join(temp_dir, oldest_file)) | |
return plt, gr.DownloadButton(value=preds_file, visible=True) | |
submit_button.click( | |
fn=make_plot, | |
inputs=[file_input, checkboxes], | |
outputs=[plot_output, download_button], | |
show_progress="full", | |
) | |
demo.queue(default_concurrency_limit=None) | |
demo.launch() | |