Spaces:
Running
Running
import os.path | |
import gradio as gr | |
import pandas as pd | |
from constants import * | |
# ------------ 下载链接 ------------ | |
def get_download_link_model(task, dataset, example): | |
_task_path = TASK_PATH_MAPPING[task] | |
_dataset_path = DATASET_PATH_MAPPING[dataset] | |
_example_path = EXAMPLE_PATH_MAPPING[example] | |
return os.path.join("data", _task_path, _dataset_path, "weight", f"{_example_path}.zip") | |
def get_download_link_json(task, dataset, example): | |
_task_path = TASK_PATH_MAPPING[task] | |
_dataset_path = DATASET_PATH_MAPPING[dataset] | |
_example_path = EXAMPLE_PATH_MAPPING[example] | |
if _task_path == "common": | |
return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.jsonl") | |
else: | |
return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.json") | |
# ------------ 数据读取 + 平均准确率 ------------ | |
def get_data(task, dataset, example): | |
_task_path = TASK_PATH_MAPPING[task] | |
_dataset_path = DATASET_PATH_MAPPING[dataset] | |
_example_path = EXAMPLE_PATH_MAPPING[example] | |
csv_file = os.path.join("data", _task_path, _dataset_path, "csv", f"{_example_path}.csv") | |
if not os.path.exists(csv_file): | |
return None, None | |
read_data = pd.read_csv(csv_file) | |
data = pd.DataFrame(columns=COLUMN_NAMES) | |
average_acc = None | |
if _task_path == "coding": | |
for _, row in read_data.iterrows(): | |
data = pd.concat([data, pd.DataFrame([{ | |
"Prompt": row["prompt"], | |
"Pass@1": round(float(row["pass@1"]) * 100, 3), | |
"Pass@5": round(float(row["pass@5"]) * 100, 3), | |
"Pass@10": round(float(row["pass@10"]) * 100, 3), | |
"Correctness": "N/A" | |
}])], ignore_index=True) | |
# 仅对 HumanEval 数据集计算三列平均 | |
if "HumanEval" in dataset: | |
p1_mean = round(read_data["pass@1"].mean() * 100, 3) | |
p5_mean = round(read_data["pass@5"].mean() * 100, 3) | |
p10_mean = round(read_data["pass@10"].mean() * 100, 3) | |
average_acc = f"{p1_mean} / {p5_mean} / {p10_mean}" | |
elif _task_path in ["common", "math"]: | |
for _, row in read_data.iterrows(): | |
data = pd.concat([data, pd.DataFrame([{ | |
"Prompt": row["prompt"], | |
"Pass@1": None, | |
"Pass@5": None, | |
"Pass@10": None, | |
"Correctness": "✅" if row["correctness"] else "❌" | |
}])], ignore_index=True) | |
average_acc = round(read_data["correctness"].mean() * 100, 3) | |
return data, average_acc | |
# ------------ Gradio UI ------------ | |
with gr.Blocks() as demo_board: | |
gr.HTML(DND_HEADER) | |
gr.Markdown(DND_INTRODUCTION) | |
task = gr.Radio( | |
label="Task", | |
choices=TASK_LIST, | |
value=TASK_LIST[0], | |
interactive=True, | |
) | |
dataset = gr.Radio( | |
label="Dataset", | |
choices=TASK_DATASET_LIST[task.value], | |
value=TASK_DATASET_LIST[task.value][0], | |
interactive=True | |
) | |
example = gr.Radio( | |
label="Example", | |
choices=EXAMPLE_LIST, | |
value=EXAMPLE_LIST[0], | |
interactive=True, | |
) | |
# 平均准确率(放在 Prompt 表格上方) | |
average_acc_display = gr.Textbox( | |
label="Average Accuracy (%)", | |
value=lambda: str(get_data(task.value, dataset.value, example.value)[1]), | |
interactive=False, | |
visible=True, | |
scale=0, | |
max_lines=1, | |
min_width=160 | |
) | |
# Prompt 表格 | |
board = gr.components.Dataframe( | |
value=lambda: get_data(task.value, dataset.value, example.value)[0], | |
column_widths=["60%", "10%", "10%", "10%", "10%"], | |
headers=COLUMN_NAMES, | |
type="pandas", | |
datatype=DATA_TITLE_TYPE, | |
interactive=False, | |
visible=True, | |
max_height=500, | |
) | |
# 联动更新:task -> dataset | |
task.change( | |
lambda t: gr.Radio( | |
label="Dataset", | |
choices=TASK_DATASET_LIST[t], | |
value=TASK_DATASET_LIST[t][0], | |
interactive=True, | |
), | |
inputs=[task], | |
outputs=dataset | |
) | |
# 联动更新:task / dataset / example -> 表格 + 平均准确率 | |
for component in [task, dataset, example]: | |
component.change( | |
lambda t, d, e: (get_data(t, d, e)[0], str(get_data(t, d, e)[1])), | |
inputs=[task, dataset, example], | |
outputs=[board, average_acc_display] | |
) | |
# 下载按钮 | |
with gr.Row(): | |
json_downloader = gr.DownloadButton("Download JSON", visible=True) | |
model_downloader = gr.DownloadButton("Download Model", visible=True) | |
json_downloader.click( | |
fn=get_download_link_json, | |
inputs=[task, dataset, example], | |
outputs=json_downloader, | |
) | |
model_downloader.click( | |
fn=get_download_link_model, | |
inputs=[task, dataset, example], | |
outputs=model_downloader, | |
) | |
# 引用文本 | |
citation_button = gr.Textbox( | |
value=CITATION_BUTTON_TEXT, | |
label=CITATION_BUTTON_LABEL, | |
elem_id="citation-button", | |
lines=6, | |
show_copy_button=True, | |
) | |
# 启动 | |
demo_board.launch() |