File size: 9,306 Bytes
4dca8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""
Gradio demo – visualise benchmark accuracy curves.

Required CSV files (place in the *same* folder as app.py):

    ├── aggregated_accuracy.csv
    ├── qa_accuracy.csv
    ├── ocr_accuracy.csv
    └── temporal_accuracy.csv

Each file has the columns

    Model,<context‑length‑1>,<context‑length‑2>,…

where the context‑length headers are strings such as `30min`, `60min`, `120min`, …

No further cleaning / renaming is done apart from two cosmetic replacements
(“gpt4.1” → “ChatGPT 4.1”, “gemini2.5pro” → “Gemini 2.5 Pro”).
"""

from pathlib import Path

import pandas as pd
import plotly.graph_objects as go
import gradio as gr
import math

# --------------------------------------------------------------------- #
# Config                                                                #
# --------------------------------------------------------------------- #

FILES = {
    "aggregated": "aggregated_accuracy.csv",
    "qa": "qa_accuracy.csv",
    "ocr": "ocr_accuracy.csv",
    "temporal": "temporal_accuracy.csv",
}

# Mapping of internal benchmark keys to nicely formatted display labels
DISPLAY_LABELS = {
    "aggregated": "Aggregated",
    "qa": "QA",
    "ocr": "OCR",
    "temporal": "Temporal",
}

# Optional: choose which models are selected by default for each benchmark.
# Use the *display names* exactly as they appear in the Models list.
# If a benchmark is missing, it falls back to the first six models.
DEFAULT_MODELS: dict[str, list[str]] = {
    "aggregated": [
        "Gemini 2.5 Pro",
        "ChatGPT 4.1",
        "Qwen2.5-VL-7B",
        "InternVL2.5-8B",
        "LLaMA-3.2-11B-Vision",
    ],
}

RENAME = {
    r"gpt4\.1":       "ChatGPT 4.1",
    r"Gemini\s2\.5\spro": "Gemini 2.5 Pro",
    r"LLaMA-3\.2B-11B": "LLaMA-3.2-11B-Vision",
}

# --------------------------------------------------------------------- #
# Data loading                                                          #
# --------------------------------------------------------------------- #

def _read_csv(path: str | Path) -> pd.DataFrame:
    df = pd.read_csv(path)
    df["Model"] = df["Model"].replace(RENAME, regex=True).astype(str)
    return df

dfs: dict[str, pd.DataFrame] = {name: _read_csv(path) for name, path in FILES.items()}

# --------------------------------------------------------------------- #
# Colour palette and model metadata                                     #
# --------------------------------------------------------------------- #

import plotly.express as px

SAFE_PALETTE = px.colors.qualitative.Safe  # colour-blind-safe qualitative palette (10 colours)

# Deterministic list of all unique model names to ensure consistent colour mapping
ALL_MODELS: list[str] = sorted({m for df in dfs.values() for m in df["Model"].unique()})

MARKER_SYMBOLS = [
    "circle",
    "square",
    "triangle-up",
    "diamond",
    "cross",
    "triangle-down",
    "x",
    "triangle-right",
    "triangle-left",
    "pentagon",
]

TIME_COLS = [c for c in dfs["aggregated"].columns if c.lower() != "model"]


def _pretty_time(label: str) -> str:
    """‘30min’ → ‘30min’; ‘120min’ → ‘2hr’; keeps original if no match."""
    if label.endswith("min"):
        minutes = int(label[:-3])
        if minutes >= 60:
            hours = minutes / 60
            return f"{hours:.0f}hr" if hours.is_integer() else f"{hours:.1f}hr"
    return label


TIME_LABELS = {c: _pretty_time(c) for c in TIME_COLS}

# --------------------------------------------------------------------- #
# Plotting                                                              #
# --------------------------------------------------------------------- #

def render_chart(
    benchmark: str,
    models: list[str],
    log_scale: bool,
) -> go.Figure:
    bench_key = benchmark.lower()
    df = dfs[bench_key]
    fig = go.Figure()

    # Define colour and marker based on deterministic mapping
    palette = SAFE_PALETTE

    # Determine minimum non-zero Y value across selected models for log scaling
    min_y_val = None

    for idx, m in enumerate(models):
        row = df.loc[df["Model"] == m]
        if row.empty:
            continue
        y = row[TIME_COLS].values.flatten()
        y = [val if val != 0 else None for val in y]   # show gaps for 0 / missing

        # Track minimum non-zero accuracy
        y_non_none = [val for val in y if val is not None]
        if y_non_none:
            cur_min = min(y_non_none)
            if min_y_val is None or cur_min < min_y_val:
                min_y_val = cur_min

        model_idx = ALL_MODELS.index(m) if m in ALL_MODELS else idx
        color = palette[model_idx % len(palette)]
        symbol = MARKER_SYMBOLS[model_idx % len(MARKER_SYMBOLS)]
        fig.add_trace(
            go.Scatter(
                x=[TIME_LABELS[c] for c in TIME_COLS],
                y=y,
                mode="lines+markers",
                name=m,
                line=dict(width=3, color=color),
                marker=dict(size=6, color=color, symbol=symbol),
                connectgaps=False,
            )
        )

    # Set Y-axis properties
    if log_scale:
        # Fallback to 0.1 if there are no valid points
        if min_y_val is None or min_y_val <= 0:
            min_y_val = 0.1
        # Plotly expects log10 values for range when axis type is "log"
        yaxis_range = [math.floor(math.log10(min_y_val)), 2]  # max at 10^2 = 100
        yaxis_type = "log"
    else:
        yaxis_range = [0, 100]
        yaxis_type = "linear"

    fig.update_layout(
         title=f"{DISPLAY_LABELS.get(bench_key, bench_key.capitalize())} Accuracy Over Time",
         xaxis_title="Video Duration",
         yaxis_title="Accuracy (%)",
         yaxis_type=yaxis_type,
         yaxis_range=yaxis_range,
         legend_title="Model",
        legend=dict(
            orientation="h",
            y=-0.25,
            x=0.5,
            xanchor="center",
            tracegroupgap=8,
            itemwidth=60,
        ),
        margin=dict(t=40, r=20, b=80, l=60),
        template="plotly_dark",
        font=dict(family="Inter,Helvetica,Arial,sans-serif", size=14),
        title_font=dict(size=20, family="Inter,Helvetica,Arial,sans-serif", color="white"),
        xaxis=dict(gridcolor="rgba(255,255,255,0.15)"),
        yaxis=dict(gridcolor="rgba(255,255,255,0.15)"),
        hoverlabel=dict(bgcolor="#1e1e1e", font_color="#eeeeee", bordercolor="#888"),
    )
    return fig


# --------------------------------------------------------------------- #
# UI                                                                    #
# --------------------------------------------------------------------- #

CSS = """
#controls {
    padding: 8px 12px;
}
.scrollbox {
    max-height: 300px;
    overflow-y: auto;
}
body, .gradio-container {
    font-family: 'Inter', 'Helvetica', sans-serif;
}
.gradio-container h1, .gradio-container h2 {
   font-weight: 600;
}

