File size: 2,141 Bytes
930e07a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import streamlit as st
import torch
import torch.nn.functional as F


st.set_page_config(page_title='Review improver', layout='centered')

st.markdown('Hello!')


# load the necessary models
bert_mlm_positive, bert_mlm_negative, bert_classifier, tokenizer = torch.load('models.pt', map_location='cpu')


def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
    inputs = tokenizer(sentence, return_tensors='pt')
    tokens = inputs['input_ids'][0]

    vocab_logits_positive = bert_mlm_positive(**inputs)['logits'][0]
    vocab_probs_positive = F.softmax(vocab_logits_positive, dim=1)
    probs_positive = vocab_probs_positive[torch.arange(len(tokens)), tokens]

    vocab_logits_negative = bert_mlm_negative(**inputs)['logits'][0]
    vocab_probs_negative = F.softmax(vocab_logits_negative, dim=1)
    probs_negative = vocab_probs_negative[torch.arange(len(tokens)), tokens]

    ratio = (probs_positive + epsilon) / (probs_negative + epsilon)
    smallest_ratio_ids = torch.argsort(ratio)[:num_tokens]

    replacements = []
    for idx in smallest_ratio_ids:
        new_tokens = torch.argsort(vocab_probs_positive[idx])[-k_best:]
        for token in new_tokens:
            cur_replacement = tokens.clone()
            cur_replacement[idx] = token
            replacements.append(cur_replacement)

    replacements = [tokenizer.decode(replacement, skip_special_tokens=True) for replacement in replacements]

    return replacements


def modify_sentence(sentence, num_iters=3):
    for _ in range(num_iters):
        replacements = get_replacements(sentence, num_tokens=3, k_best=5)

        classifier_inputs = tokenizer(replacements, padding=True, return_tensors='pt')
        logits = bert_classifier(**classifier_inputs)['logits'][:, 1]

        best_idx = torch.argmax(logits)
        sentence = replacements[best_idx]

    return sentence


# here we will try to improve the review
user_input = st.text_input('Enter your review here and we will try to improve it:')
if user_input:
    improved_review = modify_sentence(user_input)
    st.markdown('Here is your improved review:')
    st.markdown(improved_review)