|
import streamlit as st |
|
import torch |
|
import numpy as np |
|
|
|
tokenizer = torch.load('./tokenizer.pth') |
|
bert_mlm_positive = torch.load('./bert_mlm_positive.pth', map_location=torch.device('cpu')) |
|
bert_mlm_negative = torch.load('./bert_mlm_negative.pth', map_location=torch.device('cpu')) |
|
bert_class = torch.load('./bert_classification.pth', map_location=torch.device('cpu')) |
|
st.set_page_config(page_title="change the text style!", layout="centered") |
|
st.markdown("## change the text style!") |
|
phrase = st.text_input("phrase =", value='good!') |
|
|
|
device = torch.device('cpu') |
|
def beam_translate(sent: str, num_iter: int = 4, M: int = 7): |
|
sentence_buf = [sent] |
|
for _ in range(num_iter): |
|
generate_var_buf = [] |
|
for s in sentence_buf: |
|
gen_sentences = get_replacements(s, num_tokens=2, k_best=5) |
|
generate_var_buf += gen_sentences |
|
|
|
data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device) |
|
p = bert_class(**data_for_bert) |
|
probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy() |
|
best_M = np.argsort(probs)[-M:] |
|
sentence_buf = [] |
|
for i in best_M: |
|
sentence_buf.append(generate_var_buf[i]) |
|
|
|
data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device) |
|
p = bert_class(**data_for_bert) |
|
probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy() |
|
best = np.argsort(probs)[-1] |
|
return sentence_buf[best] |
|
|
|
|
|
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): |
|
""" |
|
- split the sentence into tokens using the INGSOC-approved BERT tokenizer |
|
- find :num_tokens: tokens with the highest ratio (see above) |
|
- replace them with :k_best: words according to bert_mlm_positive |
|
:return: a list of all possible strings (up to k_best * num_tokens) |
|
""" |
|
sentence_token = tokenizer(sentence, return_tensors='pt') |
|
sentence_token = {key: value.to(device) for key, value in sentence_token.items()} |
|
length = len(sentence_token['input_ids'][0]) |
|
|
|
p_posit = bert_mlm_positive(**sentence_token).logits.softmax(dim=-1)[0] |
|
p_negat = bert_mlm_negative(**sentence_token).logits.softmax(dim=-1)[0] |
|
|
|
p_tokens_positive = p_posit[torch.arange(length), sentence_token['input_ids'][0]] |
|
p_tokens_negative = p_negat[torch.arange(length), sentence_token['input_ids'][0]] |
|
|
|
prob = ((p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)).detach().cpu().numpy() |
|
replace_items = np.argsort(prob)[:num_tokens] |
|
results = [] |
|
for i in replace_items: |
|
k_best_pos = np.argsort(p_posit[i].detach().cpu().numpy())[::-1][:k_best] |
|
for item in k_best_pos: |
|
new = sentence_token['input_ids'][0] |
|
new[i] = item |
|
results.append(' '.join(tokenizer.decode(new).split(' ')[1:-1])) |
|
return results |
|
|
|
answer = beam_translate(phrase) |
|
st.markdown(answer) |