|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
Image-Caption Retrieval with COCO dataset |
|
''' |
|
from __future__ import absolute_import, division, unicode_literals |
|
|
|
import os |
|
import sys |
|
import logging |
|
import numpy as np |
|
|
|
try: |
|
import cPickle as pickle |
|
except ImportError: |
|
import pickle |
|
|
|
from senteval.tools.ranking import ImageSentenceRankingPytorch |
|
|
|
|
|
class ImageCaptionRetrievalEval(object): |
|
def __init__(self, task_path, seed=1111): |
|
logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n') |
|
|
|
|
|
self.seed = seed |
|
train, dev, test = self.loadFile(task_path) |
|
self.coco_data = {'train': train, 'dev': dev, 'test': test} |
|
|
|
def do_prepare(self, params, prepare): |
|
samples = self.coco_data['train']['sent'] + \ |
|
self.coco_data['dev']['sent'] + \ |
|
self.coco_data['test']['sent'] |
|
prepare(params, samples) |
|
|
|
def loadFile(self, fpath): |
|
coco = {} |
|
|
|
for split in ['train', 'valid', 'test']: |
|
list_sent = [] |
|
list_img_feat = [] |
|
if sys.version_info < (3, 0): |
|
with open(os.path.join(fpath, split + '.pkl')) as f: |
|
cocodata = pickle.load(f) |
|
else: |
|
with open(os.path.join(fpath, split + '.pkl'), 'rb') as f: |
|
cocodata = pickle.load(f, encoding='latin1') |
|
|
|
for imgkey in range(len(cocodata['features'])): |
|
assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \ |
|
cocodata['image_to_caption_ids'][imgkey] |
|
for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]: |
|
sent = cocodata['captions'][captkey]['cleaned_caption'] |
|
sent += ' .' |
|
list_sent.append(sent.encode('utf-8').split()) |
|
list_img_feat.append(cocodata['features'][imgkey]) |
|
assert len(list_sent) == len(list_img_feat) and \ |
|
len(list_sent) % 5 == 0 |
|
list_img_feat = np.array(list_img_feat).astype('float32') |
|
coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat} |
|
return coco['train'], coco['valid'], coco['test'] |
|
|
|
def run(self, params, batcher): |
|
coco_embed = {'train': {'sentfeat': [], 'imgfeat': []}, |
|
'dev': {'sentfeat': [], 'imgfeat': []}, |
|
'test': {'sentfeat': [], 'imgfeat': []}} |
|
|
|
for key in self.coco_data: |
|
logging.info('Computing embedding for {0}'.format(key)) |
|
|
|
self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent']) |
|
self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent']) |
|
idx_unsort = np.argsort(idx_sort) |
|
|
|
coco_embed[key]['X'] = [] |
|
nsent = len(self.coco_data[key]['sent']) |
|
for ii in range(0, nsent, params.batch_size): |
|
batch = self.coco_data[key]['sent'][ii:ii + params.batch_size] |
|
embeddings = batcher(params, batch) |
|
coco_embed[key]['sentfeat'].append(embeddings) |
|
coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort] |
|
coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat']) |
|
logging.info('Computed {0} embeddings'.format(key)) |
|
|
|
config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2} |
|
clf = ImageSentenceRankingPytorch(train=coco_embed['train'], |
|
valid=coco_embed['dev'], |
|
test=coco_embed['test'], |
|
config=config) |
|
|
|
bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \ |
|
r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run() |
|
|
|
logging.debug("\nTest scores | Image to text: \ |
|
{0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) |
|
logging.debug("Test scores | Text to image: \ |
|
{0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) |
|
|
|
return {'devacc': bestdevscore, |
|
'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t), |
|
(r1_t2i, r5_t2i, r10_t2i, medr_t2i)], |
|
'ndev': len(coco_embed['dev']['sentfeat']), |
|
'ntest': len(coco_embed['test']['sentfeat'])} |
|
|