|
import argparse |
|
|
|
import yaml |
|
|
|
|
|
languages = [ |
|
"eng", |
|
"amh", |
|
"ibo", |
|
"fra", |
|
"sna", |
|
"lin", |
|
"wol", |
|
"ewe", |
|
"lug", |
|
"xho", |
|
"kin", |
|
"twi", |
|
"zul", |
|
"orm", |
|
"yor", |
|
"hau", |
|
"sot", |
|
"swa", |
|
] |
|
|
|
languages_REGEX = { |
|
"eng": "The answer is (\\-?[0-9\\.\\,]+)", |
|
"amh": "መልሱ (\\-?[0-9\\.\\,]+)", |
|
"ibo": "Azịza ya bụ (\\-?[0-9\\.\\,]+)", |
|
"fra": "La réponse est(\\-?[0-9\\.\\,]+)", |
|
"sna": "Mhinduro kumubvunzo ndi (\\-?[0-9\\.\\,]+)", |
|
"lin": "Eyano ezali (\\-?[0-9\\.\\,]+)", |
|
"wol": "Tontu li (\\-?[0-9\\.\\,]+)", |
|
"ewe": "ŋuɖoɖoae nye (\\-?[0-9\\.\\,]+)", |
|
"lug": "Ansa eri (\\-?[0-9\\.\\,]+)", |
|
"xho": "Impendulo ngu (\\-?[0-9\\.\\,]+)", |
|
"kin": "Igisubizo ni (\\-?[0-9\\.\\,]+)", |
|
"twi": "Ne nnyiano yɛ (\\-?[0-9\\.\\,]+)", |
|
"zul": "Impendulo ithi (\\-?[0-9\\.\\,]+)", |
|
"orm": "Deebiin isaa (\\-?[0-9\\.\\,]+)", |
|
"yor": "Ìdáhùn náà ni (\\-?[0-9\\.\\,]+)", |
|
"hau": "Amsar ita ce (\\-?[0-9\\.\\,]+)", |
|
"sot": "Karabo ke (\\-?[0-9\\.\\,]+)", |
|
"swa": "Jibu ni (\\-?[0-9\\.\\,]+)", |
|
} |
|
|
|
LANGUAGES = {} |
|
|
|
for lang in languages: |
|
if lang == "amh": |
|
LANGUAGES[lang] = { |
|
"QUESTION": "ጥያቄ:", |
|
"ANSWER": "በቅደም ተከተል መልስ:", |
|
"DIRECT": "Answer:", |
|
"REGEX": languages_REGEX[lang], |
|
} |
|
elif lang == "yor": |
|
LANGUAGES[lang] = { |
|
"QUESTION": "Ìbéèrè:", |
|
"ANSWER": "Ìdáhùn lẹ́sẹsẹ:", |
|
"DIRECT": "Answer:", |
|
"REGEX": languages_REGEX[lang], |
|
} |
|
|
|
else: |
|
LANGUAGES[lang] = { |
|
"QUESTION": "Question:", |
|
"ANSWER": "Step-by-Step Answer:", |
|
"DIRECT": "Answer:", |
|
"REGEX": languages_REGEX[lang], |
|
} |
|
|
|
|
|
def add_regex_pattern(regex_pattern): |
|
if regex_pattern is None: |
|
return {} |
|
return { |
|
"filter_list": [ |
|
{ |
|
"name": "strict-match", |
|
"filter": [ |
|
{ |
|
"function": "regex", |
|
"regex_pattern": f"""{regex_pattern}""", |
|
}, |
|
{ |
|
"function": "take_first", |
|
}, |
|
], |
|
}, |
|
{ |
|
"name": "flexible-extract", |
|
"filter": [ |
|
{ |
|
"function": "regex", |
|
"regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""", |
|
"group_select": -1, |
|
}, |
|
{ |
|
"function": "take_first", |
|
}, |
|
], |
|
}, |
|
], |
|
} |
|
|
|
|
|
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: |
|
""" |
|
Generate a yaml file for each language. |
|
|
|
:param output_dir: The directory to output the files to. |
|
:param overwrite: Whether to overwrite files if they already exist. |
|
""" |
|
err = [] |
|
for lang in LANGUAGES.keys(): |
|
try: |
|
yaml_template = "cot_yaml" |
|
filter_list = {} |
|
DELIMITER = None |
|
if mode == "direct": |
|
ANSWER = LANGUAGES["eng"]["DIRECT"] |
|
QUESTION = LANGUAGES["eng"]["QUESTION"] |
|
REGEX = None |
|
task_name = f"afrimgsm_direct_{lang}" |
|
yaml_template = "direct_yaml" |
|
if mode == "direct-native": |
|
ANSWER = LANGUAGES[lang]["DIRECT"] |
|
QUESTION = LANGUAGES[lang]["QUESTION"] |
|
REGEX = None |
|
task_name = f"afrimgsm_direct_native_{lang}" |
|
yaml_template = "direct_native_yaml" |
|
elif mode == "native-cot": |
|
ANSWER = LANGUAGES[lang]["ANSWER"] |
|
REGEX = LANGUAGES[lang]["REGEX"] |
|
QUESTION = LANGUAGES[lang]["QUESTION"] |
|
task_name = f"afrimgsm_native_cot_{lang}" |
|
filter_list = add_regex_pattern(REGEX) |
|
DELIMITER = "" if lang in ["zh", "ja"] else None |
|
elif mode == "en-cot": |
|
ANSWER = LANGUAGES["eng"]["ANSWER"] |
|
REGEX = LANGUAGES["eng"]["REGEX"] |
|
QUESTION = LANGUAGES["eng"]["QUESTION"] |
|
task_name = f"afrimgsm_en_cot_{lang}" |
|
elif mode == "translate-direct": |
|
ANSWER = LANGUAGES["eng"]["DIRECT"] |
|
QUESTION = LANGUAGES["eng"]["QUESTION"] |
|
REGEX = None |
|
task_name = f"afrimgsm_translate_direct_{lang}" |
|
yaml_template = "translate_direct_yaml" |
|
|
|
file_name = f"{task_name}.yaml" |
|
ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1 |
|
with open( |
|
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" |
|
) as f: |
|
f.write("# Generated by utils.py\n") |
|
yaml.dump( |
|
{ |
|
"include": yaml_template, |
|
"dataset_name": lang, |
|
"task": f"{task_name}", |
|
"doc_to_text": f"""{{% if answer is not none %}}""" |
|
f"""{{{{question+"\\n{ANSWER}"}}}}""" |
|
f"""{{% else %}}""" |
|
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}""" |
|
f"""{{% endif %}}""", |
|
"doc_to_target": f"""{{% if answer is not none %}}""" |
|
f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}""" |
|
f"""{{% else %}}""" |
|
f"""{{{{answer_number|string}}}}""" |
|
f"""{{% endif %}}""", |
|
**filter_list, |
|
"generation_kwargs": { |
|
"until": [QUESTION, "</s>", "<|im_end|>"], |
|
"do_sample": False, |
|
}, |
|
**({"target_delimiter": DELIMITER} if DELIMITER else {}), |
|
}, |
|
f, |
|
allow_unicode=True, |
|
width=float("inf"), |
|
) |
|
except FileExistsError: |
|
err.append(file_name) |
|
|
|
if len(err) > 0: |
|
raise FileExistsError( |
|
"Files were not created because they already exist (use --overwrite flag):" |
|
f" {', '.join(err)}" |
|
) |
|
|
|
|
|
def main() -> None: |
|
"""Parse CLI args and generate language-specific yaml files.""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--overwrite", |
|
default=False, |
|
action="store_true", |
|
help="Overwrite files if they already exist", |
|
) |
|
parser.add_argument( |
|
"--output-dir", default=".", help="Directory to write yaml files to" |
|
) |
|
parser.add_argument( |
|
"--mode", |
|
default="native-cot", |
|
choices=["direct", "direct-native", "native-cot", "en-cot", "translate-direct"], |
|
help="Mode of chain-of-thought", |
|
) |
|
args = parser.parse_args() |
|
|
|
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|