|
import argparse |
|
|
|
import yaml |
|
|
|
|
|
|
|
|
|
|
|
|
|
LANGUAGES = { |
|
"ar": { |
|
"QUESTION_WORD": "صحيح", |
|
"ENTAILMENT_LABEL": "نعم", |
|
"NEUTRAL_LABEL": "لذا", |
|
"CONTRADICTION_LABEL": "رقم", |
|
}, |
|
"bg": { |
|
"QUESTION_WORD": "правилно", |
|
"ENTAILMENT_LABEL": "да", |
|
"NEUTRAL_LABEL": "така", |
|
"CONTRADICTION_LABEL": "не", |
|
}, |
|
"de": { |
|
"QUESTION_WORD": "richtig", |
|
"ENTAILMENT_LABEL": "Ja", |
|
"NEUTRAL_LABEL": "Auch", |
|
"CONTRADICTION_LABEL": "Nein", |
|
}, |
|
"el": { |
|
"QUESTION_WORD": "σωστός", |
|
"ENTAILMENT_LABEL": "Ναί", |
|
"NEUTRAL_LABEL": "Έτσι", |
|
"CONTRADICTION_LABEL": "όχι", |
|
}, |
|
"en": { |
|
"QUESTION_WORD": "right", |
|
"ENTAILMENT_LABEL": "Yes", |
|
"NEUTRAL_LABEL": "Also", |
|
"CONTRADICTION_LABEL": "No", |
|
}, |
|
"es": { |
|
"QUESTION_WORD": "correcto", |
|
"ENTAILMENT_LABEL": "Sí", |
|
"NEUTRAL_LABEL": "Asi que", |
|
"CONTRADICTION_LABEL": "No", |
|
}, |
|
"fr": { |
|
"QUESTION_WORD": "correct", |
|
"ENTAILMENT_LABEL": "Oui", |
|
"NEUTRAL_LABEL": "Aussi", |
|
"CONTRADICTION_LABEL": "Non", |
|
}, |
|
"hi": { |
|
"QUESTION_WORD": "सही", |
|
"ENTAILMENT_LABEL": "हाँ", |
|
"NEUTRAL_LABEL": "इसलिए", |
|
"CONTRADICTION_LABEL": "नहीं", |
|
}, |
|
"ru": { |
|
"QUESTION_WORD": "правильно", |
|
"ENTAILMENT_LABEL": "Да", |
|
"NEUTRAL_LABEL": "Так", |
|
"CONTRADICTION_LABEL": "Нет", |
|
}, |
|
"sw": { |
|
"QUESTION_WORD": "sahihi", |
|
"ENTAILMENT_LABEL": "Ndiyo", |
|
"NEUTRAL_LABEL": "Hivyo", |
|
"CONTRADICTION_LABEL": "Hapana", |
|
}, |
|
"th": { |
|
"QUESTION_WORD": "ถูกต้อง", |
|
"ENTAILMENT_LABEL": "ใช่", |
|
"NEUTRAL_LABEL": "ดังนั้น", |
|
"CONTRADICTION_LABEL": "ไม่", |
|
}, |
|
"tr": { |
|
"QUESTION_WORD": "doğru", |
|
"ENTAILMENT_LABEL": "Evet", |
|
"NEUTRAL_LABEL": "Böylece", |
|
"CONTRADICTION_LABEL": "Hayır", |
|
}, |
|
"ur": { |
|
"QUESTION_WORD": "صحیح", |
|
"ENTAILMENT_LABEL": "جی ہاں", |
|
"NEUTRAL_LABEL": "اس لئے", |
|
"CONTRADICTION_LABEL": "نہیں", |
|
}, |
|
"vi": { |
|
"QUESTION_WORD": "đúng", |
|
"ENTAILMENT_LABEL": "Vâng", |
|
"NEUTRAL_LABEL": "Vì vậy", |
|
"CONTRADICTION_LABEL": "Không", |
|
}, |
|
"zh": { |
|
"QUESTION_WORD": "正确", |
|
"ENTAILMENT_LABEL": "是的", |
|
"NEUTRAL_LABEL": "所以", |
|
"CONTRADICTION_LABEL": "不是的", |
|
}, |
|
} |
|
|
|
|
|
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: |
|
""" |
|
Generate a yaml file for each language. |
|
|
|
:param output_dir: The directory to output the files to. |
|
:param overwrite: Whether to overwrite files if they already exist. |
|
""" |
|
err = [] |
|
for lang in LANGUAGES.keys(): |
|
file_name = f"xnli_{lang}.yaml" |
|
try: |
|
QUESTION_WORD = LANGUAGES[lang]["QUESTION_WORD"] |
|
ENTAILMENT_LABEL = LANGUAGES[lang]["ENTAILMENT_LABEL"] |
|
NEUTRAL_LABEL = LANGUAGES[lang]["NEUTRAL_LABEL"] |
|
CONTRADICTION_LABEL = LANGUAGES[lang]["CONTRADICTION_LABEL"] |
|
with open( |
|
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" |
|
) as f: |
|
f.write("# Generated by utils.py\n") |
|
yaml.dump( |
|
{ |
|
"include": "xnli_common_yaml", |
|
"dataset_name": lang, |
|
"task": f"xnli_{lang}", |
|
"doc_to_text": "", |
|
"doc_to_choice": f"{{{{[" |
|
f"""premise+\", {QUESTION_WORD}? {ENTAILMENT_LABEL}, \"+hypothesis,""" |
|
f"""premise+\", {QUESTION_WORD}? {NEUTRAL_LABEL}, \"+hypothesis,""" |
|
f"""premise+\", {QUESTION_WORD}? {CONTRADICTION_LABEL}, \"+hypothesis""" |
|
f"]}}}}", |
|
}, |
|
f, |
|
allow_unicode=True, |
|
) |
|
except FileExistsError: |
|
err.append(file_name) |
|
|
|
if len(err) > 0: |
|
raise FileExistsError( |
|
"Files were not created because they already exist (use --overwrite flag):" |
|
f" {', '.join(err)}" |
|
) |
|
|
|
|
|
def main() -> None: |
|
"""Parse CLI args and generate language-specific yaml files.""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--overwrite", |
|
default=False, |
|
action="store_true", |
|
help="Overwrite files if they already exist", |
|
) |
|
parser.add_argument( |
|
"--output-dir", default=".", help="Directory to write yaml files to" |
|
) |
|
args = parser.parse_args() |
|
|
|
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|