|
import argparse |
|
from typing import Dict, List |
|
|
|
import yaml |
|
|
|
|
|
|
|
|
|
|
|
LANGUAGES = ["en", "fr", "jp", "pt", "ru", "zh"] |
|
|
|
|
|
def doc_to_text(doc: Dict) -> int: |
|
""" |
|
Return index of the correct choice. |
|
|
|
Note: We are using the "multiple input" mode of the multiple-choice |
|
output-type, which means we use different contexts with the same target |
|
for the different choices, rather than the same context and different targets. |
|
""" |
|
answer_to_num = {"1": 0, "2": 1} |
|
return answer_to_num[doc["answer"]] |
|
|
|
|
|
def doc_to_target(doc: Dict) -> str: |
|
""" |
|
Return the target completion. |
|
|
|
Note that this does not depend on the correct choice as we are using |
|
"multiple input" mode. |
|
""" |
|
idx = doc["sentence"].index("_") + 1 |
|
return doc["sentence"][idx:].strip() |
|
|
|
|
|
def doc_to_choice(doc: Dict) -> List[str]: |
|
"""Return the choices that will be used as contexts in "multiple input" mode.""" |
|
idx = doc["sentence"].index("_") |
|
options = [doc["option1"], doc["option2"]] |
|
return [doc["sentence"][:idx] + opt for opt in options] |
|
|
|
|
|
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: |
|
""" |
|
Generate a yaml file for each language. |
|
|
|
:param output_dir: The directory to output the files to. |
|
:param overwrite: Whether to overwrite files if they already exist. |
|
""" |
|
err = [] |
|
for lang in LANGUAGES: |
|
file_name = f"xwinograd_{lang}.yaml" |
|
try: |
|
with open( |
|
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf-8" |
|
) as f: |
|
f.write("# Generated by utils.py\n") |
|
yaml.dump( |
|
{ |
|
"include": "xwinograd_common_yaml", |
|
"dataset_name": lang, |
|
"task": f"xwinograd_{lang}", |
|
}, |
|
f, |
|
) |
|
except FileExistsError: |
|
err.append(file_name) |
|
|
|
if len(err) > 0: |
|
raise FileExistsError( |
|
"Files were not created because they already exist (use --overwrite flag):" |
|
f" {', '.join(err)}" |
|
) |
|
|
|
|
|
def main() -> None: |
|
"""Parse CLI args and generate language-specific yaml files.""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--overwrite", |
|
default=False, |
|
action="store_true", |
|
help="Overwrite files if they already exist", |
|
) |
|
parser.add_argument( |
|
"--output-dir", default=".", help="Directory to write yaml files to" |
|
) |
|
args = parser.parse_args() |
|
|
|
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|