BayesTensor's picture
Upload folder using huggingface_hub
9d5b280 verified
import argparse
from typing import Dict, List
import yaml
# Different languages that are part of xwinograd.
# These correspond to dataset names (Subsets) on HuggingFace.
# A yaml file is generated by this script for each language.
LANGUAGES = ["en", "fr", "jp", "pt", "ru", "zh"]
def doc_to_text(doc: Dict) -> int:
"""
Return index of the correct choice.
Note: We are using the "multiple input" mode of the multiple-choice
output-type, which means we use different contexts with the same target
for the different choices, rather than the same context and different targets.
"""
answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc["answer"]]
def doc_to_target(doc: Dict) -> str:
"""
Return the target completion.
Note that this does not depend on the correct choice as we are using
"multiple input" mode.
"""
idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip()
def doc_to_choice(doc: Dict) -> List[str]:
"""Return the choices that will be used as contexts in "multiple input" mode."""
idx = doc["sentence"].index("_")
options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options]
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
"""
Generate a yaml file for each language.
:param output_dir: The directory to output the files to.
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES:
file_name = f"xwinograd_{lang}.yaml"
try:
with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf-8"
) as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
"include": "xwinograd_common_yaml",
"dataset_name": lang,
"task": f"xwinograd_{lang}",
},
f,
)
except FileExistsError:
err.append(file_name)
if len(err) > 0:
raise FileExistsError(
"Files were not created because they already exist (use --overwrite flag):"
f" {', '.join(err)}"
)
def main() -> None:
"""Parse CLI args and generate language-specific yaml files."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
)
args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)
if __name__ == "__main__":
main()