|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import, division, unicode_literals |
|
|
|
import sys |
|
import io |
|
import numpy as np |
|
import logging |
|
|
|
|
|
|
|
PATH_TO_SENTEVAL = '../' |
|
PATH_TO_DATA = '../data' |
|
|
|
PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec' |
|
|
|
|
|
sys.path.insert(0, PATH_TO_SENTEVAL) |
|
import senteval |
|
|
|
|
|
|
|
def create_dictionary(sentences, threshold=0): |
|
words = {} |
|
for s in sentences: |
|
for word in s: |
|
words[word] = words.get(word, 0) + 1 |
|
|
|
if threshold > 0: |
|
newwords = {} |
|
for word in words: |
|
if words[word] >= threshold: |
|
newwords[word] = words[word] |
|
words = newwords |
|
words['<s>'] = 1e9 + 4 |
|
words['</s>'] = 1e9 + 3 |
|
words['<p>'] = 1e9 + 2 |
|
|
|
sorted_words = sorted(words.items(), key=lambda x: -x[1]) |
|
id2word = [] |
|
word2id = {} |
|
for i, (w, _) in enumerate(sorted_words): |
|
id2word.append(w) |
|
word2id[w] = i |
|
|
|
return id2word, word2id |
|
|
|
|
|
def get_wordvec(path_to_vec, word2id): |
|
word_vec = {} |
|
|
|
with io.open(path_to_vec, 'r', encoding='utf-8') as f: |
|
|
|
for line in f: |
|
word, vec = line.split(' ', 1) |
|
if word in word2id: |
|
word_vec[word] = np.fromstring(vec, sep=' ') |
|
|
|
logging.info('Found {0} words with word vectors, out of \ |
|
{1} words'.format(len(word_vec), len(word2id))) |
|
return word_vec |
|
|
|
|
|
|
|
def prepare(params, samples): |
|
_, params.word2id = create_dictionary(samples) |
|
params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id) |
|
params.wvec_dim = 300 |
|
return |
|
|
|
def batcher(params, batch): |
|
batch = [sent if sent != [] else ['.'] for sent in batch] |
|
embeddings = [] |
|
|
|
for sent in batch: |
|
sentvec = [] |
|
for word in sent: |
|
if word in params.word_vec: |
|
sentvec.append(params.word_vec[word]) |
|
if not sentvec: |
|
vec = np.zeros(params.wvec_dim) |
|
sentvec.append(vec) |
|
sentvec = np.mean(sentvec, 0) |
|
embeddings.append(sentvec) |
|
|
|
embeddings = np.vstack(embeddings) |
|
return embeddings |
|
|
|
|
|
|
|
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} |
|
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, |
|
'tenacity': 3, 'epoch_size': 2} |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) |
|
|
|
if __name__ == "__main__": |
|
se = senteval.engine.SE(params_senteval, batcher, prepare) |
|
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', |
|
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', |
|
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', |
|
'Length', 'WordContent', 'Depth', 'TopConstituents', |
|
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', |
|
'OddManOut', 'CoordinationInversion'] |
|
results = se.eval(transfer_tasks) |
|
print(results) |
|
|