import torch import torch.nn.functional as F def pairLoss(fea1, fea2, mask): # fea1_size (bs, max_len, dim) # fea2_size (bs, max_len, dim) # mask_size (bs, max_len) # '-Inf' for padded item, '0' for others fea1 = F.normalize(fea1, p=2, dim=-1) fea2 = F.normalize(fea2, p=2, dim=-1) fea_sim = (fea1 * fea2).sum(dim=-1) # (bs, max_len) fea_sim = torch.masked_select(fea_sim, mask == 0) loss = 1.0 - torch.mean(fea_sim) return loss def SimpTripLoss(fea1, fea2): # img fea1_size (bs, max_len1, dim) mask1_size (bs, max_len1) # text fea2_size (bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others # fea1 = fea1.mean(dim=1) #(bs, dim) # mask2 = torch.where(mask2==0, torch.tensor([1.0],device=mask2.device), torch.tensor([0.0],device=mask2.device)) # fea2 = (fea2 * mask2.unsqueeze(-1)).sum(dim=1) / mask2.sum(dim=1).unsqueeze(-1) #(bs, dim) fea1 = F.normalize(fea1, p=2, dim=-1) fea2 = F.normalize(fea2, p=2, dim=-1) # match fea1 to fea2 sim_pos1 = (fea1 * fea2).sum(dim=1) # (bs) # (bs, 1, dim) (1, bs, dim) sim_neg1_all = (fea1.unsqueeze(1) * fea2.unsqueeze(0)).sum(dim=-1) # (bs,bs) unmask = torch.eye(sim_pos1.size(0), dtype=torch.float32, device=sim_pos1.device) unmask = torch.where(unmask == 1, torch.tensor([float('-Inf')], device=unmask.device), unmask) sim_neg1, _ = torch.max(sim_neg1_all + unmask, 1) loss1 = -sim_pos1 + sim_neg1 + 0.2 loss1 = torch.maximum(loss1, torch.zeros_like(loss1)).mean() # match fea2 to fea1 sim_pos2 = (fea2 * fea1).sum( dim=1) # (bs) sim_neg2_all = (fea2.unsqueeze(1) * fea1.unsqueeze(0)).sum(dim=-1) #(bs,bs) sim_neg2_all = (fea2.unsqueeze(1) * fea1.unsqueeze(0)).sum(dim=-1) # (bs,bs) sim_neg2, _ = torch.max(sim_neg2_all + unmask, 1) loss2 = -sim_pos2 + sim_neg2 + 0.2 loss2 = torch.maximum(loss2, torch.zeros_like(loss2)).mean() loss = loss1 + loss2 return loss def NCELoss(fea1, fea2): # img fea1_size (bs, max_len1, dim) mask1_size (bs, max_len1) # text fea2_size (bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others # fea1 = fea1.mean(dim=1) #(bs, dim) # mask2 = torch.where(mask2==0, torch.tensor([1.0],device=mask2.device), torch.tensor([0.0],device=mask2.device)) # fea2 = (fea2 * mask2.unsqueeze(-1)).sum(dim=1) / mask2.sum(dim=1).unsqueeze(-1) #(bs, dim) fea1 = F.normalize(fea1, p=2, dim=-1) fea2 = F.normalize(fea2, p=2, dim=-1) # match fea1 to fea2 sim_pos1 = (fea1 * fea2).sum(dim=1).unsqueeze(-1) # (bs,1) BS = sim_pos1.size(0) # (bs, 1, dim) (1, bs, dim) sim_neg1_all = (fea1.unsqueeze(1) * fea2.unsqueeze(0)).sum(dim=-1) # (bs,bs) unmask = torch.eye(sim_pos1.size(0), dtype=torch.float32, device=sim_pos1.device) sim_neg1_all = torch.masked_select(sim_neg1_all, unmask == 0).view(BS, BS - 1) # (bs, bs-1) sim1_pos_neg = torch.cat((sim_pos1, sim_neg1_all), dim=1) / 0.07 # (bs, bs) loss1 = -F.log_softmax(sim1_pos_neg, dim=1)[:, 0].mean() # match fea2 to fea1 sim_pos2 = (fea2 * fea1).sum(dim=1).unsqueeze(-1) # (bs,1) sim_neg2_all = (fea2.unsqueeze(1) * fea1.unsqueeze(0)).sum(dim=-1) # (bs,bs) sim_neg2_all = torch.masked_select(sim_neg2_all, unmask == 0).view(BS, BS - 1) # (bs, bs-1) sim2_pos_neg = torch.cat((sim_pos2, sim_neg2_all), dim=1) / 0.07 # (bs, bs) loss2 = -F.log_softmax(sim2_pos_neg, dim=1)[:, 0].mean() loss = (loss1 + loss2) / 2.0 return loss def AlignTripLoss(fea1, fea2, mask1, mask2): # fea1_size (bs, max_len1, dim) mask1_size (bs, max_len1) # fea2_size (bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others fea1 = F.normalize(fea1, p=2, dim=-1) fea2 = F.normalize(fea2, p=2, dim=-1) # match fea1 to fea2 sim_pos1 = cal_sim(fea1, fea2, mask1, mask2) # (bs) # (bs, 1, max_len1, dim) (1, bs, max_len2, dim) sim_neg1_all = cal_sim_all(fea1.unsqueeze(1), fea2.unsqueeze(0), mask1, mask2) # (bs,bs) unmask = torch.eye(sim_pos1.size(0), dtype=torch.float32, device=sim_pos1.device) unmask = torch.where(unmask == 1, torch.tensor([float('-Inf')], device=unmask.device), unmask) sim_neg1, _ = torch.max(sim_neg1_all + unmask, 1) loss1 = -sim_pos1 + sim_neg1 + 0.2 loss1 = torch.maximum(loss1, torch.zeros_like(loss1)).mean() # match fea2 to fea1 sim_pos2 = cal_sim(fea2, fea1, mask2, mask1) # (bs) # (bs, 1, max_len1, dim) (1, bs, max_len2, dim) sim_neg2_all = cal_sim_all(fea2.unsqueeze(1), fea1.unsqueeze(0), mask2, mask1) # (bs,bs) sim_neg2, _ = torch.max(sim_neg2_all + unmask, 1) loss2 = -sim_pos2 + sim_neg2 + 0.2 loss2 = torch.maximum(loss2, torch.zeros_like(loss2)).mean() loss = loss1 + loss2 return loss def cal_sim_all(fea1, fea2, mask1, mask2): # fea1_size (bs, 1, max_len1, dim) mask1_size (bs, max_len1) # fea2_size (1, bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others max_len1 = fea1.size(2) max_len2 = fea2.size(2) bs = fea1.size(0) fea1_tmp = fea1.unsqueeze(3) # (bs, 1, max_len1, 1, dim) fea2_tmp = fea2.unsqueeze(2) # (1, bs, 1, max_len2, dim) fea_sim = (fea1_tmp * fea2_tmp).sum(dim=-1) # (bs, bs, max_len1, max_len2) fea_sim = fea_sim + mask2.unsqueeze(dim=1) # (bs, bs, max_len1, max_len2) idxs = torch.argmax(fea_sim, dim=-1).view(-1).unsqueeze(-1) # (bs*bs*max_len1, 1) fea_sim = fea_sim.view(-1, max_len2) # (bs*bs*max_len1, max_len2) select_sim = torch.gather(fea_sim, 1, idxs).view(bs, bs, max_len1) # (bs, bs, max_len1) mask1_mult = torch.where(mask1 == 0, torch.tensor([1.0], device=mask1.device), torch.tensor([0.0], device=mask1.device)).unsqueeze(1) # (bs, 1, max_len1) select_sim = (select_sim * mask1_mult).sum(dim=-1) / mask1_mult.sum(dim=-1) # (bs, bs) return select_sim def cal_sim(fea1, fea2, mask1, mask2): # fea1_size (bs, max_len1, dim) mask1_size (bs, max_len1) # fea2_size (bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others max_len1 = fea1.size(1) max_len2 = fea2.size(1) fea1_tmp = fea1.unsqueeze(2) # (bs, max_len1, 1, dim) fea2_tmp = fea2.unsqueeze(1) # (bs, 1, max_len2, dim) fea_sim = (fea1_tmp * fea2_tmp).sum(dim=-1) # (bs, max_len1, max_len2) fea_sim = fea_sim + mask2.unsqueeze(dim=1) # (bs, max_len1, max_len2) idxs = torch.argmax(fea_sim, dim=-1).view(-1).unsqueeze(-1) # (bs*max_len1, 1) fea_sim = fea_sim.view(-1, max_len2) # (bs*max_len1, max_len2) select_sim = torch.gather(fea_sim, 1, idxs).view(-1, max_len1) # (bs, max_len1) mask1_mult = torch.where(mask1 == 0, 1, 0) select_sim = (select_sim * mask1_mult).sum(dim=-1) / mask1_mult.sum(dim=-1) # (bs) return select_sim def alignmentLoss(fea1, fea2, mask1, mask2): # fea1_size (bs, max_len1, dim) mask1_size (bs, max_len1) # fea2_size (bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others fea1 = F.normalize(fea1, p=2, dim=-1) fea2 = F.normalize(fea2, p=2, dim=-1) loss1 = alignSingleLoss(fea1, fea2, mask1, mask2) loss2 = alignSingleLoss(fea2, fea1, mask2, mask1) loss = (loss1 + loss2) / 2.0 return loss def attAlignmentLoss(fea1, fea2, mask1, mask2, attFc): # fea1_size (bs, max_len1, dim) mask1_size (bs, max_len1) # fea2_size (bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others fea1 = F.normalize(fea1, p=2, dim=-1) fea2 = F.normalize(fea2, p=2, dim=-1) fea1_tmp = fea1.unsqueeze(2) # (bs, max_len1, 1, dim) fea2_tmp = fea2.unsqueeze(1) # (bs, 1, max_len2, dim) fea_sim = fea1_tmp * fea2_tmp att_sim = attFc(fea_sim).squeeze(-1) # (bs, max_len1, max_len2) fea_sim = fea_sim.sum(dim=-1) # (bs, max_len1, max_len2) fea_sim = fea_sim * att_sim # (bs, max_len1, max_len2) ###Simple as max_len1=49 loss = torch.masked_select(fea_sim, (mask2 == 0).unsqueeze(1)) loss = 1.0 - loss.mean() return loss def alignSingleLoss(fea1, fea2, mask1, mask2): # fea1_size (bs, max_len1, dim) mask1_size (bs, max_len1) # fea2_size (bs, max_len2, dim) mask2_size (bs, max_len2) # '-Inf' for padded item, '0' for others fea1_tmp = fea1.unsqueeze(2) # (bs, max_len1, 1, dim) fea2_tmp = fea2.unsqueeze(1) # (bs, 1, max_len2, dim) fea_sim = (fea1_tmp * fea2_tmp).sum(dim=-1) # (bs, max_len1, max_len2) fea_sim = fea_sim + mask2.unsqueeze(dim=1) # (bs, max_len1, max_len2) idxs = torch.argmax(fea_sim, dim=-1).view(-1).unsqueeze(-1) # (bs*max_len1, 1) fea_sim = fea_sim.view(-1, fea_sim.size(-1)) # (bs*max_len1, max_len2) select_sim = torch.gather(fea_sim, 1, idxs).view(-1) # (bs*max_len1) select_sim = torch.masked_select(select_sim, (mask1 == 0).view(-1)) loss = 1.0 - torch.mean(select_sim) return loss def getLanMask(seq_lens, max_len): # seq_lens (bs) mask = torch.ones((seq_lens.size(0), max_len)) # (bs, max_len) idxs = torch.arange(max_len).unsqueeze(dim=0) # (1, max_len) seq_lens = seq_lens.unsqueeze(-1) # (bs, 1) mask = torch.where(idxs < seq_lens, mask, torch.Tensor([0.0])) return mask