Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
st.set_page_config(page_title='Review improver', layout='centered') | |
st.markdown('Hello!') | |
# load the necessary models | |
bert_mlm_positive, bert_mlm_negative, bert_classifier, tokenizer = torch.load('models.pt', map_location='cpu') | |
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
inputs = tokenizer(sentence, return_tensors='pt') | |
tokens = inputs['input_ids'][0] | |
vocab_logits_positive = bert_mlm_positive(**inputs)['logits'][0] | |
vocab_probs_positive = F.softmax(vocab_logits_positive, dim=1) | |
probs_positive = vocab_probs_positive[torch.arange(len(tokens)), tokens] | |
vocab_logits_negative = bert_mlm_negative(**inputs)['logits'][0] | |
vocab_probs_negative = F.softmax(vocab_logits_negative, dim=1) | |
probs_negative = vocab_probs_negative[torch.arange(len(tokens)), tokens] | |
ratio = (probs_positive + epsilon) / (probs_negative + epsilon) | |
smallest_ratio_ids = torch.argsort(ratio)[:num_tokens] | |
replacements = [] | |
for idx in smallest_ratio_ids: | |
new_tokens = torch.argsort(vocab_probs_positive[idx])[-k_best:] | |
for token in new_tokens: | |
cur_replacement = tokens.clone() | |
cur_replacement[idx] = token | |
replacements.append(cur_replacement) | |
replacements = [tokenizer.decode(replacement, skip_special_tokens=True) for replacement in replacements] | |
return replacements | |
def modify_sentence(sentence, num_iters=3): | |
for _ in range(num_iters): | |
replacements = get_replacements(sentence, num_tokens=3, k_best=5) | |
classifier_inputs = tokenizer(replacements, padding=True, return_tensors='pt') | |
logits = bert_classifier(**classifier_inputs)['logits'][:, 1] | |
best_idx = torch.argmax(logits) | |
sentence = replacements[best_idx] | |
return sentence | |
# here we will try to improve the review | |
user_input = st.text_input('Enter your review here and we will try to improve it:') | |
if user_input: | |
improved_review = modify_sentence(user_input) | |
st.markdown('Here is your improved review:') | |
st.markdown(improved_review) | |