File size: 2,904 Bytes
1bb1c3c
 
c82392b
1bb1c3c
 
c82392b
 
1bb1c3c
 
c82392b
1bb1c3c
 
0232932
c82392b
 
 
 
 
 
 
1bb1c3c
 
 
 
 
 
 
 
c82392b
 
1bb1c3c
 
 
 
 
c82392b
1bb1c3c
 
 
 
 
 
 
 
 
 
c82392b
1bb1c3c
 
 
c82392b
 
 
1bb1c3c
 
c82392b
1bb1c3c
c82392b
1bb1c3c
c82392b
1bb1c3c
 
 
 
c82392b
1bb1c3c
 
 
c82392b
 
 
 
1bb1c3c
c82392b
 
 
1bb1c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bcc5e1
1bb1c3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import csv
import random
import pandas as pd
import gradio as gr
from utils import clean_dir, TMP_DIR, EN_US

ZH2EN = {
    "输入参与者数量": "Number of participants",
    "输入分组比率 (格式为用:隔开的数字,生成随机分组数据)": "Grouping ratio (numbers separated by : to generate randomized controlled trial)",
    "状态栏": "Status",
    "下载随机分组数据 CSV": "Download data CSV",
    "随机分组数据预览": "Data preview",
    "随机对照试验生成": "RCT Generator",
}


def _L(zh_txt: str):
    return ZH2EN[zh_txt] if EN_US else zh_txt


def list_to_csv(list_of_dicts: list, filename: str):
    keys = dict(list_of_dicts[0]).keys()
    # 将列表中的字典写入 CSV 文件
    with open(filename, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=keys)
        writer.writeheader()
        for data in list_of_dicts:
            writer.writerow(data)


def random_allocate(participants: int, ratio: list, out_csv: str):
    splits = [0]
    total = sum(ratio)
    for i, r in enumerate(ratio):
        splits.append(splits[i] + int(1.0 * r / total * participants))

    splits[-1] = participants
    partist = list(range(1, participants + 1))
    random.shuffle(partist)
    allocation = []
    groups = len(ratio)
    for i in range(groups):
        start = splits[i]
        end = splits[i + 1]
        for participant in partist[start:end]:
            allocation.append({"id": participant, "group": i + 1})

    sorted_data = sorted(allocation, key=lambda x: x["id"])
    list_to_csv(sorted_data, out_csv)
    return out_csv, pd.DataFrame(sorted_data)


# outer func
def infer(participants: float, ratios: str, cache=f"{TMP_DIR}/rct"):
    ratio = []
    status = "Success"
    out_csv = previews = None
    try:
        ratio_list = ratios.split(":")
        clean_dir(cache)
        for r in ratio_list:
            current_ratio = float(r.strip())
            if current_ratio > 0:
                ratio.append(current_ratio)

        out_csv, previews = random_allocate(
            int(participants), ratio, f"{cache}/output.csv"
        )

    except Exception as e:
        status = f"{e}"

    return status, out_csv, previews


if __name__ == "__main__":
    gr.Interface(
        fn=infer,
        inputs=[
            gr.Number(label=_L("输入参与者数量"), value=10),
            gr.Textbox(
                label=_L("输入分组比率 (格式为用:隔开的数字,生成随机分组数据)"),
                value="8:1:1",
            ),
        ],
        outputs=[
            gr.Textbox(label=_L("状态栏"), show_copy_button=True),
            gr.File(label=_L("下载随机分组数据 CSV")),
            gr.Dataframe(label=_L("随机分组数据预览")),
        ],
        flagging_mode="never",
        title=_L("随机对照试验生成")
    ).launch()