|
""" |
|
Take in a YAML, and output all other splits with this YAML |
|
""" |
|
|
|
import argparse |
|
import os |
|
|
|
import requests |
|
import yaml |
|
from tqdm import tqdm |
|
|
|
from lm_eval.utils import logging |
|
|
|
|
|
API_URL = "https://datasets-server.huggingface.co/splits?dataset=facebook/belebele" |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--base_yaml_path", required=True) |
|
parser.add_argument("--save_prefix_path", default="belebele") |
|
parser.add_argument("--cot_prompt_path", default=None) |
|
parser.add_argument("--task_prefix", default="") |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
|
|
|
|
base_yaml_name = os.path.split(args.base_yaml_path)[-1] |
|
with open(args.base_yaml_path, encoding="utf-8") as f: |
|
base_yaml = yaml.full_load(f) |
|
|
|
if args.cot_prompt_path is not None: |
|
import json |
|
|
|
with open(args.cot_prompt_path, encoding="utf-8") as f: |
|
cot_file = json.load(f) |
|
|
|
def query(): |
|
response = requests.get(API_URL) |
|
return response.json()["splits"] |
|
|
|
print(query()) |
|
languages = [split["split"] for split in query()] |
|
|
|
for lang in tqdm([lang for lang in languages if "default" not in lang]): |
|
yaml_dict = { |
|
"include": base_yaml_name, |
|
"task": f"belebele_{args.task_prefix}_{lang}" |
|
if args.task_prefix != "" |
|
else f"belebele_{lang}", |
|
"test_split": lang, |
|
"fewshot_split": lang, |
|
} |
|
|
|
file_save_path = args.save_prefix_path + f"_{lang}.yaml" |
|
logging.info(f"Saving yaml for subset {lang} to {file_save_path}") |
|
with open(file_save_path, "w", encoding="utf-8") as yaml_file: |
|
yaml.dump( |
|
yaml_dict, |
|
yaml_file, |
|
width=float("inf"), |
|
allow_unicode=True, |
|
default_style='"', |
|
) |
|
|
|
|
|
|
|
group_yaml_dict = { |
|
"group": f"belebele_{args.task_prefix}" |
|
if args.task_prefix != "" |
|
else "belebele", |
|
"task": [ |
|
( |
|
f"belebele_{args.task_prefix}_{lang}" |
|
if args.task_prefix != "" |
|
else f"belebele_{lang}" |
|
) |
|
for lang in languages |
|
if "default" not in lang |
|
], |
|
"aggregate_metric_list": [ |
|
{"metric": "acc", "aggregation": "mean", "weight_by_size": False}, |
|
{"metric": "acc_norm", "aggregation": "mean", "weight_by_size": False}, |
|
], |
|
"metadata": {"version": 0.0}, |
|
} |
|
|
|
file_save_path = "_" + args.save_prefix_path + f"{args.task_prefix}.yaml" |
|
|
|
with open(file_save_path, "w", encoding="utf-8") as group_yaml_file: |
|
yaml.dump( |
|
group_yaml_dict, |
|
group_yaml_file, |
|
width=float("inf"), |
|
allow_unicode=True, |
|
default_style='"', |
|
) |
|
|