|
""" |
|
Take in a YAML, and output all other splits with this YAML |
|
""" |
|
|
|
import argparse |
|
import os |
|
|
|
import yaml |
|
from tqdm import tqdm |
|
|
|
from lm_eval.utils import eval_logger |
|
|
|
|
|
SUBJECTS = { |
|
"agronomy": "农学", |
|
"anatomy": "解剖学", |
|
"ancient_chinese": "古汉语", |
|
"arts": "艺术学", |
|
"astronomy": "天文学", |
|
"business_ethics": "商业伦理", |
|
"chinese_civil_service_exam": "中国公务员考试", |
|
"chinese_driving_rule": "中国驾驶规则", |
|
"chinese_food_culture": "中国饮食文化", |
|
"chinese_foreign_policy": "中国外交政策", |
|
"chinese_history": "中国历史", |
|
"chinese_literature": "中国文学", |
|
"chinese_teacher_qualification": "中国教师资格", |
|
"clinical_knowledge": "临床知识", |
|
"college_actuarial_science": "大学精算学", |
|
"college_education": "大学教育学", |
|
"college_engineering_hydrology": "大学工程水文学", |
|
"college_law": "大学法律", |
|
"college_mathematics": "大学数学", |
|
"college_medical_statistics": "大学医学统计", |
|
"college_medicine": "大学医学", |
|
"computer_science": "计算机科学", |
|
"computer_security": "计算机安全", |
|
"conceptual_physics": "概念物理学", |
|
"construction_project_management": "建设工程管理", |
|
"economics": "经济学", |
|
"education": "教育学", |
|
"electrical_engineering": "电气工程", |
|
"elementary_chinese": "小学语文", |
|
"elementary_commonsense": "小学常识", |
|
"elementary_information_and_technology": "小学信息技术", |
|
"elementary_mathematics": "初等数学", |
|
"ethnology": "民族学", |
|
"food_science": "食品科学", |
|
"genetics": "遗传学", |
|
"global_facts": "全球事实", |
|
"high_school_biology": "高中生物", |
|
"high_school_chemistry": "高中化学", |
|
"high_school_geography": "高中地理", |
|
"high_school_mathematics": "高中数学", |
|
"high_school_physics": "高中物理学", |
|
"high_school_politics": "高中政治", |
|
"human_sexuality": "人类性行为", |
|
"international_law": "国际法学", |
|
"journalism": "新闻学", |
|
"jurisprudence": "法理学", |
|
"legal_and_moral_basis": "法律与道德基础", |
|
"logical": "逻辑学", |
|
"machine_learning": "机器学习", |
|
"management": "管理学", |
|
"marketing": "市场营销", |
|
"marxist_theory": "马克思主义理论", |
|
"modern_chinese": "现代汉语", |
|
"nutrition": "营养学", |
|
"philosophy": "哲学", |
|
"professional_accounting": "专业会计", |
|
"professional_law": "专业法学", |
|
"professional_medicine": "专业医学", |
|
"professional_psychology": "专业心理学", |
|
"public_relations": "公共关系", |
|
"security_study": "安全研究", |
|
"sociology": "社会学", |
|
"sports_science": "体育学", |
|
"traditional_chinese_medicine": "中医中药", |
|
"virology": "病毒学", |
|
"world_history": "世界历史", |
|
"world_religions": "世界宗教", |
|
} |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--base_yaml_path", required=True) |
|
parser.add_argument("--save_prefix_path", default="cmmlu") |
|
parser.add_argument("--cot_prompt_path", default=None) |
|
parser.add_argument("--task_prefix", default="") |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
|
|
|
|
base_yaml_name = os.path.split(args.base_yaml_path)[-1] |
|
with open(args.base_yaml_path, encoding="utf-8") as f: |
|
base_yaml = yaml.full_load(f) |
|
|
|
if args.cot_prompt_path is not None: |
|
import json |
|
|
|
with open(args.cot_prompt_path, encoding="utf-8") as f: |
|
cot_file = json.load(f) |
|
|
|
for subject_eng, subject_zh in tqdm(SUBJECTS.items()): |
|
if args.cot_prompt_path is not None: |
|
description = cot_file[subject_eng] |
|
else: |
|
description = ( |
|
f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n" |
|
) |
|
|
|
yaml_dict = { |
|
"include": base_yaml_name, |
|
"task": f"cmmlu_{args.task_prefix}_{subject_eng}" |
|
if args.task_prefix != "" |
|
else f"cmmlu_{subject_eng}", |
|
"dataset_name": subject_eng, |
|
"description": description, |
|
} |
|
|
|
file_save_path = args.save_prefix_path + f"_{subject_eng}.yaml" |
|
eval_logger.info(f"Saving yaml for subset {subject_eng} to {file_save_path}") |
|
with open(file_save_path, "w", encoding="utf-8") as yaml_file: |
|
yaml.dump( |
|
yaml_dict, |
|
yaml_file, |
|
width=float("inf"), |
|
allow_unicode=True, |
|
default_style='"', |
|
) |
|
|
|
|
|
|
|
group_yaml_dict = { |
|
"group": "cmmlu", |
|
"task": [ |
|
( |
|
f"cmmlu_{args.task_prefix}_{subject_eng}" |
|
if args.task_prefix != "" |
|
else f"cmmlu_{subject_eng}" |
|
) |
|
for subject_eng in SUBJECTS.keys() |
|
], |
|
"aggregate_metric_list": [ |
|
{"metric": "acc", "aggregation": "mean", "weight_by_size": True}, |
|
{"metric": "acc_norm", "aggregation": "mean", "weight_by_size": True}, |
|
], |
|
"metadata": {"version": 0.0}, |
|
} |
|
|
|
file_save_path = "_" + args.save_prefix_path + ".yaml" |
|
|
|
with open(file_save_path, "w", encoding="utf-8") as group_yaml_file: |
|
yaml.dump( |
|
group_yaml_dict, |
|
group_yaml_file, |
|
width=float("inf"), |
|
allow_unicode=True, |
|
default_style='"', |
|
) |
|
|