from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch from torch import nn from .until_module import PreTrainedModel from .module_cross import CrossModel, CrossConfig from .module_decoder import DecoderModel, DecoderConfig from utils.module_clip import CLIP, convert_weights from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None): if hasattr(source_config, source_attr_name): if default_value is None or getattr(source_config, source_attr_name) != default_value: setattr(target_config, target_attr_name, getattr(source_config, source_attr_name)) return target_config class CLIP4IDCPreTrainedModel(PreTrainedModel, nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, cross_config, decoder_config, *inputs, **kwargs): super(CLIP4IDCPreTrainedModel, self).__init__(cross_config, decoder_config) self.cross_config = cross_config self.decoder_config = decoder_config self.clip = None self.cross = None @classmethod def from_pretrained(cls, cross_model_name, decoder_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs): if state_dict is None: state_dict = {} pretrained_clip_name = "ViT-B/16" clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name) for key, val in clip_state_dict.items(): new_key = "clip." + key if new_key not in state_dict: state_dict[new_key] = val.clone() cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None) decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None) model = cls(cross_config, decoder_config, clip_state_dict, *inputs, **kwargs) ## ===> Initialization trick [HARD CODE] if model.linear_patch == "3d": contain_conv2 = False for key in state_dict.keys(): if key.find("visual.conv2.weight") > -1: contain_conv2 = True break if contain_conv2 is False and hasattr(model.clip.visual, "conv2"): cp_weight = state_dict["clip.visual.conv1.weight"].clone() kernel_size = model.clip.visual.conv2.weight.size(2) conv2_size = model.clip.visual.conv2.weight.size() conv2_size = list(conv2_size) left_conv2_size = conv2_size.copy() right_conv2_size = conv2_size.copy() left_conv2_size[2] = (kernel_size - 1) // 2 right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2] left_zeros, right_zeros = None, None if left_conv2_size[2] > 0: left_zeros = torch.zeros(*tuple(left_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device) if right_conv2_size[2] > 0: right_zeros = torch.zeros(*tuple(right_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device) cat_list = [] if left_zeros != None: cat_list.append(left_zeros) cat_list.append(cp_weight.unsqueeze(2)) if right_zeros != None: cat_list.append(right_zeros) cp_weight = torch.cat(cat_list, dim=2) state_dict["clip.visual.conv2.weight"] = cp_weight ## <=== End of initialization trick if state_dict is not None: model = cls.init_preweight(model, state_dict) return model class CLIP4IDC(CLIP4IDCPreTrainedModel): def __init__(self, cross_config, decoder_config, clip_state_dict): super(CLIP4IDC, self).__init__(cross_config, decoder_config) self.ignore_video_index = -1 # assert self.task_config.max_words <= cross_config.max_position_embeddings # CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===> vit = "visual.proj" in clip_state_dict assert vit if vit: vision_width = clip_state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in clip_state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = clip_state_dict["visual.conv1.weight"].shape[-1] grid_size = round((clip_state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_resolution = vision_patch_size * grid_size else: counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == clip_state_dict["visual.attnpool.positional_embedding"].shape[0] image_resolution = output_width * 32 embed_dim = clip_state_dict["text_projection"].shape[1] context_length = clip_state_dict["positional_embedding"].shape[0] vocab_size = clip_state_dict["token_embedding.weight"].shape[0] transformer_width = clip_state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks"))) self.linear_patch = '2d' # use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40 cut_top_layer = 0 self.clip = CLIP( embed_dim, image_resolution, vision_layers-cut_top_layer, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers-cut_top_layer, linear_patch=self.linear_patch, intra_layers=9 ).float() bert_word_embeddings_weight = self.clip.token_embedding.weight bert_position_embeddings_weight = self.clip.positional_embedding for key in ["input_resolution", "context_length", "vocab_size"]: if key in clip_state_dict: del clip_state_dict[key] convert_weights(self.clip) # <=== End of CLIP Encoders self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight) self.apply(self.init_weights) def get_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1): bs_pair = visual_mask.size(0) visual_hidden, visual_output, left_map, right_map = self.clip.encode_image(video, left_gt_map, right_gt_map, video_frame=video_frame, return_hidden=True) visual_hidden = visual_hidden.float() visual_output = visual_output.float() visual_hidden = visual_hidden.view(bs_pair, -1, visual_hidden.size(-1)) left_map = left_map.float() right_map = right_map.float() return visual_hidden, visual_output, left_map, right_map def get_sequence_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1): if shaped is False: visual_mask = visual_mask.view(-1, visual_mask.shape[-1]) video = torch.as_tensor(video).float() b, pair, channel, h, w = video.shape video = video.view(b * pair, channel, h, w) video_frame = pair _, visual_hidden, left_map, right_map = self.get_visual_output(video, visual_mask, left_gt_map, right_gt_map, shaped=True, video_frame=video_frame) return visual_hidden, left_map, right_map def _get_decoder_score(self, visual_output, visual_mask, input_caption_ids, decoder_mask): res_tuples = () decoder_scores = self.decoder(input_caption_ids, encoder_outs=visual_output, answer_mask=decoder_mask, encoder_mask=visual_mask) return decoder_scores, res_tuples def decoder_caption(self, visual_output, visual_mask, input_caption_ids, decoder_mask, get_logits=False): decoder_scores, _ = self._get_decoder_score(visual_output, visual_mask, input_caption_ids, decoder_mask) if get_logits: return decoder_scores _, decoder_scores_result = torch.max(decoder_scores, -1) return decoder_scores_result def init_model(model_path, device): model_state_dict = torch.load(model_path, map_location='cpu') # Prepare model cache_dir = "" model = CLIP4IDC.from_pretrained("cross-base", "decoder-base", cache_dir=cache_dir, state_dict=model_state_dict) model.to(device) return model