""" Gradio demo – visualise benchmark accuracy curves. Required CSV files (place in the *same* folder as app.py): ├── aggregated_accuracy.csv ├── qa_accuracy.csv ├── ocr_accuracy.csv └── temporal_accuracy.csv Each file has the columns Model,,,… where the context‑length headers are strings such as `30min`, `60min`, `120min`, … No further cleaning / renaming is done apart from two cosmetic replacements (“gpt4.1” → “ChatGPT 4.1”, “gemini2.5pro” → “Gemini 2.5 Pro”). """ from pathlib import Path import pandas as pd import plotly.graph_objects as go import gradio as gr import math # --------------------------------------------------------------------- # # Config # # --------------------------------------------------------------------- # FILES = { "aggregated": "aggregated_accuracy.csv", "qa": "qa_accuracy.csv", "ocr": "ocr_accuracy.csv", "temporal": "temporal_accuracy.csv", } # Mapping of internal benchmark keys to nicely formatted display labels DISPLAY_LABELS = { "aggregated": "Aggregated", "qa": "QA", "ocr": "OCR", "temporal": "Temporal", } # Optional: choose which models are selected by default for each benchmark. # Use the *display names* exactly as they appear in the Models list. # If a benchmark is missing, it falls back to the first six models. DEFAULT_MODELS: dict[str, list[str]] = { "aggregated": [ "Gemini 2.5 Pro", "ChatGPT 4.1", "Qwen2.5-VL-7B", "InternVL2.5-8B", "LLaMA-3.2-11B-Vision", ], } RENAME = { r"gpt4\.1": "ChatGPT 4.1", r"Gemini\s2\.5\spro": "Gemini 2.5 Pro", r"LLaMA-3\.2B-11B": "LLaMA-3.2-11B-Vision", } # --------------------------------------------------------------------- # # Data loading # # --------------------------------------------------------------------- # def _read_csv(path: str | Path) -> pd.DataFrame: df = pd.read_csv(path) df["Model"] = df["Model"].replace(RENAME, regex=True).astype(str) return df dfs: dict[str, pd.DataFrame] = {name: _read_csv(path) for name, path in FILES.items()} # --------------------------------------------------------------------- # # Colour palette and model metadata # # --------------------------------------------------------------------- # import plotly.express as px SAFE_PALETTE = px.colors.qualitative.Safe # colour-blind-safe qualitative palette (10 colours) # Deterministic list of all unique model names to ensure consistent colour mapping ALL_MODELS: list[str] = sorted({m for df in dfs.values() for m in df["Model"].unique()}) MARKER_SYMBOLS = [ "circle", "square", "triangle-up", "diamond", "cross", "triangle-down", "x", "triangle-right", "triangle-left", "pentagon", ] TIME_COLS = [c for c in dfs["aggregated"].columns if c.lower() != "model"] def _pretty_time(label: str) -> str: """‘30min’ → ‘30min’; ‘120min’ → ‘2hr’; keeps original if no match.""" if label.endswith("min"): minutes = int(label[:-3]) if minutes >= 60: hours = minutes / 60 return f"{hours:.0f}hr" if hours.is_integer() else f"{hours:.1f}hr" return label TIME_LABELS = {c: _pretty_time(c) for c in TIME_COLS} # --------------------------------------------------------------------- # # Plotting # # --------------------------------------------------------------------- # def render_chart( benchmark: str, models: list[str], log_scale: bool, ) -> go.Figure: bench_key = benchmark.lower() df = dfs[bench_key] fig = go.Figure() # Define colour and marker based on deterministic mapping palette = SAFE_PALETTE # Determine minimum non-zero Y value across selected models for log scaling min_y_val = None for idx, m in enumerate(models): row = df.loc[df["Model"] == m] if row.empty: continue y = row[TIME_COLS].values.flatten() y = [val if val != 0 else None for val in y] # show gaps for 0 / missing # Track minimum non-zero accuracy y_non_none = [val for val in y if val is not None] if y_non_none: cur_min = min(y_non_none) if min_y_val is None or cur_min < min_y_val: min_y_val = cur_min model_idx = ALL_MODELS.index(m) if m in ALL_MODELS else idx color = palette[model_idx % len(palette)] symbol = MARKER_SYMBOLS[model_idx % len(MARKER_SYMBOLS)] fig.add_trace( go.Scatter( x=[TIME_LABELS[c] for c in TIME_COLS], y=y, mode="lines+markers", name=m, line=dict(width=3, color=color), marker=dict(size=6, color=color, symbol=symbol), connectgaps=False, ) ) # Set Y-axis properties if log_scale: # Fallback to 0.1 if there are no valid points if min_y_val is None or min_y_val <= 0: min_y_val = 0.1 # Plotly expects log10 values for range when axis type is "log" yaxis_range = [math.floor(math.log10(min_y_val)), 2] # max at 10^2 = 100 yaxis_type = "log" else: yaxis_range = [0, 100] yaxis_type = "linear" fig.update_layout( title=f"{DISPLAY_LABELS.get(bench_key, bench_key.capitalize())} Accuracy Over Time", xaxis_title="Video Duration", yaxis_title="Accuracy (%)", yaxis_type=yaxis_type, yaxis_range=yaxis_range, legend_title="Model", legend=dict( orientation="h", y=-0.25, x=0.5, xanchor="center", tracegroupgap=8, itemwidth=60, ), margin=dict(t=40, r=20, b=80, l=60), template="plotly_dark", font=dict(family="Inter,Helvetica,Arial,sans-serif", size=14), title_font=dict(size=20, family="Inter,Helvetica,Arial,sans-serif", color="white"), xaxis=dict(gridcolor="rgba(255,255,255,0.15)"), yaxis=dict(gridcolor="rgba(255,255,255,0.15)"), hoverlabel=dict(bgcolor="#1e1e1e", font_color="#eeeeee", bordercolor="#888"), ) return fig # --------------------------------------------------------------------- # # UI # # --------------------------------------------------------------------- # CSS = """ #controls { padding: 8px 12px; } .scrollbox { max-height: 300px; overflow-y: auto; } body, .gradio-container { font-family: 'Inter', 'Helvetica', sans-serif; } .gradio-container h1, .gradio-container h2 { font-weight: 600; } #controls, .scrollbox { background: rgba(255,255,255,0.02); border-radius: 6px; } input[type="checkbox"]:checked { accent-color: #FF715E; } """ def available_models(bench: str) -> list[str]: return sorted(dfs[bench]["Model"].unique()) def default_models(bench: str) -> list[str]: """Return list of default-selected models for a benchmark.""" opts = available_models(bench) configured = DEFAULT_MODELS.get(bench, []) # Keep only those present in opts valid = [m for m in configured if m in opts] if not valid: # Fall back to first six valid = opts[:6] return valid with gr.Blocks(theme=gr.themes.Base(), css=CSS) as demo: gr.Markdown( """ # 📈 TimeScope How long can your video model keep up? """ ) # ---- top controls row ---- # with gr.Row(): benchmark_dd = gr.Dropdown( label="Type", choices=list(DISPLAY_LABELS.values()), value=DISPLAY_LABELS["aggregated"], scale=1, ) log_cb = gr.Checkbox( label="Log-scale Y-axis", value=False, scale=1, ) # ---- models list and plot ---- # plot_out = gr.Plot( render_chart("Aggregated", default_models("aggregated"), False) ) models_cb = gr.CheckboxGroup( label="Models", choices=available_models("aggregated"), value=default_models("aggregated"), interactive=True, elem_classes=["scrollbox"], ) # ‑-- dynamic callbacks ‑-- # def _update_models(bench: str): bench_key = bench.lower() opts = available_models(bench_key) defaults = default_models(bench_key) # Use generic gr.update for compatibility across Gradio versions return gr.update(choices=opts, value=defaults) benchmark_dd.change( fn=_update_models, inputs=benchmark_dd, outputs=models_cb, queue=False, ) for ctrl in (benchmark_dd, models_cb, log_cb): ctrl.change( fn=render_chart, inputs=[benchmark_dd, models_cb, log_cb], outputs=plot_out, queue=False, ) # Make legend interaction clearer: click to toggle traces demo.launch(share=True)