Spaces:
Sleeping
Sleeping
| from itertools import cycle | |
| import random | |
| import argparse | |
| from simpletransformers.seq2seq import Seq2SeqModel | |
| import pandas as pd | |
| random.seed = 42 | |
| def load_conllu_dataset(datafile, join=False): | |
| arr = [] | |
| with open(datafile, encoding='utf-8') as inp: | |
| strings = inp.readlines() | |
| for s in strings: | |
| if (s[0] != "#" and s.strip()): | |
| split_string = s.split('\t') | |
| if split_string[1] == "(" or split_string[1] == ")" or split_string[1] == "[" or split_string[1] == "]": | |
| form = split_string[1] | |
| else: | |
| form = split_string[1].replace("(", "").replace(")", "").replace("[", "").replace("]", "") | |
| if split_string[3] != "PROPN": | |
| form = form.lower() | |
| else: | |
| form = form.capitalize() | |
| lemma = split_string[2] | |
| if split_string[3] == "PROPN": | |
| lemma = lemma.capitalize() | |
| if join: | |
| inpt = form + " " + split_string[3] + " " + split_string[5] | |
| else: | |
| inpt = form | |
| pos = split_string[3] | |
| arr.append([inpt, lemma, pos]) | |
| return pd.DataFrame(arr, columns=["input_text", "target_text", "pos"]) | |
| def predict(in_file, out_file, join=False): | |
| if join: | |
| model_name = "Futyn-Maker/RuthLemm-morphology" | |
| else: | |
| model_name = "Futyn-Maker/RuthLemm" | |
| model = Seq2SeqModel( | |
| encoder_decoder_type="bart", | |
| encoder_decoder_name=model_name, | |
| use_cuda=False | |
| ) | |
| pred_data = load_conllu_dataset(in_file, join=join)["input_text"].tolist() | |
| predictions = cycle(model.predict(pred_data)) | |
| with open(in_file, encoding="utf8") as inp: | |
| strings = inp.readlines() | |
| predicted = [] | |
| for s in strings: | |
| if (s[0] != "#" and s.strip()): | |
| split_string = s.split("\t") | |
| split_string[2] = next(predictions) | |
| joined_string = "\t".join(split_string) | |
| predicted.append(joined_string) | |
| continue | |
| predicted.append(s) | |
| with open(out_file, "w", encoding="utf8") as out: | |
| out.write("".join(predicted)) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("input_file", type=str, help="Path to the input file") | |
| parser.add_argument("output_file", type=str, help="Path to the output file") | |
| parser.add_argument("--morphology", "-m", action="store_true", help="Use morphology") | |
| args = parser.parse_args() | |
| predict(args.input_file, args.output_file, args.morphology) | |
| print("All done!") | |