|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf |
|
""" |
|
|
|
import numpy as np |
|
import time |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class InferSent(nn.Module): |
|
|
|
def __init__(self, config): |
|
super(InferSent, self).__init__() |
|
self.bsize = config['bsize'] |
|
self.word_emb_dim = config['word_emb_dim'] |
|
self.enc_lstm_dim = config['enc_lstm_dim'] |
|
self.pool_type = config['pool_type'] |
|
self.dpout_model = config['dpout_model'] |
|
self.version = 1 if 'version' not in config else config['version'] |
|
|
|
self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, |
|
bidirectional=True, dropout=self.dpout_model) |
|
|
|
assert self.version in [1, 2] |
|
if self.version == 1: |
|
self.bos = '<s>' |
|
self.eos = '</s>' |
|
self.max_pad = True |
|
self.moses_tok = False |
|
elif self.version == 2: |
|
self.bos = '<p>' |
|
self.eos = '</p>' |
|
self.max_pad = False |
|
self.moses_tok = True |
|
|
|
def is_cuda(self): |
|
|
|
return self.enc_lstm.bias_hh_l0.data.is_cuda |
|
|
|
def forward(self, sent_tuple): |
|
|
|
|
|
sent, sent_len = sent_tuple |
|
|
|
|
|
sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) |
|
sent_len_sorted = sent_len_sorted.copy() |
|
idx_unsort = np.argsort(idx_sort) |
|
|
|
idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \ |
|
else torch.from_numpy(idx_sort) |
|
sent = sent.index_select(1, idx_sort) |
|
|
|
|
|
sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted) |
|
sent_output = self.enc_lstm(sent_packed)[0] |
|
sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] |
|
|
|
|
|
idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \ |
|
else torch.from_numpy(idx_unsort) |
|
sent_output = sent_output.index_select(1, idx_unsort) |
|
|
|
|
|
if self.pool_type == "mean": |
|
sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda() |
|
emb = torch.sum(sent_output, 0).squeeze(0) |
|
emb = emb / sent_len.expand_as(emb) |
|
elif self.pool_type == "max": |
|
if not self.max_pad: |
|
sent_output[sent_output == 0] = -1e9 |
|
emb = torch.max(sent_output, 0)[0] |
|
if emb.ndimension() == 3: |
|
emb = emb.squeeze(0) |
|
assert emb.ndimension() == 2 |
|
|
|
return emb |
|
|
|
def set_w2v_path(self, w2v_path): |
|
self.w2v_path = w2v_path |
|
|
|
def get_word_dict(self, sentences, tokenize=True): |
|
|
|
word_dict = {} |
|
sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences] |
|
for sent in sentences: |
|
for word in sent: |
|
if word not in word_dict: |
|
word_dict[word] = '' |
|
word_dict[self.bos] = '' |
|
word_dict[self.eos] = '' |
|
return word_dict |
|
|
|
def get_w2v(self, word_dict): |
|
assert hasattr(self, 'w2v_path'), 'w2v path not set' |
|
|
|
word_vec = {} |
|
with open(self.w2v_path, encoding='utf-8') as f: |
|
for line in f: |
|
word, vec = line.split(' ', 1) |
|
if word in word_dict: |
|
word_vec[word] = np.fromstring(vec, sep=' ') |
|
print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict))) |
|
return word_vec |
|
|
|
def get_w2v_k(self, K): |
|
assert hasattr(self, 'w2v_path'), 'w2v path not set' |
|
|
|
k = 0 |
|
word_vec = {} |
|
with open(self.w2v_path, encoding='utf-8') as f: |
|
for line in f: |
|
word, vec = line.split(' ', 1) |
|
if k <= K: |
|
word_vec[word] = np.fromstring(vec, sep=' ') |
|
k += 1 |
|
if k > K: |
|
if word in [self.bos, self.eos]: |
|
word_vec[word] = np.fromstring(vec, sep=' ') |
|
|
|
if k > K and all([w in word_vec for w in [self.bos, self.eos]]): |
|
break |
|
return word_vec |
|
|
|
def build_vocab(self, sentences, tokenize=True): |
|
assert hasattr(self, 'w2v_path'), 'w2v path not set' |
|
word_dict = self.get_word_dict(sentences, tokenize) |
|
self.word_vec = self.get_w2v(word_dict) |
|
print('Vocab size : %s' % (len(self.word_vec))) |
|
|
|
|
|
def build_vocab_k_words(self, K): |
|
assert hasattr(self, 'w2v_path'), 'w2v path not set' |
|
self.word_vec = self.get_w2v_k(K) |
|
print('Vocab size : %s' % (K)) |
|
|
|
def update_vocab(self, sentences, tokenize=True): |
|
assert hasattr(self, 'w2v_path'), 'warning : w2v path not set' |
|
assert hasattr(self, 'word_vec'), 'build_vocab before updating it' |
|
word_dict = self.get_word_dict(sentences, tokenize) |
|
|
|
|
|
for word in self.word_vec: |
|
if word in word_dict: |
|
del word_dict[word] |
|
|
|
|
|
if word_dict: |
|
new_word_vec = self.get_w2v(word_dict) |
|
self.word_vec.update(new_word_vec) |
|
else: |
|
new_word_vec = [] |
|
print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec))) |
|
|
|
def get_batch(self, batch): |
|
|
|
|
|
embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim)) |
|
|
|
for i in range(len(batch)): |
|
for j in range(len(batch[i])): |
|
embed[j, i, :] = self.word_vec[batch[i][j]] |
|
|
|
return torch.FloatTensor(embed) |
|
|
|
def tokenize(self, s): |
|
from nltk.tokenize import word_tokenize |
|
if self.moses_tok: |
|
s = ' '.join(word_tokenize(s)) |
|
s = s.replace(" n't ", "n 't ") |
|
return s.split() |
|
else: |
|
return word_tokenize(s) |
|
|
|
def prepare_samples(self, sentences, bsize, tokenize, verbose): |
|
sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else |
|
[self.bos] + self.tokenize(s) + [self.eos] for s in sentences] |
|
n_w = np.sum([len(x) for x in sentences]) |
|
|
|
|
|
for i in range(len(sentences)): |
|
s_f = [word for word in sentences[i] if word in self.word_vec] |
|
if not s_f: |
|
import warnings |
|
warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \ |
|
Replacing by "</s>"..' % (sentences[i], i)) |
|
s_f = [self.eos] |
|
sentences[i] = s_f |
|
|
|
lengths = np.array([len(s) for s in sentences]) |
|
n_wk = np.sum(lengths) |
|
if verbose: |
|
print('Nb words kept : %s/%s (%.1f%s)' % ( |
|
n_wk, n_w, 100.0 * n_wk / n_w, '%')) |
|
|
|
|
|
lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths) |
|
sentences = np.array(sentences)[idx_sort] |
|
|
|
return sentences, lengths, idx_sort |
|
|
|
def encode(self, sentences, bsize=64, tokenize=True, verbose=False): |
|
tic = time.time() |
|
sentences, lengths, idx_sort = self.prepare_samples( |
|
sentences, bsize, tokenize, verbose) |
|
|
|
embeddings = [] |
|
for stidx in range(0, len(sentences), bsize): |
|
batch = self.get_batch(sentences[stidx:stidx + bsize]) |
|
if self.is_cuda(): |
|
batch = batch.cuda() |
|
with torch.no_grad(): |
|
batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy() |
|
embeddings.append(batch) |
|
embeddings = np.vstack(embeddings) |
|
|
|
|
|
idx_unsort = np.argsort(idx_sort) |
|
embeddings = embeddings[idx_unsort] |
|
|
|
if verbose: |
|
print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % ( |
|
len(embeddings)/(time.time()-tic), |
|
'gpu' if self.is_cuda() else 'cpu', bsize)) |
|
return embeddings |
|
|
|
def visualize(self, sent, tokenize=True): |
|
|
|
sent = sent.split() if not tokenize else self.tokenize(sent) |
|
sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]] |
|
|
|
if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos): |
|
import warnings |
|
warnings.warn('No words in "%s" have w2v vectors. Replacing \ |
|
by "%s %s"..' % (sent, self.bos, self.eos)) |
|
batch = self.get_batch(sent) |
|
|
|
if self.is_cuda(): |
|
batch = batch.cuda() |
|
output = self.enc_lstm(batch)[0] |
|
output, idxs = torch.max(output, 0) |
|
|
|
idxs = idxs.data.cpu().numpy() |
|
argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))] |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
x = range(len(sent[0])) |
|
y = [100.0 * n / np.sum(argmaxs) for n in argmaxs] |
|
plt.xticks(x, sent[0], rotation=45) |
|
plt.bar(x, y) |
|
plt.ylabel('%') |
|
plt.title('Visualisation of words importance') |
|
plt.show() |
|
|
|
return output, idxs |
|
|