# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # ''' Generic sentence evaluation scripts wrapper ''' from __future__ import absolute_import, division, unicode_literals from senteval import utils from senteval.binary import CREval, MREval, MPQAEval, SUBJEval from senteval.snli import SNLIEval from senteval.trec import TRECEval from senteval.sick import SICKEntailmentEval, SICKEval from senteval.mrpc import MRPCEval from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune from senteval.sst import SSTEval from senteval.rank import ImageCaptionRetrievalEval from senteval.probing import * class SE(object): def __init__(self, params, batcher, prepare=None): # parameters params = utils.dotdict(params) params.usepytorch = True if 'usepytorch' not in params else params.usepytorch params.seed = 1111 if 'seed' not in params else params.seed params.batch_size = 128 if 'batch_size' not in params else params.batch_size params.nhid = 0 if 'nhid' not in params else params.nhid params.kfold = 5 if 'kfold' not in params else params.kfold if 'classifier' not in params or not params['classifier']: params.classifier = {'nhid': 0} assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!' self.params = params # batcher and prepare self.batcher = batcher self.prepare = prepare if prepare else lambda x, y: None self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SICKRelatedness', 'SICKEntailment', 'STSBenchmark', 'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'Length', 'WordContent', 'Depth', 'TopConstituents', 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix'] def eval(self, name): # evaluate on evaluation [name], either takes string or list of strings if (isinstance(name, list)): self.results = {x: self.eval(x) for x in name} return self.results tpath = self.params.task_path assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks) # Original SentEval tasks if name == 'CR': self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed) elif name == 'MR': self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed) elif name == 'MPQA': self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed) elif name == 'SUBJ': self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed) elif name == 'SST2': self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed) elif name == 'SST5': self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed) elif name == 'TREC': self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed) elif name == 'MRPC': self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed) elif name == 'SICKRelatedness': self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed) elif name == 'STSBenchmark': self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) elif name == 'STSBenchmark-fix': self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed) elif name == 'STSBenchmark-finetune': self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) elif name == 'SICKRelatedness-finetune': self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed) elif name == 'SICKEntailment': self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed) elif name == 'SNLI': self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed) elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: fpath = name + '-en-test' self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed) elif name == 'ImageCaptionRetrieval': self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed) # Probing Tasks elif name == 'Length': self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed) elif name == 'WordContent': self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed) elif name == 'Depth': self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed) elif name == 'TopConstituents': self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed) elif name == 'BigramShift': self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed) elif name == 'Tense': self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed) elif name == 'SubjNumber': self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed) elif name == 'ObjNumber': self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed) elif name == 'OddManOut': self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed) elif name == 'CoordinationInversion': self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed) self.params.current_task = name self.evaluation.do_prepare(self.params, self.prepare) self.results = self.evaluation.run(self.params, self.batcher) return self.results