kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame
18 kB
import torch
import torch.nn as nn
from .fakeTransformer import FakeTransformer
from .bert import Bert
from ..utils import pairLoss, alignmentLoss, attAlignmentLoss, AlignTripLoss, SimpTripLoss, NCELoss
import torch.nn.functional as F
import timm
import numpy as np
import sys
class ImgLearnableEncoder(nn.Module):
def __init__(self, model_cfg):
super(ImgLearnableEncoder, self).__init__()
self.backbone = timm.create_model(model_cfg.CNN, pretrained=True)
self.model_cfg = model_cfg
self.learnable = nn.ModuleDict()
self.learnable['imgFC'] = FakeTransformer(model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM)
img_encoder_layer = nn.TransformerEncoderLayer(d_model=model_cfg.IMG_FEATURE_DIM, nhead=model_cfg.IMG_TRANSFORMER_HEAD)
self.learnable['imgAtt'] = nn.TransformerEncoder(img_encoder_layer, num_layers=model_cfg.IMG_TRANSFORMER_LAYER)
self.learnable['max_pool'] = nn.Sequential(
nn.Conv2d(model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM, kernel_size=1),
nn.AvgPool2d(model_cfg.GRID_SIZE, stride=1)
)
self.init_param()
def init_param(self):
for name, param in self.backbone.named_parameters():
# print('@@@@@@@@@@@@@@@@@@@@@@@')
condition = 'blocks.6' not in name and 'blocks.5' not in name and 'blocks.4' not in name and 'blocks.3' not in name
if condition:
param.requires_grad = False
else:
# print(name + ' need grads')
param.requires_grad = True
sys.stdout.flush()
def roi_grid_pool(self, spatial_features_2d, rois):
"""
Args:
rois: (B, num_rois, 4)
spatial_features_2d: (B, C, H, W)
Returns:
pooled_features : (B, num_rois, C)
"""
batch_size = spatial_features_2d.size(0)
rois = rois.detach()
height, width = spatial_features_2d.size(2), spatial_features_2d.size(3) # 特征图的长宽
#print(spatial_features_2d.size())
down_sample_ratio = self.model_cfg.IMG_SIZE / height
pooled_features_list = []
torch.backends.cudnn.enabled = False
for b_id in range(batch_size):
# todo 这里有一个坐标系的转换需要做
# Map global boxes coordinates to feature map coordinates
x1 = rois[b_id, :, 0] / down_sample_ratio
y1 = rois[b_id, :, 1] / down_sample_ratio
x2 = rois[b_id, :, 2] / down_sample_ratio
y2 = rois[b_id, :, 3] / down_sample_ratio
#print(x1, y1, x2, y2)
angle = torch.zeros((1), device=spatial_features_2d.device) ##########
cosa = torch.cos(angle)
sina = torch.sin(angle)
theta = torch.stack((
(x2 - x1) / (width - 1) * cosa, (x2 - x1) / (width - 1) * (-sina), (x1 + x2 - width + 1) / (width - 1),
(y2 - y1) / (height - 1) * sina, (y2 - y1) / (height - 1) * cosa, (y1 + y2 - height + 1) / (height - 1)
), dim=1).view(-1, 2, 3).float()
grid_size = self.model_cfg.GRID_SIZE
grid = nn.functional.affine_grid(
theta,
torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size))
)
pooled_features = nn.functional.grid_sample(
spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width),
grid
)
pooled_features = self.learnable['max_pool'](pooled_features)
pooled_features_list.append(pooled_features.squeeze())
torch.backends.cudnn.enabled = True
pooled_features = torch.stack(pooled_features_list, dim=0)
return pooled_features
def forward(self, imgFea, maskImages, image_boxs):
feature_map = self.backbone.forward_features(imgFea)
imgFea = self.roi_grid_pool(feature_map, image_boxs)
imgFea = F.normalize(imgFea, p=2, dim=-1)
imgFea = self.learnable['imgAtt'](imgFea.transpose(0, 1), src_key_padding_mask=(maskImages == 0)).transpose(0,1)
tmpMask = torch.where(maskImages == 1, torch.tensor([1.0], device=maskImages.device),
torch.tensor([0.0], device=maskImages.device))
imgFea = (imgFea * tmpMask.unsqueeze(-1)).sum(dim=1) / tmpMask.sum(dim=1).unsqueeze(-1) # (bs, dim)
imgFea = self.learnable['imgFC'](imgFea)
return imgFea
class TextLearnableEncoder(nn.Module):
def __init__(self, model_cfg):
super(TextLearnableEncoder, self).__init__()
self.backbone = Bert(model_cfg)
self.model_cfg = model_cfg
self.learnable = nn.ModuleDict()
self.learnable['textFC'] = FakeTransformer(model_cfg.TEXT_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM)
text_encoder_layer = nn.TransformerEncoderLayer(d_model=model_cfg.TEXT_FEATURE_DIM, nhead=model_cfg.TEXT_TRANSFORMER_HEAD)
self.learnable['textAtt'] = nn.TransformerEncoder(text_encoder_layer, num_layers=model_cfg.TEXT_TRANSFORMER_LAYER)
self.init_param()
def init_param(self):
#print('!!!!!!!!!!!!!!!!')
for name, param in self.backbone.named_parameters():
#print(name)
if 'large' not in self.model_cfg.ENCODER:
if 'layer.11' not in name and 'layer.10' not in name and 'layer.9' not in name and 'layer.8' not in name:
param.requires_grad = False
else:
#print('????????')
# print(name + ' need grads')
param.requires_grad = True
else:
if 'layer.21' not in name and 'layer.22' not in name and 'layer.23' not in name and 'layer.20' not in name: # and 'layer.9' not in name
param.requires_grad = False
else:
#print('????????')
# print(name + ' need grads')
param.requires_grad = True
sys.stdout.flush()
def forward(self, textFea, maskTexts):
textFea = self.backbone(textFea)
textFea = F.normalize(textFea, p=2, dim=-1)
# print(textFea.shape) # torch.Size([75, 80, 1024])
# print(maskTexts.shape)
# print(1)
textFea = self.learnable['textAtt'](textFea.transpose(0, 1), src_key_padding_mask=(maskTexts == 0)).transpose(0,1)
# print(textFea.shape) # torch.Size([75, 80, 1024])
# print(2)
tmpMask = torch.where(maskTexts == 1, torch.tensor([1.0], device=maskTexts.device),
torch.tensor([0.0], device=maskTexts.device))
textFea = (textFea * tmpMask.unsqueeze(-1)).sum(dim=1) / tmpMask.sum(dim=1).unsqueeze(-1) # (bs, dim)
# print(textFea.shape) # torch.Size([75, 80, 1024])
# print(3)
textFea = self.learnable['textFC'](textFea)
# print(textFea.shape) # torch.Size([75, 80, 1024])
# print(4)
return textFea
class VL_model(nn.Module):
def __init__(self, model_cfg):
super(VL_model, self).__init__()
self.model_cfg = model_cfg
self.learnable = nn.ModuleDict()
self.learnable['imgencoder'] = ImgLearnableEncoder(model_cfg)
self.learnable['imgencoder_mom'] = ImgLearnableEncoder(model_cfg)
self.learnable['textencoder'] = TextLearnableEncoder(model_cfg)
self.learnable['textencoder_mom'] = TextLearnableEncoder(model_cfg)
#self.generator = Generator(model_cfg)
############ add new params in .yml config file
self.K = model_cfg.QUEUE_SIZE # 6400
self.m = model_cfg.MOMENTUM # 0.9
self.T = model_cfg.TEMPERATURE # 0.07
self.topk = model_cfg.TOPK # 5
self.multi_label = False
############ add new params in .yml config file
# init the parameter of two models
self.init_param()
# create the img queue
self.register_buffer("img_queue", torch.randn(model_cfg.IMG_FEATURE_DIM, self.K))
self.img_queue = nn.functional.normalize(self.img_queue, dim=0)
self.register_buffer("img_queue_ptr", torch.zeros(1, dtype=torch.long)) # image queue points
# create the text queue
self.register_buffer("text_queue", torch.randn(model_cfg.IMG_FEATURE_DIM, self.K))
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
self.register_buffer("text_queue_ptr", torch.zeros(1, dtype=torch.long)) # text queue points
def init_param(self):
for param_q, param_k in zip(self.learnable['imgencoder'].parameters(), self.learnable['imgencoder_mom'].parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
for param_q, param_k in zip(self.learnable['textencoder'].parameters(), self.learnable['textencoder_mom'].parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder for image modal
"""
for param_q, param_k in zip(self.learnable['imgencoder'].parameters(), self.learnable['imgencoder_mom'].parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
for param_q, param_k in zip(self.learnable['textencoder'].parameters(), self.learnable['textencoder_mom'].parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys, option='img'):
# option in
# gather keys before updating queue
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
if option == 'img':
ptr = int(self.img_queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.img_queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K # move pointer
self.img_queue_ptr[0] = ptr
else:
ptr = int(self.text_queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.text_queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K # move pointer
self.text_queue_ptr[0] = ptr
@torch.no_grad()
def _batch_shuffle_ddp(self, x, x_mask):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
x_mask_gather = concat_all_gather(x_mask)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).cuda()
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], x_mask_gather[idx_this], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_ddp(self, x, x_mask, idx_unshuffle):
"""
Undo batch shuffle.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
x_mask_gather = concat_all_gather(x_mask)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], x_mask_gather[idx_this]
def forward(self, imgFea, texts, maskImages, maskTexts, text_lens, image_boxs, is_training=True):
if self.model_cfg.IS_EXTRACT:
return self.extract(imgFea, texts, maskImages, maskTexts, image_boxs)
batch_size = imgFea.size(0)
imgFea_q = self.learnable['imgencoder'](imgFea, maskImages, image_boxs) # <bsz, img_dim>
imgFea_q = F.normalize(imgFea_q, p=2, dim=-1)
textFea_q = self.learnable['textencoder'](texts, maskTexts) # <bsz, img_dim>
textFea_q = F.normalize(textFea_q, p=2, dim=-1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
# shuffle for making use of BN
imgFea, image_boxs, idx_unshuffle = self._batch_shuffle_ddp(imgFea, image_boxs)
imgFea_k = self.learnable['imgencoder_mom'](imgFea, maskImages, image_boxs) # <bsz, img_dim>
imgFea_k = F.normalize(imgFea_k, p=2, dim=-1)
# undo shuffle
imgFea_k, image_boxs = self._batch_unshuffle_ddp(imgFea_k, image_boxs, idx_unshuffle)
# shuffle for making use of BN
texts, maskTexts, idx_unshuffle = self._batch_shuffle_ddp(texts, maskTexts)
textFea_k = self.learnable['textencoder_mom'](texts, maskTexts) # <bsz, img_dim>
textFea_k = F.normalize(textFea_k, p=2, dim=-1)
# undo shuffle
textFea_k, maskTexts = self._batch_unshuffle_ddp(textFea_k, maskTexts, idx_unshuffle)
# compute logits for image -> text
# positive logits: Nx1
i2t_l_pos = torch.einsum('nc,nc->n', [imgFea_q, textFea_k]).unsqueeze(-1)
# negative logits: NxK
i2t_l_neg = torch.einsum('nc,ck->nk', [imgFea_q, self.text_queue.clone().detach()])
# logits: Nx(1+K)
i2t_logits = torch.cat([i2t_l_pos, i2t_l_neg], dim=-1)
i2t_logits /= self.T
# compute logits for text -> image
# positive logits: Nx1
t2i_l_pos = torch.einsum('nc,nc->n', [textFea_q, imgFea_k]).unsqueeze(-1)
# negative logits: NxK
t2i_l_neg = torch.einsum('nc,ck->nk', [textFea_q, self.img_queue.clone().detach()])
# logits: Nx(1+K)
t2i_logits = torch.cat([t2i_l_pos, t2i_l_neg], dim=-1)
t2i_logits /= self.T
### multi-label
mask = torch.zeros((batch_size, self.K)).bool().cuda() # <B, K>
if self.multi_label:
mask_sim_txt = textFea_k.matmul(self.text_queue.clone().detach()) # <B, dim> <dim, K> -> <B, K>
mask_sim_img = imgFea_k.matmul(self.img_queue.clone().detach())
_, topkidx_txt = torch.topk(mask_sim_txt, self.topk, dim=1) # <B, topk>
_, topkidx_img = torch.topk(mask_sim_img, self.topk, dim=1) # <B, topk>
topk_onehot_txt = torch.zeros_like(mask_sim_txt) # <B, K>
topk_onehot_txt.scatter_(1, topkidx_txt, 1) # one hot vector
topk_onehot_img = torch.zeros_like(mask_sim_img) # <B, K>
topk_onehot_img.scatter_(1, topkidx_img, 1) # one hot vector
mask[topk_onehot_txt.bool() & topk_onehot_img.bool()] = True # <B, K>
mask = torch.cat([torch.ones((batch_size, 1), dtype=torch.long, device=mask.device).bool(),
mask], dim=1) # <B, K+1>
### multi-label
t2i_loss = -1 * F.log_softmax(t2i_logits, dim=1) # <B, 1+K>
t2i_loss = torch.masked_select(t2i_loss, mask).sum() / batch_size # masked_select return 1-d tensor
i2t_loss = -1 * F.log_softmax(i2t_logits, dim=1)
i2t_loss = torch.masked_select(i2t_loss, mask).sum() / batch_size # masked_select return 1-d tensor
loss = t2i_loss + i2t_loss
## enqueue and dequeue
self._dequeue_and_enqueue(imgFea_k, option='img')
self._dequeue_and_enqueue(textFea_k, option='text')
# ----------caption-------------
# TODO: update
'''
if is_training:
caption = None
caption_loss = self.generator(imgFea_q, texts, text_lens, maskTexts, is_training)
else:
caption_loss, caption = self.generator(imgFea_q, texts, text_lens, maskTexts, is_training)
'''
return loss#, caption_loss, caption
def extract(self, imgFea, texts, maskImages, maskTexts, image_boxs):
imgFea = self.learnable['imgencoder'](imgFea, maskImages, image_boxs) # <bsz, img_dim>
textFea = self.learnable['textencoder'](texts, maskTexts) # <bsz, img_dim>
imgFea = F.normalize(imgFea, p=2, dim=-1)
textFea = F.normalize(textFea, p=2, dim=-1)
retrieval_feat_group = {}
retrieval_feat_group['img_text'] = (imgFea, textFea)
return retrieval_feat_group
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output