|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from utils.model_utils import RNNEncoder | 
					
						
						|  | from easydict import EasyDict as edict | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cal_base_cfg = edict( | 
					
						
						|  | visual_input_size=2048, | 
					
						
						|  | textual_input_size=768, | 
					
						
						|  | query_feat_size=768, | 
					
						
						|  | visual_hidden_size=500, | 
					
						
						|  | output_size=100, | 
					
						
						|  | embedding_size=768, | 
					
						
						|  | lstm_hidden_size=1000, | 
					
						
						|  | margin=0.1, | 
					
						
						|  | loss_type="hinge", | 
					
						
						|  | inter_loss_weight=0.4, | 
					
						
						|  | ctx_mode="video" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CAL(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(CAL, self).__init__() | 
					
						
						|  | self.config = config | 
					
						
						|  |  | 
					
						
						|  | self.moment_mlp = nn.Sequential( | 
					
						
						|  | nn.Linear(config.visual_input_size, config.visual_hidden_size), | 
					
						
						|  | nn.ReLU(True), | 
					
						
						|  | nn.Linear(config.visual_hidden_size, config.output_size), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.query_lstm = RNNEncoder(word_embedding_size=config.embedding_size, | 
					
						
						|  | hidden_size=config.lstm_hidden_size, | 
					
						
						|  | bidirectional=False, | 
					
						
						|  | rnn_type="lstm", | 
					
						
						|  | dropout_p=0, | 
					
						
						|  | n_layers=1, | 
					
						
						|  | return_outputs=False) | 
					
						
						|  |  | 
					
						
						|  | self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size) | 
					
						
						|  |  | 
					
						
						|  | def moment_encoder(self, moment_feat): | 
					
						
						|  | """moment_feat: (N, L_clip, D_v)""" | 
					
						
						|  | return F.normalize(self.moment_mlp(moment_feat), p=2, dim=-1) | 
					
						
						|  |  | 
					
						
						|  | def query_encoder(self, query_feat, query_mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | query_feat: (N, L_q, D_q), torch.float32 | 
					
						
						|  | query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask | 
					
						
						|  | """ | 
					
						
						|  | _, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long()) | 
					
						
						|  | return F.normalize(self.query_linear(hidden), p=2, dim=-1) | 
					
						
						|  |  | 
					
						
						|  | def compute_pdist(self, query_embedding, moment_feat, moment_mask): | 
					
						
						|  | """ pairwise L2 distance | 
					
						
						|  | Args: | 
					
						
						|  | query_embedding: (N, D_o) | 
					
						
						|  | moment_feat: (N, L_clip, D_v) | 
					
						
						|  | moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding | 
					
						
						|  | """ | 
					
						
						|  | moment_embedding = self.moment_encoder(moment_feat) | 
					
						
						|  | moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) | 
					
						
						|  | moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) | 
					
						
						|  | return moment_dist | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def compute_cdist_inference(cls, query_embeddings, moment_embeddings, moment_mask): | 
					
						
						|  | """ Compute L2 distance for every possible pair of queries and proposals. This is different from | 
					
						
						|  | compute_pdist as the latter computes only pairs at each row. | 
					
						
						|  | Args: | 
					
						
						|  | query_embeddings: (N_q, D_o) | 
					
						
						|  | moment_embeddings: (N_prop, N_clips, D_o) | 
					
						
						|  | moment_mask: (N_prop, N_clips) | 
					
						
						|  | return: | 
					
						
						|  | query_moment_scores: (N_q, N_prop) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | query_device = query_embeddings.device | 
					
						
						|  | if moment_embeddings.device != query_device: | 
					
						
						|  | moment_embeddings = moment_embeddings.to(query_device) | 
					
						
						|  | moment_mask = moment_mask.to(query_device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | n_query = query_embeddings.shape[0] | 
					
						
						|  | n_prop, n_clips, d = moment_embeddings.shape | 
					
						
						|  | query_clip_dist = torch.cdist( | 
					
						
						|  | query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 | 
					
						
						|  | query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips) | 
					
						
						|  | query_moment_dist = torch.sum( | 
					
						
						|  | query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0) | 
					
						
						|  | return query_moment_dist | 
					
						
						|  |  | 
					
						
						|  | def forward(self, query_feat, query_mask, pos_moment_feat, pos_moment_mask, | 
					
						
						|  | intra_neg_moment_feat, intra_neg_moment_mask, | 
					
						
						|  | inter_neg_moment_feat, inter_neg_moment_mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | query_feat: (N, L, D_q) | 
					
						
						|  | query_mask: (N, L) | 
					
						
						|  | pos_moment_feat: (N, L_clip_1, D_v) | 
					
						
						|  | pos_moment_mask: (N, L_clip_1) | 
					
						
						|  | intra_neg_moment_feat: (N, L_clip_2, D_v) | 
					
						
						|  | intra_neg_moment_mask: (N, L_clip_2) | 
					
						
						|  | inter_neg_moment_feat: (N, L_clip_3, D_v) | 
					
						
						|  | inter_neg_moment_mask: (N, L_clip_2) | 
					
						
						|  | """ | 
					
						
						|  | query_embed = self.query_encoder(query_feat, query_mask) | 
					
						
						|  | pos_dist = self.compute_pdist(query_embed, pos_moment_feat, pos_moment_mask) | 
					
						
						|  | intra_neg_dist = self.compute_pdist(query_embed, intra_neg_moment_feat, intra_neg_moment_mask) | 
					
						
						|  | if self.config.inter_loss_weight == 0: | 
					
						
						|  | loss_inter = 0. | 
					
						
						|  | else: | 
					
						
						|  | inter_neg_dist = self.compute_pdist(query_embed, inter_neg_moment_feat, inter_neg_moment_mask) | 
					
						
						|  | loss_inter = self.calc_loss(pos_dist, inter_neg_dist) | 
					
						
						|  |  | 
					
						
						|  | loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  | def calc_loss(self, pos_dist, neg_dist): | 
					
						
						|  | """ Note here we encourage positive distance to be smaller than negative distance. | 
					
						
						|  | Args: | 
					
						
						|  | pos_dist: (N, ), torch.float32 | 
					
						
						|  | neg_dist: (N, ), torch.float32 | 
					
						
						|  | """ | 
					
						
						|  | if self.config.loss_type == "hinge": | 
					
						
						|  | return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist) | 
					
						
						|  | elif self.config.loss_type == "lse": | 
					
						
						|  | return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError("Only support 'hinge' and 'lse'") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CALWithSub(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(CALWithSub, self).__init__() | 
					
						
						|  | self.config = config | 
					
						
						|  | self.use_video = "video" in config.ctx_mode | 
					
						
						|  | self.use_sub = "sub" in config.ctx_mode | 
					
						
						|  | self.use_tef = "tef" in config.ctx_mode | 
					
						
						|  | self.tef_only = self.use_tef and not self.use_video and not self.use_sub | 
					
						
						|  |  | 
					
						
						|  | if self.use_video or self.tef_only: | 
					
						
						|  | self.video_moment_mlp = nn.Sequential( | 
					
						
						|  | nn.Linear(config.visual_input_size, config.visual_hidden_size), | 
					
						
						|  | nn.ReLU(True), | 
					
						
						|  | nn.Linear(config.visual_hidden_size, config.output_size), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.use_sub: | 
					
						
						|  | self.sub_moment_mlp = nn.Sequential( | 
					
						
						|  | nn.Linear(config.textual_input_size, config.visual_hidden_size), | 
					
						
						|  | nn.ReLU(True), | 
					
						
						|  | nn.Linear(config.visual_hidden_size, config.output_size), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.query_lstm = RNNEncoder(word_embedding_size=config.query_feat_size, | 
					
						
						|  | hidden_size=config.lstm_hidden_size, | 
					
						
						|  | bidirectional=False, | 
					
						
						|  | rnn_type="lstm", | 
					
						
						|  | dropout_p=0, | 
					
						
						|  | n_layers=1, | 
					
						
						|  | return_outputs=False) | 
					
						
						|  |  | 
					
						
						|  | self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size) | 
					
						
						|  |  | 
					
						
						|  | def moment_encoder(self, moment_feat, module_name="video"): | 
					
						
						|  | """moment_feat: (N, L_clip, D_v)""" | 
					
						
						|  | if moment_feat is not None: | 
					
						
						|  | encoder = getattr(self, module_name + "_moment_mlp") | 
					
						
						|  | return F.normalize(encoder(moment_feat), p=2, dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def query_encoder(self, query_feat, query_mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | query_feat: (N, L_q, D_q), torch.float32 | 
					
						
						|  | query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask | 
					
						
						|  | """ | 
					
						
						|  | _, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long()) | 
					
						
						|  | return F.normalize(self.query_linear(hidden), p=2, dim=-1) | 
					
						
						|  |  | 
					
						
						|  | def _compute_pdist(self, query_embedding, moment_feat, moment_mask, module_name="video"): | 
					
						
						|  | """ pairwise L2 distance | 
					
						
						|  | Args: | 
					
						
						|  | query_embedding: (N, D_o) | 
					
						
						|  | moment_feat: (N, L_clip, D_v) | 
					
						
						|  | moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding | 
					
						
						|  | """ | 
					
						
						|  | moment_embedding = self.moment_encoder(moment_feat, module_name=module_name) | 
					
						
						|  | moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) | 
					
						
						|  | moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) | 
					
						
						|  | return moment_dist | 
					
						
						|  |  | 
					
						
						|  | def compute_pdist(self, query_embedding, moment_video_feat, moment_sub_feat, moment_mask): | 
					
						
						|  | """ pairwise L2 distance | 
					
						
						|  | Args: | 
					
						
						|  | query_embedding: (N, D_o) | 
					
						
						|  | moment_video_feat: (N, L_clip, D_v) | 
					
						
						|  | moment_sub_feat: (N, L_clip, D_t) | 
					
						
						|  | moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding | 
					
						
						|  | """ | 
					
						
						|  | divisor = (self.use_video or self.tef_only) + self.use_sub | 
					
						
						|  | video_moment_dist = self._compute_pdist(query_embedding, moment_video_feat, moment_mask, module_name="video") \ | 
					
						
						|  | if self.use_video or self.tef_only else 0 | 
					
						
						|  | sub_moment_dist = self._compute_pdist(query_embedding, moment_sub_feat, moment_mask, module_name="sub") \ | 
					
						
						|  | if self.use_sub else 0 | 
					
						
						|  | return (video_moment_dist + sub_moment_dist) / divisor | 
					
						
						|  |  | 
					
						
						|  | def _compute_cdist_inference(self, query_embeddings, moment_embeddings, moment_mask): | 
					
						
						|  | """ Compute L2 distance for every possible pair of queries and proposals. This is different from | 
					
						
						|  | compute_pdist as the latter computes only pairs at each row. | 
					
						
						|  | Args: | 
					
						
						|  | query_embeddings: (N_q, D_o) | 
					
						
						|  | moment_embeddings: (N_prop, N_clips, D_o) | 
					
						
						|  | moment_mask: (N_prop, N_clips) | 
					
						
						|  | return: | 
					
						
						|  | query_moment_scores: (N_q, N_prop) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | query_device = query_embeddings.device | 
					
						
						|  | if moment_embeddings.device != query_device: | 
					
						
						|  | moment_embeddings = moment_embeddings.to(query_device) | 
					
						
						|  | moment_mask = moment_mask.to(query_device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | n_query = query_embeddings.shape[0] | 
					
						
						|  | n_prop, n_clips, d = moment_embeddings.shape | 
					
						
						|  | query_clip_dist = torch.cdist( | 
					
						
						|  | query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 | 
					
						
						|  | query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips) | 
					
						
						|  | query_moment_dist = torch.sum( | 
					
						
						|  | query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0) | 
					
						
						|  | return query_moment_dist | 
					
						
						|  |  | 
					
						
						|  | def compute_cdist_inference(self, query_embeddings, video_moment_embeddings, sub_moment_embeddings, moment_mask): | 
					
						
						|  | divisor = (self.use_video or self.tef_only) + self.use_sub | 
					
						
						|  | video_moment_dist = self._compute_cdist_inference(query_embeddings, video_moment_embeddings, moment_mask) \ | 
					
						
						|  | if self.use_video or self.tef_only else 0 | 
					
						
						|  | sub_moment_dist = self._compute_cdist_inference(query_embeddings, sub_moment_embeddings, moment_mask) \ | 
					
						
						|  | if self.use_sub else 0 | 
					
						
						|  | return (video_moment_dist + sub_moment_dist) / divisor | 
					
						
						|  |  | 
					
						
						|  | def forward(self, query_feat, query_mask, pos_moment_video_feat, pos_moment_video_mask, | 
					
						
						|  | intra_neg_moment_video_feat, intra_neg_moment_video_mask, | 
					
						
						|  | inter_neg_moment_video_feat, inter_neg_moment_video_mask, | 
					
						
						|  | pos_moment_sub_feat, pos_moment_sub_mask, | 
					
						
						|  | intra_neg_moment_sub_feat, intra_neg_moment_sub_mask, | 
					
						
						|  | inter_neg_moment_sub_feat, inter_neg_moment_sub_mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | query_feat: (N, L, D_q) | 
					
						
						|  | query_mask: (N, L) | 
					
						
						|  | pos_moment_video_feat: (N, L_clip_1, D_v) | 
					
						
						|  | pos_moment_video_mask: (N, L_clip_1) | 
					
						
						|  | intra_neg_moment_video_feat: (N, L_clip_2, D_v) | 
					
						
						|  | intra_neg_moment_video_mask: (N, L_clip_2) | 
					
						
						|  | inter_neg_moment_video_feat: (N, L_clip_3, D_v) | 
					
						
						|  | inter_neg_moment_video_mask: (N, L_clip_2) | 
					
						
						|  | pos_moment_sub_feat: | 
					
						
						|  | pos_moment_sub_mask: | 
					
						
						|  | intra_neg_moment_sub_feat: | 
					
						
						|  | intra_neg_moment_sub_mask: | 
					
						
						|  | inter_neg_moment_sub_feat: | 
					
						
						|  | inter_neg_moment_sub_mask: | 
					
						
						|  | """ | 
					
						
						|  | query_embed = self.query_encoder(query_feat, query_mask) | 
					
						
						|  | pos_dist = self.compute_pdist( | 
					
						
						|  | query_embed, pos_moment_video_feat, pos_moment_sub_feat, | 
					
						
						|  | moment_mask=pos_moment_sub_mask if self.use_sub else pos_moment_video_mask) | 
					
						
						|  | intra_neg_dist = self.compute_pdist( | 
					
						
						|  | query_embed, intra_neg_moment_video_feat, intra_neg_moment_sub_feat, | 
					
						
						|  | moment_mask=intra_neg_moment_sub_mask if self.use_sub else intra_neg_moment_video_mask) | 
					
						
						|  | if self.config.inter_loss_weight == 0: | 
					
						
						|  | loss_inter = 0. | 
					
						
						|  | else: | 
					
						
						|  | inter_neg_dist = self.compute_pdist( | 
					
						
						|  | query_embed, inter_neg_moment_video_feat, inter_neg_moment_sub_feat, | 
					
						
						|  | moment_mask=inter_neg_moment_sub_mask if self.use_sub else inter_neg_moment_video_mask) | 
					
						
						|  | loss_inter = self.calc_loss(pos_dist, inter_neg_dist) | 
					
						
						|  |  | 
					
						
						|  | loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  | def calc_loss(self, pos_dist, neg_dist): | 
					
						
						|  | """ Note here we encourage positive distance to be smaller than negative distance. | 
					
						
						|  | Args: | 
					
						
						|  | pos_dist: (N, ), torch.float32 | 
					
						
						|  | neg_dist: (N, ), torch.float32 | 
					
						
						|  | """ | 
					
						
						|  | if self.config.loss_type == "hinge": | 
					
						
						|  | return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist) | 
					
						
						|  | elif self.config.loss_type == "lse": | 
					
						
						|  | return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError("Only support 'hinge' and 'lse'") | 
					
						
						|  |  |