pere's picture
added SentEval
cd5fcb4
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
'''
Image-Caption Retrieval with COCO dataset
'''
from __future__ import absolute_import, division, unicode_literals
import os
import sys
import logging
import numpy as np
try:
import cPickle as pickle
except ImportError:
import pickle
from senteval.tools.ranking import ImageSentenceRankingPytorch
class ImageCaptionRetrievalEval(object):
def __init__(self, task_path, seed=1111):
logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n')
# Get captions and image features
self.seed = seed
train, dev, test = self.loadFile(task_path)
self.coco_data = {'train': train, 'dev': dev, 'test': test}
def do_prepare(self, params, prepare):
samples = self.coco_data['train']['sent'] + \
self.coco_data['dev']['sent'] + \
self.coco_data['test']['sent']
prepare(params, samples)
def loadFile(self, fpath):
coco = {}
for split in ['train', 'valid', 'test']:
list_sent = []
list_img_feat = []
if sys.version_info < (3, 0):
with open(os.path.join(fpath, split + '.pkl')) as f:
cocodata = pickle.load(f)
else:
with open(os.path.join(fpath, split + '.pkl'), 'rb') as f:
cocodata = pickle.load(f, encoding='latin1')
for imgkey in range(len(cocodata['features'])):
assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \
cocodata['image_to_caption_ids'][imgkey]
for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]:
sent = cocodata['captions'][captkey]['cleaned_caption']
sent += ' .' # add punctuation to end of sentence in COCO
list_sent.append(sent.encode('utf-8').split())
list_img_feat.append(cocodata['features'][imgkey])
assert len(list_sent) == len(list_img_feat) and \
len(list_sent) % 5 == 0
list_img_feat = np.array(list_img_feat).astype('float32')
coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat}
return coco['train'], coco['valid'], coco['test']
def run(self, params, batcher):
coco_embed = {'train': {'sentfeat': [], 'imgfeat': []},
'dev': {'sentfeat': [], 'imgfeat': []},
'test': {'sentfeat': [], 'imgfeat': []}}
for key in self.coco_data:
logging.info('Computing embedding for {0}'.format(key))
# Sort to reduce padding
self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent'])
self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent'])
idx_unsort = np.argsort(idx_sort)
coco_embed[key]['X'] = []
nsent = len(self.coco_data[key]['sent'])
for ii in range(0, nsent, params.batch_size):
batch = self.coco_data[key]['sent'][ii:ii + params.batch_size]
embeddings = batcher(params, batch)
coco_embed[key]['sentfeat'].append(embeddings)
coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort]
coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat'])
logging.info('Computed {0} embeddings'.format(key))
config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2}
clf = ImageSentenceRankingPytorch(train=coco_embed['train'],
valid=coco_embed['dev'],
test=coco_embed['test'],
config=config)
bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \
r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run()
logging.debug("\nTest scores | Image to text: \
{0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t))
logging.debug("Test scores | Text to image: \
{0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i))
return {'devacc': bestdevscore,
'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t),
(r1_t2i, r5_t2i, r10_t2i, medr_t2i)],
'ndev': len(coco_embed['dev']['sentfeat']),
'ntest': len(coco_embed['test']['sentfeat'])}