File size: 5,095 Bytes
d35349d 38ee4b2 d35349d b026a94 d35349d b026a94 d35349d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
import pandas as pd
from pathlib import Path
from src.json_leaderboard import create_leaderboard_df
from src.about import (
CITATION_BUTTON_TEXT,
INTRODUCTION_TEXT,
LINKS_AND_INFO,
TITLE,
)
from src.display.css_html_js import custom_css
# 固定列,永远在前面
FIXED_COLUMNS = ["Model Name (clickable)", "Release Date", "HF Model", "Open Source"]
def get_json_df():
"""Load the leaderboard DataFrame"""
json_path = Path(__file__).parent / "leaderboard_data.json"
df = create_leaderboard_df(str(json_path))
return df
# 提取大类及其子类
def extract_categories_and_subs(df):
"""
返回 {大类: {"overall": 大类列, "subs": [子类列]}}
大类列以 '-Overall' 结尾,紧跟其后的列为子类
"""
category_dict = {}
all_cols = list(df.columns)
skip_cols = set(FIXED_COLUMNS + ["Overall"])
i = 0
while i < len(all_cols):
col = all_cols[i]
if col.endswith("-Overall") and col not in skip_cols:
cat_name = col.replace("-Overall", "")
subs = []
j = i + 1
while j < len(all_cols):
next_col = all_cols[j]
if next_col.endswith("-Overall") or next_col in skip_cols:
break
subs.append(next_col)
j += 1
category_dict[cat_name] = {"overall": col, "subs": subs}
i += 1
return category_dict
# 列过滤函数,保持固定列 + 用户选择列 + 顺序不变
def filtered_leaderboard(df, selected_columns):
selected_columns = selected_columns or []
final_cols = FIXED_COLUMNS + [col for col in df.columns if col in selected_columns and col not in FIXED_COLUMNS]
return df[final_cols]
# Update functions
def update_leaderboard_overall(selected_cols, df_overall):
return filtered_leaderboard(df_overall, selected_cols)
def update_leaderboard_cat(selected_cols, df_cat):
return filtered_leaderboard(df_cat, selected_cols)
# 初始化
df = get_json_df()
ALL_COLUMNS_ORDERED = list(df.columns)
categories = extract_categories_and_subs(df)
# 可选列 = 全部列 - 固定列
optional_columns = [col for col in df.columns if col not in FIXED_COLUMNS]
# Gradio interface
demo = gr.Blocks(css=custom_css, title="UniGenBench Leaderboard")
with demo:
gr.HTML(TITLE)
gr.HTML(LINKS_AND_INFO)
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
# Overall leaderboard
with gr.TabItem("🏅 Overall Leaderboard", elem_id="tab-overall"):
selected_columns_overall = gr.CheckboxGroup(
choices=optional_columns,
label="Select additional columns to display",
value=optional_columns
)
leaderboard_table = gr.Dataframe(
value=df[ALL_COLUMNS_ORDERED],
headers=list(df.columns),
datatype=["html" if col in ["Model Name (clickable)","HF Model"] else "str" for col in df.columns],
interactive=False,
wrap=False
)
selected_columns_overall.change(
fn=update_leaderboard_overall,
inputs=[selected_columns_overall, gr.State(value=df)],
outputs=leaderboard_table
)
# 每个大类 leaderboard
for cat_name, info in categories.items():
with gr.TabItem(f"🏆 {cat_name}", elem_id=f"tab-{cat_name}"):
cat_cols = [info["overall"]] + info["subs"]
cat_df = df[FIXED_COLUMNS + cat_cols]
optional_columns_cat = [col for col in cat_cols if col not in FIXED_COLUMNS]
selected_columns_cat = gr.CheckboxGroup(
choices=optional_columns_cat,
label=f"Select additional columns for {cat_name}",
value=optional_columns_cat
)
leaderboard_table_cat = gr.Dataframe(
value=cat_df,
headers=list(cat_df.columns),
datatype=["html" if col in ["Model Name (clickable)","HF Model"] else "str" for col in cat_df.columns],
interactive=False,
wrap=False
)
selected_columns_cat.change(
fn=update_leaderboard_cat,
inputs=[selected_columns_cat, gr.State(value=cat_df)],
outputs=leaderboard_table_cat
)
# Citation
with gr.Row():
with gr.Column():
gr.Markdown("## 📙 Citation")
gr.Markdown("If you use [UniGenBench]() in your research, please cite our work:")
citation_textbox = gr.Textbox(
value=CITATION_BUTTON_TEXT,
elem_id="citation-textbox",
show_label=False,
interactive=False,
lines=8,
show_copy_button=True
)
if __name__ == "__main__":
demo.launch()
|