system's picture
system HF Staff
initial commit
930e07a
import streamlit as st
import torch
import torch.nn.functional as F
st.set_page_config(page_title='Review improver', layout='centered')
st.markdown('Hello!')
# load the necessary models
bert_mlm_positive, bert_mlm_negative, bert_classifier, tokenizer = torch.load('models.pt', map_location='cpu')
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
inputs = tokenizer(sentence, return_tensors='pt')
tokens = inputs['input_ids'][0]
vocab_logits_positive = bert_mlm_positive(**inputs)['logits'][0]
vocab_probs_positive = F.softmax(vocab_logits_positive, dim=1)
probs_positive = vocab_probs_positive[torch.arange(len(tokens)), tokens]
vocab_logits_negative = bert_mlm_negative(**inputs)['logits'][0]
vocab_probs_negative = F.softmax(vocab_logits_negative, dim=1)
probs_negative = vocab_probs_negative[torch.arange(len(tokens)), tokens]
ratio = (probs_positive + epsilon) / (probs_negative + epsilon)
smallest_ratio_ids = torch.argsort(ratio)[:num_tokens]
replacements = []
for idx in smallest_ratio_ids:
new_tokens = torch.argsort(vocab_probs_positive[idx])[-k_best:]
for token in new_tokens:
cur_replacement = tokens.clone()
cur_replacement[idx] = token
replacements.append(cur_replacement)
replacements = [tokenizer.decode(replacement, skip_special_tokens=True) for replacement in replacements]
return replacements
def modify_sentence(sentence, num_iters=3):
for _ in range(num_iters):
replacements = get_replacements(sentence, num_tokens=3, k_best=5)
classifier_inputs = tokenizer(replacements, padding=True, return_tensors='pt')
logits = bert_classifier(**classifier_inputs)['logits'][:, 1]
best_idx = torch.argmax(logits)
sentence = replacements[best_idx]
return sentence
# here we will try to improve the review
user_input = st.text_input('Enter your review here and we will try to improve it:')
if user_input:
improved_review = modify_sentence(user_input)
st.markdown('Here is your improved review:')
st.markdown(improved_review)