import streamlit as st import torch import torch.nn.functional as F st.set_page_config(page_title='Review improver', layout='centered') st.markdown('Hello!') # load the necessary models bert_mlm_positive, bert_mlm_negative, bert_classifier, tokenizer = torch.load('models.pt', map_location='cpu') def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): inputs = tokenizer(sentence, return_tensors='pt') tokens = inputs['input_ids'][0] vocab_logits_positive = bert_mlm_positive(**inputs)['logits'][0] vocab_probs_positive = F.softmax(vocab_logits_positive, dim=1) probs_positive = vocab_probs_positive[torch.arange(len(tokens)), tokens] vocab_logits_negative = bert_mlm_negative(**inputs)['logits'][0] vocab_probs_negative = F.softmax(vocab_logits_negative, dim=1) probs_negative = vocab_probs_negative[torch.arange(len(tokens)), tokens] ratio = (probs_positive + epsilon) / (probs_negative + epsilon) smallest_ratio_ids = torch.argsort(ratio)[:num_tokens] replacements = [] for idx in smallest_ratio_ids: new_tokens = torch.argsort(vocab_probs_positive[idx])[-k_best:] for token in new_tokens: cur_replacement = tokens.clone() cur_replacement[idx] = token replacements.append(cur_replacement) replacements = [tokenizer.decode(replacement, skip_special_tokens=True) for replacement in replacements] return replacements def modify_sentence(sentence, num_iters=3): for _ in range(num_iters): replacements = get_replacements(sentence, num_tokens=3, k_best=5) classifier_inputs = tokenizer(replacements, padding=True, return_tensors='pt') logits = bert_classifier(**classifier_inputs)['logits'][:, 1] best_idx = torch.argmax(logits) sentence = replacements[best_idx] return sentence # here we will try to improve the review user_input = st.text_input('Enter your review here and we will try to improve it:') if user_input: improved_review = modify_sentence(user_input) st.markdown('Here is your improved review:') st.markdown(improved_review)