import os.path import gradio as gr import pandas as pd from constants import * # ------------ 下载链接 ------------ def get_download_link_model(task, dataset, example): _task_path = TASK_PATH_MAPPING[task] _dataset_path = DATASET_PATH_MAPPING[dataset] _example_path = EXAMPLE_PATH_MAPPING[example] return os.path.join("data", _task_path, _dataset_path, "weight", f"{_example_path}.zip") def get_download_link_json(task, dataset, example): _task_path = TASK_PATH_MAPPING[task] _dataset_path = DATASET_PATH_MAPPING[dataset] _example_path = EXAMPLE_PATH_MAPPING[example] if _task_path == "common": return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.jsonl") else: return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.json") # ------------ 数据读取 + 平均准确率 ------------ def get_data(task, dataset, example): _task_path = TASK_PATH_MAPPING[task] _dataset_path = DATASET_PATH_MAPPING[dataset] _example_path = EXAMPLE_PATH_MAPPING[example] csv_file = os.path.join("data", _task_path, _dataset_path, "csv", f"{_example_path}.csv") if not os.path.exists(csv_file): return None, None read_data = pd.read_csv(csv_file) data = pd.DataFrame(columns=COLUMN_NAMES) average_acc = None if _task_path == "coding": for _, row in read_data.iterrows(): data = pd.concat([data, pd.DataFrame([{ "Prompt": row["prompt"], "Pass@1": round(float(row["pass@1"]) * 100, 3), "Pass@5": round(float(row["pass@5"]) * 100, 3), "Pass@10": round(float(row["pass@10"]) * 100, 3), "Correctness": "N/A" }])], ignore_index=True) # 仅对 HumanEval 数据集计算三列平均 if "HumanEval" in dataset: p1_mean = round(read_data["pass@1"].mean() * 100, 3) p5_mean = round(read_data["pass@5"].mean() * 100, 3) p10_mean = round(read_data["pass@10"].mean() * 100, 3) average_acc = f"{p1_mean} / {p5_mean} / {p10_mean}" elif _task_path in ["common", "math"]: for _, row in read_data.iterrows(): data = pd.concat([data, pd.DataFrame([{ "Prompt": row["prompt"], "Pass@1": None, "Pass@5": None, "Pass@10": None, "Correctness": "✅" if row["correctness"] else "❌" }])], ignore_index=True) average_acc = round(read_data["correctness"].mean() * 100, 3) return data, average_acc # ------------ Gradio UI ------------ with gr.Blocks() as demo_board: gr.HTML(DND_HEADER) gr.Markdown(DND_INTRODUCTION) task = gr.Radio( label="Task", choices=TASK_LIST, value=TASK_LIST[0], interactive=True, ) dataset = gr.Radio( label="Dataset", choices=TASK_DATASET_LIST[task.value], value=TASK_DATASET_LIST[task.value][0], interactive=True ) example = gr.Radio( label="Example", choices=EXAMPLE_LIST, value=EXAMPLE_LIST[0], interactive=True, ) # 平均准确率(放在 Prompt 表格上方) average_acc_display = gr.Textbox( label="Average Accuracy (%)", value=lambda: str(get_data(task.value, dataset.value, example.value)[1]), interactive=False, visible=True, scale=0, max_lines=1, min_width=160 ) # Prompt 表格 board = gr.components.Dataframe( value=lambda: get_data(task.value, dataset.value, example.value)[0], column_widths=["60%", "10%", "10%", "10%", "10%"], headers=COLUMN_NAMES, type="pandas", datatype=DATA_TITLE_TYPE, interactive=False, visible=True, max_height=500, ) # 联动更新:task -> dataset task.change( lambda t: gr.Radio( label="Dataset", choices=TASK_DATASET_LIST[t], value=TASK_DATASET_LIST[t][0], interactive=True, ), inputs=[task], outputs=dataset ) # 联动更新:task / dataset / example -> 表格 + 平均准确率 for component in [task, dataset, example]: component.change( lambda t, d, e: (get_data(t, d, e)[0], str(get_data(t, d, e)[1])), inputs=[task, dataset, example], outputs=[board, average_acc_display] ) # 下载按钮 with gr.Row(): json_downloader = gr.DownloadButton("Download JSON", visible=True) model_downloader = gr.DownloadButton("Download Model", visible=True) json_downloader.click( fn=get_download_link_json, inputs=[task, dataset, example], outputs=json_downloader, ) model_downloader.click( fn=get_download_link_model, inputs=[task, dataset, example], outputs=model_downloader, ) # 引用文本 citation_button = gr.Textbox( value=CITATION_BUTTON_TEXT, label=CITATION_BUTTON_LABEL, elem_id="citation-button", lines=6, show_copy_button=True, ) # 启动 demo_board.launch()