File size: 5,263 Bytes
6f7189c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os.path
import gradio as gr
import pandas as pd
from constants import *

# ------------ 下载链接 ------------
def get_download_link_model(task, dataset, example):
    _task_path = TASK_PATH_MAPPING[task]
    _dataset_path = DATASET_PATH_MAPPING[dataset]
    _example_path = EXAMPLE_PATH_MAPPING[example]
    return os.path.join("data", _task_path, _dataset_path, "weight", f"{_example_path}.zip")

def get_download_link_json(task, dataset, example):
    _task_path = TASK_PATH_MAPPING[task]
    _dataset_path = DATASET_PATH_MAPPING[dataset]
    _example_path = EXAMPLE_PATH_MAPPING[example]
    if _task_path == "common":
        return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.jsonl")
    else:
        return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.json")

# ------------ 数据读取 + 平均准确率 ------------
def get_data(task, dataset, example):
    _task_path = TASK_PATH_MAPPING[task]
    _dataset_path = DATASET_PATH_MAPPING[dataset]
    _example_path = EXAMPLE_PATH_MAPPING[example]
    csv_file = os.path.join("data", _task_path, _dataset_path, "csv", f"{_example_path}.csv")
    if not os.path.exists(csv_file):
        return None, None

    read_data = pd.read_csv(csv_file)
    data = pd.DataFrame(columns=COLUMN_NAMES)
    average_acc = None

    if _task_path == "coding":
        for _, row in read_data.iterrows():
            data = pd.concat([data, pd.DataFrame([{
                "Prompt": row["prompt"],
                "Pass@1": round(float(row["pass@1"]) * 100, 3),
                "Pass@5": round(float(row["pass@5"]) * 100, 3),
                "Pass@10": round(float(row["pass@10"]) * 100, 3),
                "Correctness": "N/A"
            }])], ignore_index=True)
        # 仅对 HumanEval 数据集计算三列平均
        if "HumanEval" in dataset:
            p1_mean = round(read_data["pass@1"].mean() * 100, 3)
            p5_mean = round(read_data["pass@5"].mean() * 100, 3)
            p10_mean = round(read_data["pass@10"].mean() * 100, 3)
            average_acc = f"{p1_mean} / {p5_mean} / {p10_mean}"
    elif _task_path in ["common", "math"]:
        for _, row in read_data.iterrows():
            data = pd.concat([data, pd.DataFrame([{
                "Prompt": row["prompt"],
                "Pass@1": None,
                "Pass@5": None,
                "Pass@10": None,
                "Correctness": "✅" if row["correctness"] else "❌"
            }])], ignore_index=True)
        average_acc = round(read_data["correctness"].mean() * 100, 3)

    return data, average_acc

# ------------ Gradio UI ------------
with gr.Blocks() as demo_board:
    gr.HTML(DND_HEADER)
    gr.Markdown(DND_INTRODUCTION)

    task = gr.Radio(
        label="Task",
        choices=TASK_LIST,
        value=TASK_LIST[0],
        interactive=True,
    )
    dataset = gr.Radio(
        label="Dataset",
        choices=TASK_DATASET_LIST[task.value],
        value=TASK_DATASET_LIST[task.value][0],
        interactive=True
    )
    example = gr.Radio(
        label="Example",
        choices=EXAMPLE_LIST,
        value=EXAMPLE_LIST[0],
        interactive=True,
    )

    # 平均准确率(放在 Prompt 表格上方)
    average_acc_display = gr.Textbox(
        label="Average Accuracy (%)",
        value=lambda: str(get_data(task.value, dataset.value, example.value)[1]),
        interactive=False,
        visible=True,
        scale=0,
        max_lines=1,
        min_width=160
    )

    # Prompt 表格
    board = gr.components.Dataframe(
        value=lambda: get_data(task.value, dataset.value, example.value)[0],
        column_widths=["60%", "10%", "10%", "10%", "10%"],
        headers=COLUMN_NAMES,
        type="pandas",
        datatype=DATA_TITLE_TYPE,
        interactive=False,
        visible=True,
        max_height=500,
    )

    # 联动更新:task -> dataset
    task.change(
        lambda t: gr.Radio(
            label="Dataset",
            choices=TASK_DATASET_LIST[t],
            value=TASK_DATASET_LIST[t][0],
            interactive=True,
        ),
        inputs=[task],
        outputs=dataset
    )

    # 联动更新:task / dataset / example -> 表格 + 平均准确率
    for component in [task, dataset, example]:
        component.change(
            lambda t, d, e: (get_data(t, d, e)[0], str(get_data(t, d, e)[1])),
            inputs=[task, dataset, example],
            outputs=[board, average_acc_display]
        )

    # 下载按钮
    with gr.Row():
        json_downloader = gr.DownloadButton("Download JSON", visible=True)
        model_downloader = gr.DownloadButton("Download Model", visible=True)

        json_downloader.click(
            fn=get_download_link_json,
            inputs=[task, dataset, example],
            outputs=json_downloader,
        )
        model_downloader.click(
            fn=get_download_link_model,
            inputs=[task, dataset, example],
            outputs=model_downloader,
        )

    # 引用文本
    citation_button = gr.Textbox(
        value=CITATION_BUTTON_TEXT,
        label=CITATION_BUTTON_LABEL,
        elem_id="citation-button",
        lines=6,
        show_copy_button=True,
    )

# 启动
demo_board.launch()