Spaces:
Running
Running
File size: 5,263 Bytes
6f7189c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os.path
import gradio as gr
import pandas as pd
from constants import *
# ------------ 下载链接 ------------
def get_download_link_model(task, dataset, example):
_task_path = TASK_PATH_MAPPING[task]
_dataset_path = DATASET_PATH_MAPPING[dataset]
_example_path = EXAMPLE_PATH_MAPPING[example]
return os.path.join("data", _task_path, _dataset_path, "weight", f"{_example_path}.zip")
def get_download_link_json(task, dataset, example):
_task_path = TASK_PATH_MAPPING[task]
_dataset_path = DATASET_PATH_MAPPING[dataset]
_example_path = EXAMPLE_PATH_MAPPING[example]
if _task_path == "common":
return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.jsonl")
else:
return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.json")
# ------------ 数据读取 + 平均准确率 ------------
def get_data(task, dataset, example):
_task_path = TASK_PATH_MAPPING[task]
_dataset_path = DATASET_PATH_MAPPING[dataset]
_example_path = EXAMPLE_PATH_MAPPING[example]
csv_file = os.path.join("data", _task_path, _dataset_path, "csv", f"{_example_path}.csv")
if not os.path.exists(csv_file):
return None, None
read_data = pd.read_csv(csv_file)
data = pd.DataFrame(columns=COLUMN_NAMES)
average_acc = None
if _task_path == "coding":
for _, row in read_data.iterrows():
data = pd.concat([data, pd.DataFrame([{
"Prompt": row["prompt"],
"Pass@1": round(float(row["pass@1"]) * 100, 3),
"Pass@5": round(float(row["pass@5"]) * 100, 3),
"Pass@10": round(float(row["pass@10"]) * 100, 3),
"Correctness": "N/A"
}])], ignore_index=True)
# 仅对 HumanEval 数据集计算三列平均
if "HumanEval" in dataset:
p1_mean = round(read_data["pass@1"].mean() * 100, 3)
p5_mean = round(read_data["pass@5"].mean() * 100, 3)
p10_mean = round(read_data["pass@10"].mean() * 100, 3)
average_acc = f"{p1_mean} / {p5_mean} / {p10_mean}"
elif _task_path in ["common", "math"]:
for _, row in read_data.iterrows():
data = pd.concat([data, pd.DataFrame([{
"Prompt": row["prompt"],
"Pass@1": None,
"Pass@5": None,
"Pass@10": None,
"Correctness": "✅" if row["correctness"] else "❌"
}])], ignore_index=True)
average_acc = round(read_data["correctness"].mean() * 100, 3)
return data, average_acc
# ------------ Gradio UI ------------
with gr.Blocks() as demo_board:
gr.HTML(DND_HEADER)
gr.Markdown(DND_INTRODUCTION)
task = gr.Radio(
label="Task",
choices=TASK_LIST,
value=TASK_LIST[0],
interactive=True,
)
dataset = gr.Radio(
label="Dataset",
choices=TASK_DATASET_LIST[task.value],
value=TASK_DATASET_LIST[task.value][0],
interactive=True
)
example = gr.Radio(
label="Example",
choices=EXAMPLE_LIST,
value=EXAMPLE_LIST[0],
interactive=True,
)
# 平均准确率(放在 Prompt 表格上方)
average_acc_display = gr.Textbox(
label="Average Accuracy (%)",
value=lambda: str(get_data(task.value, dataset.value, example.value)[1]),
interactive=False,
visible=True,
scale=0,
max_lines=1,
min_width=160
)
# Prompt 表格
board = gr.components.Dataframe(
value=lambda: get_data(task.value, dataset.value, example.value)[0],
column_widths=["60%", "10%", "10%", "10%", "10%"],
headers=COLUMN_NAMES,
type="pandas",
datatype=DATA_TITLE_TYPE,
interactive=False,
visible=True,
max_height=500,
)
# 联动更新:task -> dataset
task.change(
lambda t: gr.Radio(
label="Dataset",
choices=TASK_DATASET_LIST[t],
value=TASK_DATASET_LIST[t][0],
interactive=True,
),
inputs=[task],
outputs=dataset
)
# 联动更新:task / dataset / example -> 表格 + 平均准确率
for component in [task, dataset, example]:
component.change(
lambda t, d, e: (get_data(t, d, e)[0], str(get_data(t, d, e)[1])),
inputs=[task, dataset, example],
outputs=[board, average_acc_display]
)
# 下载按钮
with gr.Row():
json_downloader = gr.DownloadButton("Download JSON", visible=True)
model_downloader = gr.DownloadButton("Download Model", visible=True)
json_downloader.click(
fn=get_download_link_json,
inputs=[task, dataset, example],
outputs=json_downloader,
)
model_downloader.click(
fn=get_download_link_model,
inputs=[task, dataset, example],
outputs=model_downloader,
)
# 引用文本
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id="citation-button",
lines=6,
show_copy_button=True,
)
# 启动
demo_board.launch() |