import re import gradio as gr import nltk import torch from transformers import AutoTokenizer, AutoModelForMaskedLM nltk.download("punkt") nltk.download("punkt_tab") def pred_slonspell(input_text: str): return_values = [] input_text = re.sub(r"(\n)+|( ){2,}", " ", input_text) input_sentences = nltk.sent_tokenize(input_text, language="slovene") for _sent in input_sentences: input_words = nltk.word_tokenize(_sent, language="slovene") formatted_text = " ".join(input_words) formatted_text = f"{formatted_text} " encoded_input = tokenizer(formatted_text, return_tensors="pt", max_length=512, truncation=True) mask_positions = encoded_input["input_ids"] == tokenizer.mask_token_id # bool tensor with torch.no_grad(): logits = model(**{k: v.to(DEVICE) for k, v in encoded_input.items()}).logits[:, :, [0, 1, 2, 3]].cpu() probas = torch.softmax(logits, dim=-1)[0] relevant_probas = probas[mask_positions[0]] # [num_words, 4] is_ok_proba = relevant_probas[:, [0]] is_err_proba = torch.sum(relevant_probas[:, 1:], dim=1, keepdim=True) binary_probas = torch.hstack((is_ok_proba, is_err_proba)) preds = torch.argmax(binary_probas, dim=-1).tolist() # pred_label_probas = binary_probas[torch.arange(len(preds)), preds] return_values.extend( [(_word, "error" if preds[_idx_word] else None) for _idx_word, _word in enumerate(input_words)] ) return return_values _description = """\

SloNSpell demo

This is a simple demo setup for SloNSpell, a 🇸🇮 Slovene spelling error detection model. You can find more about the model in the model card \ cjvt/SloBERTa-slo-word-spelling-annotator.

Given an input text:

1. The input is segmented into sentences and tokenized using NLTK to prepare the model input.

2. The model makes predictions on the sentence level.

The model does not work perfectly and can make mistakes, please check the output! """ demo = gr.Interface( pred_slonspell, gr.Textbox( label="Input text", info="The text that you want to run through the SloNSpell spell-checking model.", lines=3, value="Model vbesedilu o znači besede, v katerih se najajajo napake.", ), gr.HighlightedText( label="Spell-checking prediction", show_legend=True, color_map={"error": "red"}), theme=gr.themes.Base(), description=_description, allow_flagging="never" # RIP flagging to HuggingFace dataset ) if __name__ == "__main__": model_name = "cjvt/SloBERTa-slo-word-spelling-annotator" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForMaskedLM.from_pretrained(model_name) mask_token = tokenizer.mask_token DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") demo.launch()