Spaces:
Running
Running
File size: 2,904 Bytes
1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c 0232932 c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c c82392b 1bb1c3c 9bcc5e1 1bb1c3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import csv
import random
import pandas as pd
import gradio as gr
from utils import clean_dir, TMP_DIR, EN_US
ZH2EN = {
"输入参与者数量": "Number of participants",
"输入分组比率 (格式为用:隔开的数字,生成随机分组数据)": "Grouping ratio (numbers separated by : to generate randomized controlled trial)",
"状态栏": "Status",
"下载随机分组数据 CSV": "Download data CSV",
"随机分组数据预览": "Data preview",
"随机对照试验生成": "RCT Generator",
}
def _L(zh_txt: str):
return ZH2EN[zh_txt] if EN_US else zh_txt
def list_to_csv(list_of_dicts: list, filename: str):
keys = dict(list_of_dicts[0]).keys()
# 将列表中的字典写入 CSV 文件
with open(filename, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=keys)
writer.writeheader()
for data in list_of_dicts:
writer.writerow(data)
def random_allocate(participants: int, ratio: list, out_csv: str):
splits = [0]
total = sum(ratio)
for i, r in enumerate(ratio):
splits.append(splits[i] + int(1.0 * r / total * participants))
splits[-1] = participants
partist = list(range(1, participants + 1))
random.shuffle(partist)
allocation = []
groups = len(ratio)
for i in range(groups):
start = splits[i]
end = splits[i + 1]
for participant in partist[start:end]:
allocation.append({"id": participant, "group": i + 1})
sorted_data = sorted(allocation, key=lambda x: x["id"])
list_to_csv(sorted_data, out_csv)
return out_csv, pd.DataFrame(sorted_data)
# outer func
def infer(participants: float, ratios: str, cache=f"{TMP_DIR}/rct"):
ratio = []
status = "Success"
out_csv = previews = None
try:
ratio_list = ratios.split(":")
clean_dir(cache)
for r in ratio_list:
current_ratio = float(r.strip())
if current_ratio > 0:
ratio.append(current_ratio)
out_csv, previews = random_allocate(
int(participants), ratio, f"{cache}/output.csv"
)
except Exception as e:
status = f"{e}"
return status, out_csv, previews
if __name__ == "__main__":
gr.Interface(
fn=infer,
inputs=[
gr.Number(label=_L("输入参与者数量"), value=10),
gr.Textbox(
label=_L("输入分组比率 (格式为用:隔开的数字,生成随机分组数据)"),
value="8:1:1",
),
],
outputs=[
gr.Textbox(label=_L("状态栏"), show_copy_button=True),
gr.File(label=_L("下载随机分组数据 CSV")),
gr.Dataframe(label=_L("随机分组数据预览")),
],
flagging_mode="never",
title=_L("随机对照试验生成")
).launch()
|