import streamlit as st import torch import numpy as np tokenizer = torch.load('./tokenizer.pth') bert_mlm_positive = torch.load('./bert_mlm_positive.pth', map_location=torch.device('cpu')) bert_mlm_negative = torch.load('./bert_mlm_negative.pth', map_location=torch.device('cpu')) bert_class = torch.load('./bert_classification.pth', map_location=torch.device('cpu')) st.set_page_config(page_title="change the text style!", layout="centered") st.markdown("## change the text style!") phrase = st.text_input("phrase =", value='good!') device = torch.device('cpu') def beam_translate(sent: str, num_iter: int = 4, M: int = 7): sentence_buf = [sent] for _ in range(num_iter): generate_var_buf = [] for s in sentence_buf: gen_sentences = get_replacements(s, num_tokens=2, k_best=5) generate_var_buf += gen_sentences data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device) p = bert_class(**data_for_bert) probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy() best_M = np.argsort(probs)[-M:] sentence_buf = [] for i in best_M: sentence_buf.append(generate_var_buf[i]) data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device) p = bert_class(**data_for_bert) probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy() best = np.argsort(probs)[-1] return sentence_buf[best] def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): """ - split the sentence into tokens using the INGSOC-approved BERT tokenizer - find :num_tokens: tokens with the highest ratio (see above) - replace them with :k_best: words according to bert_mlm_positive :return: a list of all possible strings (up to k_best * num_tokens) """ sentence_token = tokenizer(sentence, return_tensors='pt') sentence_token = {key: value.to(device) for key, value in sentence_token.items()} length = len(sentence_token['input_ids'][0]) p_posit = bert_mlm_positive(**sentence_token).logits.softmax(dim=-1)[0] p_negat = bert_mlm_negative(**sentence_token).logits.softmax(dim=-1)[0] p_tokens_positive = p_posit[torch.arange(length), sentence_token['input_ids'][0]] p_tokens_negative = p_negat[torch.arange(length), sentence_token['input_ids'][0]] prob = ((p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)).detach().cpu().numpy() replace_items = np.argsort(prob)[:num_tokens] results = [] for i in replace_items: k_best_pos = np.argsort(p_posit[i].detach().cpu().numpy())[::-1][:k_best] for item in k_best_pos: new = sentence_token['input_ids'][0] new[i] = item results.append(' '.join(tokenizer.decode(new).split(' ')[1:-1])) return results answer = beam_translate(phrase) st.markdown(answer)