File size: 2,975 Bytes
2a97bef 26273c8 372ac16 7ef8f9c 6ae1130 80ab3f4 372ac16 80ab3f4 372ac16 8f23970 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import streamlit as st
import torch
import numpy as np
tokenizer = torch.load('./tokenizer.pth')
bert_mlm_positive = torch.load('./bert_mlm_positive.pth', map_location=torch.device('cpu'))
bert_mlm_negative = torch.load('./bert_mlm_negative.pth', map_location=torch.device('cpu'))
bert_class = torch.load('./bert_classification.pth', map_location=torch.device('cpu'))
st.set_page_config(page_title="change the text style!", layout="centered")
st.markdown("## change the text style!")
phrase = st.text_input("phrase =", value='good!')
device = torch.device('cpu')
def beam_translate(sent: str, num_iter: int = 4, M: int = 7):
sentence_buf = [sent]
for _ in range(num_iter):
generate_var_buf = []
for s in sentence_buf:
gen_sentences = get_replacements(s, num_tokens=2, k_best=5)
generate_var_buf += gen_sentences
data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device)
p = bert_class(**data_for_bert)
probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy()
best_M = np.argsort(probs)[-M:]
sentence_buf = []
for i in best_M:
sentence_buf.append(generate_var_buf[i])
data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device)
p = bert_class(**data_for_bert)
probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy()
best = np.argsort(probs)[-1]
return sentence_buf[best]
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
"""
- split the sentence into tokens using the INGSOC-approved BERT tokenizer
- find :num_tokens: tokens with the highest ratio (see above)
- replace them with :k_best: words according to bert_mlm_positive
:return: a list of all possible strings (up to k_best * num_tokens)
"""
sentence_token = tokenizer(sentence, return_tensors='pt')
sentence_token = {key: value.to(device) for key, value in sentence_token.items()}
length = len(sentence_token['input_ids'][0])
p_posit = bert_mlm_positive(**sentence_token).logits.softmax(dim=-1)[0]
p_negat = bert_mlm_negative(**sentence_token).logits.softmax(dim=-1)[0]
p_tokens_positive = p_posit[torch.arange(length), sentence_token['input_ids'][0]]
p_tokens_negative = p_negat[torch.arange(length), sentence_token['input_ids'][0]]
prob = ((p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)).detach().cpu().numpy()
replace_items = np.argsort(prob)[:num_tokens]
results = []
for i in replace_items:
k_best_pos = np.argsort(p_posit[i].detach().cpu().numpy())[::-1][:k_best]
for item in k_best_pos:
new = sentence_token['input_ids'][0]
new[i] = item
results.append(' '.join(tokenizer.decode(new).split(' ')[1:-1]))
return results
answer = beam_translate(phrase)
st.markdown(answer) |