zhihengchen's picture
Upload app.py
6f7189c verified
import os.path
import gradio as gr
import pandas as pd
from constants import *
# ------------ 下载链接 ------------
def get_download_link_model(task, dataset, example):
_task_path = TASK_PATH_MAPPING[task]
_dataset_path = DATASET_PATH_MAPPING[dataset]
_example_path = EXAMPLE_PATH_MAPPING[example]
return os.path.join("data", _task_path, _dataset_path, "weight", f"{_example_path}.zip")
def get_download_link_json(task, dataset, example):
_task_path = TASK_PATH_MAPPING[task]
_dataset_path = DATASET_PATH_MAPPING[dataset]
_example_path = EXAMPLE_PATH_MAPPING[example]
if _task_path == "common":
return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.jsonl")
else:
return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.json")
# ------------ 数据读取 + 平均准确率 ------------
def get_data(task, dataset, example):
_task_path = TASK_PATH_MAPPING[task]
_dataset_path = DATASET_PATH_MAPPING[dataset]
_example_path = EXAMPLE_PATH_MAPPING[example]
csv_file = os.path.join("data", _task_path, _dataset_path, "csv", f"{_example_path}.csv")
if not os.path.exists(csv_file):
return None, None
read_data = pd.read_csv(csv_file)
data = pd.DataFrame(columns=COLUMN_NAMES)
average_acc = None
if _task_path == "coding":
for _, row in read_data.iterrows():
data = pd.concat([data, pd.DataFrame([{
"Prompt": row["prompt"],
"Pass@1": round(float(row["pass@1"]) * 100, 3),
"Pass@5": round(float(row["pass@5"]) * 100, 3),
"Pass@10": round(float(row["pass@10"]) * 100, 3),
"Correctness": "N/A"
}])], ignore_index=True)
# 仅对 HumanEval 数据集计算三列平均
if "HumanEval" in dataset:
p1_mean = round(read_data["pass@1"].mean() * 100, 3)
p5_mean = round(read_data["pass@5"].mean() * 100, 3)
p10_mean = round(read_data["pass@10"].mean() * 100, 3)
average_acc = f"{p1_mean} / {p5_mean} / {p10_mean}"
elif _task_path in ["common", "math"]:
for _, row in read_data.iterrows():
data = pd.concat([data, pd.DataFrame([{
"Prompt": row["prompt"],
"Pass@1": None,
"Pass@5": None,
"Pass@10": None,
"Correctness": "✅" if row["correctness"] else "❌"
}])], ignore_index=True)
average_acc = round(read_data["correctness"].mean() * 100, 3)
return data, average_acc
# ------------ Gradio UI ------------
with gr.Blocks() as demo_board:
gr.HTML(DND_HEADER)
gr.Markdown(DND_INTRODUCTION)
task = gr.Radio(
label="Task",
choices=TASK_LIST,
value=TASK_LIST[0],
interactive=True,
)
dataset = gr.Radio(
label="Dataset",
choices=TASK_DATASET_LIST[task.value],
value=TASK_DATASET_LIST[task.value][0],
interactive=True
)
example = gr.Radio(
label="Example",
choices=EXAMPLE_LIST,
value=EXAMPLE_LIST[0],
interactive=True,
)
# 平均准确率(放在 Prompt 表格上方)
average_acc_display = gr.Textbox(
label="Average Accuracy (%)",
value=lambda: str(get_data(task.value, dataset.value, example.value)[1]),
interactive=False,
visible=True,
scale=0,
max_lines=1,
min_width=160
)
# Prompt 表格
board = gr.components.Dataframe(
value=lambda: get_data(task.value, dataset.value, example.value)[0],
column_widths=["60%", "10%", "10%", "10%", "10%"],
headers=COLUMN_NAMES,
type="pandas",
datatype=DATA_TITLE_TYPE,
interactive=False,
visible=True,
max_height=500,
)
# 联动更新:task -> dataset
task.change(
lambda t: gr.Radio(
label="Dataset",
choices=TASK_DATASET_LIST[t],
value=TASK_DATASET_LIST[t][0],
interactive=True,
),
inputs=[task],
outputs=dataset
)
# 联动更新:task / dataset / example -> 表格 + 平均准确率
for component in [task, dataset, example]:
component.change(
lambda t, d, e: (get_data(t, d, e)[0], str(get_data(t, d, e)[1])),
inputs=[task, dataset, example],
outputs=[board, average_acc_display]
)
# 下载按钮
with gr.Row():
json_downloader = gr.DownloadButton("Download JSON", visible=True)
model_downloader = gr.DownloadButton("Download Model", visible=True)
json_downloader.click(
fn=get_download_link_json,
inputs=[task, dataset, example],
outputs=json_downloader,
)
model_downloader.click(
fn=get_download_link_model,
inputs=[task, dataset, example],
outputs=model_downloader,
)
# 引用文本
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id="citation-button",
lines=6,
show_copy_button=True,
)
# 启动
demo_board.launch()