File size: 3,083 Bytes
eb759e8
 
 
 
 
 
 
 
944a37e
2f7b8de
944a37e
 
eb759e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229bda0
eb759e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229bda0
 
eb759e8
 
 
 
847b058
eb759e8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import re

import gradio as gr
import nltk
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM


nltk.download("punkt")
nltk.download("punkt_tab")


def pred_slonspell(input_text: str):
    return_values = []

    input_text = re.sub(r"(\n)+|( ){2,}", " ", input_text)
    input_sentences = nltk.sent_tokenize(input_text, language="slovene")
    for _sent in input_sentences:
        input_words = nltk.word_tokenize(_sent, language="slovene")
        formatted_text = " <mask> ".join(input_words)
        formatted_text = f"{formatted_text} <mask>"

        encoded_input = tokenizer(formatted_text, return_tensors="pt", max_length=512, truncation=True)
        mask_positions = encoded_input["input_ids"] == tokenizer.mask_token_id  # bool tensor

        with torch.no_grad():
            logits = model(**{k: v.to(DEVICE) for k, v in encoded_input.items()}).logits[:, :, [0, 1, 2, 3]].cpu()
            probas = torch.softmax(logits, dim=-1)[0]

            relevant_probas = probas[mask_positions[0]]  # [num_words, 4]

            is_ok_proba = relevant_probas[:, [0]]
            is_err_proba = torch.sum(relevant_probas[:, 1:], dim=1, keepdim=True)
            binary_probas = torch.hstack((is_ok_proba, is_err_proba))
            preds = torch.argmax(binary_probas, dim=-1).tolist()
            # pred_label_probas = binary_probas[torch.arange(len(preds)), preds]

        return_values.extend(
            [(_word, "error" if preds[_idx_word] else None) for _idx_word, _word in enumerate(input_words)]
        )

    return return_values

_description = """\
<h1> SloNSpell demo</h1>
<p>This is a simple demo setup for SloNSpell, a 🇸🇮 Slovene spelling error detection model.
You can find more about the model in the model card <a href='https://huggingface.co/cjvt/SloBERTa-slo-word-spelling-annotator'>\
cjvt/SloBERTa-slo-word-spelling-annotator</a>.</p>
<p>Given an input text: </p>
<p>1. The input is segmented into sentences and tokenized using NLTK to prepare the model input.</p>
<p>2. The model makes predictions on the sentence level. </p>
<b>The model does not work perfectly and can make mistakes, please check the output!</b>
"""

demo = gr.Interface(
    pred_slonspell,
    gr.Textbox(
        label="Input text",
        info="The text that you want to run through the SloNSpell spell-checking model.",
        lines=3,
        value="Model vbesedilu o znači besede, v katerih se najajajo napake.",
    ),
    gr.HighlightedText(
        label="Spell-checking prediction",
        show_legend=True,
        color_map={"error": "red"}),
    theme=gr.themes.Base(),
    description=_description,
    allow_flagging="never"  # RIP flagging to HuggingFace dataset
)

if __name__ == "__main__":
    model_name = "cjvt/SloBERTa-slo-word-spelling-annotator"
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    model = AutoModelForMaskedLM.from_pretrained(model_name)
    mask_token = tokenizer.mask_token
    DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    demo.launch()