|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
|
Generic sentence evaluation scripts wrapper |
|
|
|
''' |
|
from __future__ import absolute_import, division, unicode_literals |
|
|
|
from senteval import utils |
|
from senteval.binary import CREval, MREval, MPQAEval, SUBJEval |
|
from senteval.snli import SNLIEval |
|
from senteval.trec import TRECEval |
|
from senteval.sick import SICKEntailmentEval, SICKEval |
|
from senteval.mrpc import MRPCEval |
|
from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune |
|
from senteval.sst import SSTEval |
|
from senteval.rank import ImageCaptionRetrievalEval |
|
from senteval.probing import * |
|
|
|
class SE(object): |
|
def __init__(self, params, batcher, prepare=None): |
|
|
|
params = utils.dotdict(params) |
|
params.usepytorch = True if 'usepytorch' not in params else params.usepytorch |
|
params.seed = 1111 if 'seed' not in params else params.seed |
|
|
|
params.batch_size = 128 if 'batch_size' not in params else params.batch_size |
|
params.nhid = 0 if 'nhid' not in params else params.nhid |
|
params.kfold = 5 if 'kfold' not in params else params.kfold |
|
|
|
if 'classifier' not in params or not params['classifier']: |
|
params.classifier = {'nhid': 0} |
|
|
|
assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!' |
|
|
|
self.params = params |
|
|
|
|
|
self.batcher = batcher |
|
self.prepare = prepare if prepare else lambda x, y: None |
|
|
|
self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', |
|
'SICKRelatedness', 'SICKEntailment', 'STSBenchmark', |
|
'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13', |
|
'STS14', 'STS15', 'STS16', |
|
'Length', 'WordContent', 'Depth', 'TopConstituents', |
|
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', |
|
'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix'] |
|
|
|
def eval(self, name): |
|
|
|
if (isinstance(name, list)): |
|
self.results = {x: self.eval(x) for x in name} |
|
return self.results |
|
|
|
tpath = self.params.task_path |
|
assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks) |
|
|
|
|
|
if name == 'CR': |
|
self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed) |
|
elif name == 'MR': |
|
self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed) |
|
elif name == 'MPQA': |
|
self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed) |
|
elif name == 'SUBJ': |
|
self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed) |
|
elif name == 'SST2': |
|
self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed) |
|
elif name == 'SST5': |
|
self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed) |
|
elif name == 'TREC': |
|
self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed) |
|
elif name == 'MRPC': |
|
self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed) |
|
elif name == 'SICKRelatedness': |
|
self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed) |
|
elif name == 'STSBenchmark': |
|
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) |
|
elif name == 'STSBenchmark-fix': |
|
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed) |
|
elif name == 'STSBenchmark-finetune': |
|
self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) |
|
elif name == 'SICKRelatedness-finetune': |
|
self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed) |
|
elif name == 'SICKEntailment': |
|
self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed) |
|
elif name == 'SNLI': |
|
self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed) |
|
elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: |
|
fpath = name + '-en-test' |
|
self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed) |
|
elif name == 'ImageCaptionRetrieval': |
|
self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed) |
|
|
|
|
|
elif name == 'Length': |
|
self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'WordContent': |
|
self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'Depth': |
|
self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'TopConstituents': |
|
self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'BigramShift': |
|
self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'Tense': |
|
self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'SubjNumber': |
|
self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'ObjNumber': |
|
self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'OddManOut': |
|
self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed) |
|
elif name == 'CoordinationInversion': |
|
self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed) |
|
|
|
self.params.current_task = name |
|
self.evaluation.do_prepare(self.params, self.prepare) |
|
|
|
self.results = self.evaluation.run(self.params, self.batcher) |
|
|
|
return self.results |
|
|