|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Image Annotation/Search for COCO with Pytorch |
|
""" |
|
from __future__ import absolute_import, division, unicode_literals |
|
|
|
import logging |
|
import copy |
|
import numpy as np |
|
|
|
import torch |
|
from torch import nn |
|
from torch.autograd import Variable |
|
import torch.optim as optim |
|
|
|
|
|
class COCOProjNet(nn.Module): |
|
def __init__(self, config): |
|
super(COCOProjNet, self).__init__() |
|
self.imgdim = config['imgdim'] |
|
self.sentdim = config['sentdim'] |
|
self.projdim = config['projdim'] |
|
self.imgproj = nn.Sequential( |
|
nn.Linear(self.imgdim, self.projdim), |
|
) |
|
self.sentproj = nn.Sequential( |
|
nn.Linear(self.sentdim, self.projdim), |
|
) |
|
|
|
def forward(self, img, sent, imgc, sentc): |
|
|
|
|
|
|
|
|
|
img = img.unsqueeze(1).expand_as(imgc).contiguous() |
|
img = img.view(-1, self.imgdim) |
|
imgc = imgc.view(-1, self.imgdim) |
|
sent = sent.unsqueeze(1).expand_as(sentc).contiguous() |
|
sent = sent.view(-1, self.sentdim) |
|
sentc = sentc.view(-1, self.sentdim) |
|
|
|
imgproj = self.imgproj(img) |
|
imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj) |
|
imgcproj = self.imgproj(imgc) |
|
imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj) |
|
sentproj = self.sentproj(sent) |
|
sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj) |
|
sentcproj = self.sentproj(sentc) |
|
sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj) |
|
|
|
|
|
anchor1 = torch.sum((imgproj*sentproj), 1) |
|
anchor2 = torch.sum((sentproj*imgproj), 1) |
|
img_sentc = torch.sum((imgproj*sentcproj), 1) |
|
sent_imgc = torch.sum((sentproj*imgcproj), 1) |
|
|
|
|
|
return anchor1, anchor2, img_sentc, sent_imgc |
|
|
|
def proj_sentence(self, sent): |
|
output = self.sentproj(sent) |
|
output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) |
|
return output |
|
|
|
def proj_image(self, img): |
|
output = self.imgproj(img) |
|
output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) |
|
return output |
|
|
|
|
|
class PairwiseRankingLoss(nn.Module): |
|
""" |
|
Pairwise ranking loss |
|
""" |
|
def __init__(self, margin): |
|
super(PairwiseRankingLoss, self).__init__() |
|
self.margin = margin |
|
|
|
def forward(self, anchor1, anchor2, img_sentc, sent_imgc): |
|
|
|
cost_sent = torch.clamp(self.margin - anchor1 + img_sentc, |
|
min=0.0).sum() |
|
cost_img = torch.clamp(self.margin - anchor2 + sent_imgc, |
|
min=0.0).sum() |
|
loss = cost_sent + cost_img |
|
return loss |
|
|
|
|
|
class ImageSentenceRankingPytorch(object): |
|
|
|
def __init__(self, train, valid, test, config): |
|
|
|
self.seed = config['seed'] |
|
np.random.seed(self.seed) |
|
torch.manual_seed(self.seed) |
|
torch.cuda.manual_seed(self.seed) |
|
|
|
self.train = train |
|
self.valid = valid |
|
self.test = test |
|
|
|
self.imgdim = len(train['imgfeat'][0]) |
|
self.sentdim = len(train['sentfeat'][0]) |
|
self.projdim = config['projdim'] |
|
self.margin = config['margin'] |
|
|
|
self.batch_size = 128 |
|
self.ncontrast = 30 |
|
self.maxepoch = 20 |
|
self.early_stop = True |
|
|
|
config_model = {'imgdim': self.imgdim,'sentdim': self.sentdim, |
|
'projdim': self.projdim} |
|
self.model = COCOProjNet(config_model).cuda() |
|
|
|
self.loss_fn = PairwiseRankingLoss(margin=self.margin).cuda() |
|
|
|
self.optimizer = optim.Adam(self.model.parameters()) |
|
|
|
def prepare_data(self, trainTxt, trainImg, devTxt, devImg, |
|
testTxt, testImg): |
|
trainTxt = torch.FloatTensor(trainTxt) |
|
trainImg = torch.FloatTensor(trainImg) |
|
devTxt = torch.FloatTensor(devTxt).cuda() |
|
devImg = torch.FloatTensor(devImg).cuda() |
|
testTxt = torch.FloatTensor(testTxt).cuda() |
|
testImg = torch.FloatTensor(testImg).cuda() |
|
|
|
return trainTxt, trainImg, devTxt, devImg, testTxt, testImg |
|
|
|
def run(self): |
|
self.nepoch = 0 |
|
bestdevscore = -1 |
|
early_stop_count = 0 |
|
stop_train = False |
|
|
|
|
|
logging.info('prepare data') |
|
trainTxt, trainImg, devTxt, devImg, testTxt, testImg = \ |
|
self.prepare_data(self.train['sentfeat'], self.train['imgfeat'], |
|
self.valid['sentfeat'], self.valid['imgfeat'], |
|
self.test['sentfeat'], self.test['imgfeat']) |
|
|
|
|
|
while not stop_train and self.nepoch <= self.maxepoch: |
|
logging.info('start epoch') |
|
self.trainepoch(trainTxt, trainImg, devTxt, devImg, nepoches=1) |
|
logging.info('Epoch {0} finished'.format(self.nepoch)) |
|
|
|
results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
|
't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
|
'dev': bestdevscore} |
|
score = 0 |
|
for i in range(5): |
|
devTxt_i = devTxt[i*5000:(i+1)*5000] |
|
devImg_i = devImg[i*5000:(i+1)*5000] |
|
|
|
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg_i, |
|
devTxt_i) |
|
results['i2t']['r1'] += r1_i2t / 5 |
|
results['i2t']['r5'] += r5_i2t / 5 |
|
results['i2t']['r10'] += r10_i2t / 5 |
|
results['i2t']['medr'] += medr_i2t / 5 |
|
logging.info("Image to text: {0}, {1}, {2}, {3}" |
|
.format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) |
|
|
|
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg_i, |
|
devTxt_i) |
|
results['t2i']['r1'] += r1_t2i / 5 |
|
results['t2i']['r5'] += r5_t2i / 5 |
|
results['t2i']['r10'] += r10_t2i / 5 |
|
results['t2i']['medr'] += medr_t2i / 5 |
|
logging.info("Text to Image: {0}, {1}, {2}, {3}" |
|
.format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) |
|
score += (r1_i2t + r5_i2t + r10_i2t + |
|
r1_t2i + r5_t2i + r10_t2i) / 5 |
|
|
|
logging.info("Dev mean Text to Image: {0}, {1}, {2}, {3}".format( |
|
results['t2i']['r1'], results['t2i']['r5'], |
|
results['t2i']['r10'], results['t2i']['medr'])) |
|
logging.info("Dev mean Image to text: {0}, {1}, {2}, {3}".format( |
|
results['i2t']['r1'], results['i2t']['r5'], |
|
results['i2t']['r10'], results['i2t']['medr'])) |
|
|
|
|
|
if score > bestdevscore: |
|
bestdevscore = score |
|
bestmodel = copy.deepcopy(self.model) |
|
elif self.early_stop: |
|
if early_stop_count >= 3: |
|
stop_train = True |
|
early_stop_count += 1 |
|
self.model = bestmodel |
|
|
|
|
|
results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
|
't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
|
'dev': bestdevscore} |
|
for i in range(5): |
|
testTxt_i = testTxt[i*5000:(i+1)*5000] |
|
testImg_i = testImg[i*5000:(i+1)*5000] |
|
|
|
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(testImg_i, testTxt_i) |
|
results['i2t']['r1'] += r1_i2t / 5 |
|
results['i2t']['r5'] += r5_i2t / 5 |
|
results['i2t']['r10'] += r10_i2t / 5 |
|
results['i2t']['medr'] += medr_i2t / 5 |
|
|
|
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(testImg_i, testTxt_i) |
|
results['t2i']['r1'] += r1_t2i / 5 |
|
results['t2i']['r5'] += r5_t2i / 5 |
|
results['t2i']['r10'] += r10_t2i / 5 |
|
results['t2i']['medr'] += medr_t2i / 5 |
|
|
|
return bestdevscore, results['i2t']['r1'], results['i2t']['r5'], \ |
|
results['i2t']['r10'], results['i2t']['medr'], \ |
|
results['t2i']['r1'], results['t2i']['r5'], \ |
|
results['t2i']['r10'], results['t2i']['medr'] |
|
|
|
def trainepoch(self, trainTxt, trainImg, devTxt, devImg, nepoches=1): |
|
self.model.train() |
|
for _ in range(self.nepoch, self.nepoch + nepoches): |
|
permutation = list(np.random.permutation(len(trainTxt))) |
|
all_costs = [] |
|
for i in range(0, len(trainTxt), self.batch_size): |
|
|
|
if i % (self.batch_size*500) == 0 and i > 0: |
|
logging.info('samples : {0}'.format(i)) |
|
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg, |
|
devTxt) |
|
logging.info("Image to text: {0}, {1}, {2}, {3}".format( |
|
r1_i2t, r5_i2t, r10_i2t, medr_i2t)) |
|
|
|
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg, |
|
devTxt) |
|
logging.info("Text to Image: {0}, {1}, {2}, {3}".format( |
|
r1_t2i, r5_t2i, r10_t2i, medr_t2i)) |
|
idx = torch.LongTensor(permutation[i:i + self.batch_size]) |
|
imgbatch = Variable(trainImg.index_select(0, idx)).cuda() |
|
sentbatch = Variable(trainTxt.index_select(0, idx)).cuda() |
|
|
|
idximgc = np.random.choice(permutation[:i] + |
|
permutation[i + self.batch_size:], |
|
self.ncontrast*idx.size(0)) |
|
idxsentc = np.random.choice(permutation[:i] + |
|
permutation[i + self.batch_size:], |
|
self.ncontrast*idx.size(0)) |
|
idximgc = torch.LongTensor(idximgc) |
|
idxsentc = torch.LongTensor(idxsentc) |
|
|
|
imgcbatch = Variable(trainImg.index_select(0, idximgc)).view( |
|
-1, self.ncontrast, self.imgdim).cuda() |
|
sentcbatch = Variable(trainTxt.index_select(0, idxsentc)).view( |
|
-1, self.ncontrast, self.sentdim).cuda() |
|
|
|
anchor1, anchor2, img_sentc, sent_imgc = self.model( |
|
imgbatch, sentbatch, imgcbatch, sentcbatch) |
|
|
|
loss = self.loss_fn(anchor1, anchor2, img_sentc, sent_imgc) |
|
all_costs.append(loss.data.item()) |
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
|
|
self.optimizer.step() |
|
self.nepoch += nepoches |
|
|
|
def t2i(self, images, captions): |
|
""" |
|
Images: (5N, imgdim) matrix of images |
|
Captions: (5N, sentdim) matrix of captions |
|
""" |
|
with torch.no_grad(): |
|
|
|
img_embed, sent_embed = [], [] |
|
for i in range(0, len(images), self.batch_size): |
|
img_embed.append(self.model.proj_image( |
|
Variable(images[i:i + self.batch_size]))) |
|
sent_embed.append(self.model.proj_sentence( |
|
Variable(captions[i:i + self.batch_size]))) |
|
img_embed = torch.cat(img_embed, 0).data |
|
sent_embed = torch.cat(sent_embed, 0).data |
|
|
|
npts = int(img_embed.size(0) / 5) |
|
idxs = torch.cuda.LongTensor(range(0, len(img_embed), 5)) |
|
ims = img_embed.index_select(0, idxs) |
|
|
|
ranks = np.zeros(5 * npts) |
|
for index in range(npts): |
|
|
|
|
|
queries = sent_embed[5*index: 5*index + 5] |
|
|
|
|
|
scores = torch.mm(queries, ims.transpose(0, 1)).cpu().numpy() |
|
inds = np.zeros(scores.shape) |
|
for i in range(len(inds)): |
|
inds[i] = np.argsort(scores[i])[::-1] |
|
ranks[5 * index + i] = np.where(inds[i] == index)[0][0] |
|
|
|
|
|
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
|
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
|
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
|
medr = np.floor(np.median(ranks)) + 1 |
|
return (r1, r5, r10, medr) |
|
|
|
def i2t(self, images, captions): |
|
""" |
|
Images: (5N, imgdim) matrix of images |
|
Captions: (5N, sentdim) matrix of captions |
|
""" |
|
with torch.no_grad(): |
|
|
|
img_embed, sent_embed = [], [] |
|
for i in range(0, len(images), self.batch_size): |
|
img_embed.append(self.model.proj_image( |
|
Variable(images[i:i + self.batch_size]))) |
|
sent_embed.append(self.model.proj_sentence( |
|
Variable(captions[i:i + self.batch_size]))) |
|
img_embed = torch.cat(img_embed, 0).data |
|
sent_embed = torch.cat(sent_embed, 0).data |
|
|
|
npts = int(img_embed.size(0) / 5) |
|
index_list = [] |
|
|
|
ranks = np.zeros(npts) |
|
for index in range(npts): |
|
|
|
|
|
query_img = img_embed[5 * index] |
|
|
|
|
|
scores = torch.mm(query_img.view(1, -1), |
|
sent_embed.transpose(0, 1)).view(-1) |
|
scores = scores.cpu().numpy() |
|
inds = np.argsort(scores)[::-1] |
|
index_list.append(inds[0]) |
|
|
|
|
|
rank = 1e20 |
|
for i in range(5*index, 5*index + 5, 1): |
|
tmp = np.where(inds == i)[0][0] |
|
if tmp < rank: |
|
rank = tmp |
|
ranks[index] = rank |
|
|
|
|
|
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
|
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
|
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
|
medr = np.floor(np.median(ranks)) + 1 |
|
return (r1, r5, r10, medr) |
|
|