""" Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py """ import warnings from collections import OrderedDict from typing import Tuple, Union, Optional import hashlib import os import urllib import warnings from tqdm import tqdm import torch import torch.nn.functional as F from torch import Tensor from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.init import xavier_uniform_ from torch.nn.init import constant_ from torch.nn.init import xavier_normal_ from torch.nn.parameter import Parameter from torch.nn.modules.module import Module from .module_gated_attention import gated_coattention from torch import nn _MODELS = { "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", } _PT_NAME = { "RN50": "RN50.pt", "RN101": "RN101.pt", "RN50x4": "RN50x4.pt", "RN50x16": "RN50x16.pt", "ViT-B/32": "ViT-B-32.pt", "ViT-B/16": "ViT-B-16.pt", } def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target def available_models(): """Returns the names of available CLIP models""" return list(_MODELS.keys()) # ============================= class TABAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. See `Attention Is All You Need `_ .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set to :attr:`embed_dim` such that query, key, and value have the same number of features. Examples:: >>> multihead_attn = TABAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) This is a version of multihead attention written to comply with the defintion of TAB!!! """ bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): super(TABAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias) if add_bias_kv: self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.) constant_(self.out_proj.bias, 0.) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old TABAttention checkpoints generated by v1.1.0 if '_qkv_same_embed_dim' not in state: state['_qkv_same_embed_dim'] = True super(TABAttention, self).__setstate__(state) def forward(self, query: Tensor, key: Tensor, value: Tensor, gt_attention_map: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shapes for inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. Shapes for outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ if not self._qkv_same_embed_dim: return gated_coattention( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight.half(), self.in_proj_bias.half(), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(), training=self.training, gt_attention_map=gt_attention_map, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight) else: return gated_coattention( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight.half(), self.in_proj_bias.half(), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(), training=self.training, gt_attention_map=gt_attention_map, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask=None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor): attn_mask_ = self.attn_mask if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): attn_mask_ = self.attn_mask(x.size(0)) # LND attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] def forward(self, x_tuple:tuple): x, video_frame = x_tuple x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return (x, video_frame) def visualize_attention(self, x: torch.Tensor): attn_outputs, attn_weights = self.attn(x, x, x, need_weights=True, attn_mask=None) return attn_outputs, attn_weights def visualize_forward(self, x_tuple:tuple): x, video_frame = x_tuple attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x)) x = x + attn_outputs x = x + self.mlp(self.ln_2(x)) return (x, video_frame, attn_weights) class TABLayer(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask=None): super().__init__() self.attn = TABAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor, y: torch.Tensor): attn_mask_ = self.attn_mask if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): attn_mask_ = self.attn_mask(x.size(0)) # LND attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None return self.attn(x, y, y, need_weights=False, attn_mask=attn_mask_)[0] def forward(self, x: torch.Tensor, y: torch.Tensor): x = self.attention(self.ln_1(x), self.ln_1(y)) x = x + self.mlp(self.ln_2(x)) return x def visualize_attention(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map): attn_outputs, attn_weights = self.attn(x, y, y, gt_attention_map=gt_attention_map, need_weights=True, attn_mask=None) return attn_outputs, attn_weights def visualize_forward(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map): attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x), self.ln_1(y), gt_attention_map) x = attn_outputs x = x + self.mlp(self.ln_2(x)) return (x, attn_weights) class visionTransformer(nn.Module): def __init__(self, width: int, layers: int, heads: int, attn_mask = None): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) if i < (layers - 1) else TABLayer(width, 1, attn_mask) for i in range(layers)]) def forward(self, x: torch.Tensor, video_frame=-1): return self.resblocks((x, video_frame))[0] class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int, attn_mask = None): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) def forward(self, x: torch.Tensor, video_frame=-1): return self.resblocks((x, video_frame))[0] class VisualTransformer(nn.Module): def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, linear_patch: str = '2d', intra_layers: int = 9): super().__init__() self.input_resolution = input_resolution self.output_dim = output_dim self.intra_layers = intra_layers self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) self.ln_pre = LayerNorm(width) self.joint_positional_embedding = nn.Parameter(scale * torch.randn(2 * ((input_resolution // patch_size) ** 2 + 1), width)) self.bef_embedding = nn.Parameter(scale * torch.randn(width)) self.aft_embedding = nn.Parameter(scale * torch.randn(width)) self.ln_mid = LayerNorm(width) self.transformer = visionTransformer(width, layers, heads) self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) # For 3D assert linear_patch in ['2d', '3d'] self.linear_patch = linear_patch if self.linear_patch == '3d': self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size), stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False) def forward(self, x: torch.Tensor, left_gt_map, right_gt_map, video_frame=-1, visualize=False): if self.linear_patch == '3d': assert video_frame != -1 x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1]) x_3d = x_3d.permute(0, 2, 1, 3, 4) x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] x_3d = x_3d.permute(0, 2, 1, 3, 4) # shape = [*, frame, width, grid, grid] x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] else: x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND if visualize is True: all_attn_weights = [] for i in range(self.intra_layers): x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame)) attn_weights = attn_weights.view(x.size(1) // video_frame, -1, attn_weights.size(-2), attn_weights.size(-1)) all_attn_weights.append(attn_weights) else: for i in range(self.intra_layers): x = self.transformer.resblocks[i]((x, video_frame))[0] x = x.permute(1, 0, 2) # LND -> NLD bs = x.size(0) // video_frame x = x.view(bs, video_frame, x.size(-2), x.size(-1)) x = torch.cat([x[:, 0] + self.bef_embedding.to(x.dtype), x[:, 1] + self.aft_embedding.to(x.dtype)], dim=1) x = x + self.joint_positional_embedding.to(x.dtype) x = self.ln_mid(x) x = x.permute(1, 0, 2) # NLD -> LND if visualize is True: for i in range(self.intra_layers, self.transformer.layers - 1): x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame)) all_attn_weights.append(attn_weights) cls_index = int(x.size(0) / 2) left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map) right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map) all_attn_weights.append(left_attn_weights) all_attn_weights.append(right_attn_weights) else: for i in range(self.intra_layers, self.transformer.layers - 1): x = self.transformer.resblocks[i]((x, video_frame))[0] cls_index = int(x.size(0) / 2) left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map) right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map) left_features = left_features.permute(1, 0, 2) # LND -> NLD right_features = right_features.permute(1, 0, 2) # LND -> NLD x = torch.cat([left_features, right_features], 1) # Move the three lines below to `encode_image` for entire hidden sequence # x = self.ln_post(x[:, 0, :]) # if self.proj is not None: # x = x @ self.proj if visualize is True: return x, all_attn_weights return x, left_attn_weights, right_attn_weights class CLIP(nn.Module): def __init__(self, embed_dim: int, # vision image_resolution: int, vision_layers: Union[Tuple[int, int, int, int], int], vision_width: int, vision_patch_size: int, # text context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int, # vision linear of patch linear_patch: str = '2d', intra_layers: int = 9, ): super().__init__() self.context_length = context_length vision_heads = vision_width // 64 self.visual = VisualTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, output_dim=embed_dim, linear_patch=linear_patch, intra_layers=intra_layers, ) self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask ) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([])) self.initialize_parameters() def initialize_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) @staticmethod def get_config(pretrained_clip_name="ViT-B/32"): model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt") if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME: model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name]) if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path): pass else: if pretrained_clip_name in _MODELS: model_path = _download(_MODELS[pretrained_clip_name]) elif os.path.isfile(pretrained_clip_name): model_path = pretrained_clip_name else: raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}") try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = model.state_dict() except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") return state_dict def build_attention_mask(self, context_length): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.zeros(context_length, context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @property def dtype(self): return self.visual.conv1.weight.dtype def encode_image(self, image, left_gt_map, right_gt_map, return_hidden=False, video_frame=-1): hidden, left_map, right_map = self.visual(image.type(self.dtype), left_gt_map, right_gt_map, video_frame=video_frame) hidden = self.visual.ln_post(hidden) @ self.visual.proj cls_index = int(hidden.size(1) / 2) hidden2 = torch.cat([hidden[:, 0, :].unsqueeze(1), hidden[:, cls_index, :].unsqueeze(1)], 1) x = torch.mean(hidden2, 1) if return_hidden: return x, hidden2, left_map, right_map return x, left_map, right_map def encode_text(self, text, return_hidden=False): x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) x = x + pos_emd x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD hidden = self.ln_final(x).type(self.dtype) @ self.text_projection # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] if return_hidden: return x, hidden return x def forward(self, image, text): image_features = self.encode_image(image) text_features = self.encode_text(text) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logit_scale * text_features @ image_features.t() # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16) def build_model(state_dict: dict): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_resolution = vision_patch_size * grid_size else: counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_resolution = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) model = CLIP( embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers ) for key in ["input_resolution", "context_length", "vocab_size"]: if key in state_dict: del state_dict[key] convert_weights(model) model.load_state_dict(state_dict) return model.eval()