Spaces:
Runtime error
Runtime error
File size: 2,141 Bytes
930e07a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import streamlit as st
import torch
import torch.nn.functional as F
st.set_page_config(page_title='Review improver', layout='centered')
st.markdown('Hello!')
# load the necessary models
bert_mlm_positive, bert_mlm_negative, bert_classifier, tokenizer = torch.load('models.pt', map_location='cpu')
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
inputs = tokenizer(sentence, return_tensors='pt')
tokens = inputs['input_ids'][0]
vocab_logits_positive = bert_mlm_positive(**inputs)['logits'][0]
vocab_probs_positive = F.softmax(vocab_logits_positive, dim=1)
probs_positive = vocab_probs_positive[torch.arange(len(tokens)), tokens]
vocab_logits_negative = bert_mlm_negative(**inputs)['logits'][0]
vocab_probs_negative = F.softmax(vocab_logits_negative, dim=1)
probs_negative = vocab_probs_negative[torch.arange(len(tokens)), tokens]
ratio = (probs_positive + epsilon) / (probs_negative + epsilon)
smallest_ratio_ids = torch.argsort(ratio)[:num_tokens]
replacements = []
for idx in smallest_ratio_ids:
new_tokens = torch.argsort(vocab_probs_positive[idx])[-k_best:]
for token in new_tokens:
cur_replacement = tokens.clone()
cur_replacement[idx] = token
replacements.append(cur_replacement)
replacements = [tokenizer.decode(replacement, skip_special_tokens=True) for replacement in replacements]
return replacements
def modify_sentence(sentence, num_iters=3):
for _ in range(num_iters):
replacements = get_replacements(sentence, num_tokens=3, k_best=5)
classifier_inputs = tokenizer(replacements, padding=True, return_tensors='pt')
logits = bert_classifier(**classifier_inputs)['logits'][:, 1]
best_idx = torch.argmax(logits)
sentence = replacements[best_idx]
return sentence
# here we will try to improve the review
user_input = st.text_input('Enter your review here and we will try to improve it:')
if user_input:
improved_review = modify_sentence(user_input)
st.markdown('Here is your improved review:')
st.markdown(improved_review)
|