#controls, .scrollbox {
    background: rgba(255,255,255,0.02);
    border-radius: 6px;
}

input[type="checkbox"]:checked {
    accent-color: #FF715E;
}
"""

def available_models(bench: str) -> list[str]:
    return sorted(dfs[bench]["Model"].unique())


def default_models(bench: str) -> list[str]:
    """Return list of default-selected models for a benchmark."""
    opts = available_models(bench)
    configured = DEFAULT_MODELS.get(bench, [])
    # Keep only those present in opts
    valid = [m for m in configured if m in opts]
    if not valid:
        # Fall back to first six
        valid = opts[:6]
    return valid


with gr.Blocks(theme=gr.themes.Base(), css=CSS) as demo:
    gr.Markdown(
        """
        # 📈 TimeScope
        
        How long can your video model keep up?
        """
    )

    # ---- top controls row ---- #
    with gr.Row():
        benchmark_dd = gr.Dropdown(
            label="Type",
            choices=list(DISPLAY_LABELS.values()),
            value=DISPLAY_LABELS["aggregated"],
            scale=1,
        )
        log_cb = gr.Checkbox(
            label="Log-scale Y-axis",
            value=False,
            scale=1,
        )

    # ---- models list and plot ---- #
    plot_out = gr.Plot(
        render_chart("Aggregated", default_models("aggregated"), False)
    )

    models_cb = gr.CheckboxGroup(
        label="Models",
        choices=available_models("aggregated"),
        value=default_models("aggregated"),
        interactive=True,
        elem_classes=["scrollbox"],
    )

    # ‑-- dynamic callbacks ‑-- #
    def _update_models(bench: str):
        bench_key = bench.lower()
        opts = available_models(bench_key)
        defaults = default_models(bench_key)
        # Use generic gr.update for compatibility across Gradio versions
        return gr.update(choices=opts, value=defaults)

    benchmark_dd.change(
        fn=_update_models,
        inputs=benchmark_dd,
        outputs=models_cb,
        queue=False,
    )

    for ctrl in (benchmark_dd, models_cb, log_cb):
        ctrl.change(
            fn=render_chart,
            inputs=[benchmark_dd, models_cb, log_cb],
            outputs=plot_out,
            queue=False,
        )

    # Make legend interaction clearer: click to toggle traces

demo.launch(share=True)