File size: 2,975 Bytes
2a97bef
 
26273c8
372ac16
7ef8f9c
6ae1130
 
 
80ab3f4
372ac16
80ab3f4
372ac16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f23970
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import torch
import numpy as np

tokenizer = torch.load('./tokenizer.pth')
bert_mlm_positive = torch.load('./bert_mlm_positive.pth', map_location=torch.device('cpu'))
bert_mlm_negative = torch.load('./bert_mlm_negative.pth', map_location=torch.device('cpu'))
bert_class = torch.load('./bert_classification.pth', map_location=torch.device('cpu'))
st.set_page_config(page_title="change the text style!", layout="centered")
st.markdown("## change the text style!")
phrase = st.text_input("phrase =", value='good!')
    
device = torch.device('cpu')    
def beam_translate(sent: str, num_iter: int = 4, M: int = 7):
    sentence_buf = [sent]
    for _ in range(num_iter):
        generate_var_buf = []
        for s in sentence_buf:
            gen_sentences = get_replacements(s, num_tokens=2, k_best=5)
            generate_var_buf += gen_sentences
        
        data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device)
        p = bert_class(**data_for_bert)
        probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy()
        best_M = np.argsort(probs)[-M:]
        sentence_buf = []
        for i in best_M:
            sentence_buf.append(generate_var_buf[i])
    
    data_for_bert = tokenizer(generate_var_buf, padding=True, truncation=True, return_tensors="pt").to(device)
    p = bert_class(**data_for_bert)
    probs = p.logits.softmax(dim=-1)[0].detach().cpu().numpy()
    best = np.argsort(probs)[-1]
    return sentence_buf[best]
    
   
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
    """
    - split the sentence into tokens using the INGSOC-approved BERT tokenizer
    - find :num_tokens: tokens with the highest ratio (see above)
    - replace them with :k_best: words according to bert_mlm_positive
    :return: a list of all possible strings (up to k_best * num_tokens)
    """
    sentence_token = tokenizer(sentence, return_tensors='pt')
    sentence_token = {key: value.to(device) for key, value in sentence_token.items()}
    length = len(sentence_token['input_ids'][0])
    
    p_posit = bert_mlm_positive(**sentence_token).logits.softmax(dim=-1)[0]
    p_negat = bert_mlm_negative(**sentence_token).logits.softmax(dim=-1)[0]
    
    p_tokens_positive = p_posit[torch.arange(length), sentence_token['input_ids'][0]]
    p_tokens_negative = p_negat[torch.arange(length), sentence_token['input_ids'][0]]
    
    prob = ((p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)).detach().cpu().numpy()
    replace_items = np.argsort(prob)[:num_tokens]
    results = []
    for i in replace_items:
        k_best_pos = np.argsort(p_posit[i].detach().cpu().numpy())[::-1][:k_best]
        for item in k_best_pos:
            new = sentence_token['input_ids'][0]
            new[i] = item
            results.append(' '.join(tokenizer.decode(new).split(' ')[1:-1]))
    return results
    
answer = beam_translate(phrase)
st.markdown(answer)