FLARE / dust3r /croco /models /transformer_utils.py
聂如
Add design file
91126af
import os
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint
from einops import rearrange, repeat
from inspect import isfunction
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func, flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
except:
flash_attn_qkvpacked_func, flash_attn_func, flash_attn_varlen_qkvpacked_func = None, None, None
unpad_input, pad_input = None, None
from .x_transformer import AttentionLayers, BasicEncoder
import math
def _init_weights(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# Copy from CLIP GitHub
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
def modulate(x, shift, scale):
# from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
class MultiheadAttentionFlashV2(nn.Module):
def __init__(self, embed_dim, n_head, bias=False, shift_group=None, qkv_packed=False, window_size=None):
super().__init__()
self.head_dim = embed_dim// n_head
self.embed_dim = embed_dim
self.n_head = n_head
self.to_q = nn.Linear(embed_dim, embed_dim, bias=bias)
self.to_k = nn.Linear(embed_dim, embed_dim, bias=bias)
self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
self.shift_group = shift_group
self.qkv_packed = qkv_packed
self.window_size = window_size
def forward(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, need_weights=False, attn_mask=None):
q = q.permute(1, 0, 2)
k = k.permute(1, 0, 2)
v = v.permute(1, 0, 2)
h = self.n_head
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v))
# print(q.dtype, k.dtype, v.dtype)
if self.qkv_packed:
bsz, q_len, heads, head_dim = q.shape
group_size = self.shift_group
nheads = self.n_head
qkv = torch.stack([q,k,v], dim=2)
qkv = qkv.reshape(bsz, q_len, 3, 2, nheads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
q_len, 3,
nheads // 2,
self.head_dim)
x = rearrange(qkv, "b s three h d -> b s (three h d)")
key_padding_mask = torch.ones(x.shape[0], x.shape[1], device=x.device, dtype=x.dtype)
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2
cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
cu_q_lens = cu_q_lens[cu_q_lens >= 0]
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=False,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
),
"b s (h d) -> b s h d",
h=nheads // 2,
)
r_out = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
self.head_dim)
else:
if self.shift_group is not None:
bsz, q_len, heads, head_dim = q.shape
assert q_len % self.shift_group == 0
def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
return qkv
q = shift(q, bsz, q_len, self.shift_group, h, self.head_dim)
k = shift(k, bsz, q_len, self.shift_group, h, self.head_dim)
v = shift(v, bsz, q_len, self.shift_group, h, self.head_dim)
if self.window_size:
out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=(self.window_size // 2, self.window_size // 2))
else:
out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal)
if self.shift_group is not None:
out = out.transpose(1, 2).contiguous()
out = rearrange(out, '(b l) g h d -> b (l g) h d', l=q_len // self.shift_group)
r_out = out.clone()
r_out[:, :, h//2:] = r_out[:, :, h//2:].roll(h//2, dims=1)
else:
r_out = out
r_out = rearrange(r_out, 'b n h d -> b n (h d)')
r_out = r_out.permute(1, 0, 2)
return (r_out,)
class PSUpsamplerBlock(nn.Module):
def __init__(self, d_model: int, d_model_out: int, scale_factor: int):
super().__init__()
# self.mlp = nn.Sequential(OrderedDict([
# ("c_fc", nn.Linear(d_model, d_model_out * scale_factor**2)),
# ("gelu", QuickGELU()),
# ("c_proj", nn.Linear(d_model_out * scale_factor**2, d_model_out * scale_factor**2))
# ]))
# self.ln_2 = LayerNorm(d_model)
self.scale_factor = scale_factor
self.d_model_out = d_model_out
self.residual_fc = nn.Linear(d_model, d_model_out * (scale_factor**2))
self.pixelshuffle = nn.PixelShuffle(scale_factor)
def forward(self, x: torch.Tensor):
# mlp block
# x.shape b, l, d
# y = self.ln_2(x)
# y = self.mlp(y)
# For here we have two cases:
# 1. If we have a modulation function for the MLP, we use it to modulate the output of the MLP
# 2. If we don't have a modulation function for the MLP, we use the modulation function for the attention
x = self.residual_fc(x)# .repeat(1, 1, self.scale_factor**2)
# x = x + y
bs, l, c = x.shape
resolution = int(np.sqrt(l))
x = x.permute(0, 2, 1).reshape(bs, c, resolution, resolution)
x = self.pixelshuffle(x)
x = x.reshape(bs, self.d_model_out, resolution*self.scale_factor*resolution*self.scale_factor)
x = x.permute(0, 2, 1)
# x = rearrange(x, 'b l (s c) -> b (l s) c', s=self.scale_factor**2)
return x
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
modulate_feature_size: int = None,
modulate_act_type: str = 'gelu',
cross_att: bool = None,
flash_v2: bool = None,
qkv_packed: bool = None,
shift_group: int = None,
window_size: int = None,):
super().__init__()
print('vit flashv2', flash_v2)
self.flash_v2 = flash_v2
self.window_size = window_size
if self.flash_v2:
self.attn = MultiheadAttentionFlashV2(d_model, n_head, shift_group=shift_group, qkv_packed=qkv_packed, window_size=window_size)
else:
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.window_size = window_size
if modulate_feature_size is not None:
act_dict = {'gelu': QuickGELU,
'silu': nn.SiLU}
self.modulation_fn = nn.Sequential(
LayerNorm(modulate_feature_size),
act_dict[modulate_act_type](),
nn.Linear(modulate_feature_size, 3 * d_model, bias=True)
)
self.mlp_modulation_fn = nn.Sequential(
LayerNorm(modulate_feature_size),
act_dict[modulate_act_type](),
nn.Linear(modulate_feature_size, 3 * d_model, bias=True)
)
else:
self.modulation_fn = None
self.mlp_modulation_fn = None
self.cross_att = cross_att
if self.cross_att:
self.cross_att = CrossAttention(query_dim=d_model, context_dim=d_model,
heads=n_head, dim_head=int(d_model//n_head), dropout=0)
self.ln_1_5 = LayerNorm(d_model)
def attention(self, x: torch.Tensor, index):
if self.attn_mask is not None:
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device)
length = x.shape[0]
attn_mask = self.attn_mask[:length, :length]
else:
attn_mask = None
if self.window_size is not None:
x = x.permute(1, 0, 2)
b, l, c = x.shape
# print(x.shape)
assert l % self.window_size == 0
if index % 2 == 0:
x = rearrange(x, 'b (p w) c -> (b p) w c', w=self.window_size)
x = x.permute(1, 0, 2) # w, bp, c
x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
x = x.permute(1, 0, 2) # bp, w, c
x = rearrange(x, '(b l) w c -> b (l w) c', l=l//self.window_size, w=self.window_size)
x = x.permute(1, 0, 2) # w, bp, c
else:
x = torch.roll(x, shifts=self.window_size//2, dims=1)
x = rearrange(x, 'b (p w) c -> (b p) w c', w=self.window_size)
x = x.permute(1, 0, 2) # w, bp, c
x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
x = x.permute(1, 0, 2) # w, bp, c
x = rearrange(x, '(b l) w c -> b (l w) c', l=l//self.window_size, w=self.window_size)
x = torch.roll(x, shifts=-self.window_size//2, dims=1)
x = x.permute(1, 0, 2)
else:
x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
return x
def forward(self, x: torch.Tensor, modulation: torch.Tensor = None, context: torch.Tensor = None, index=None):
# self attention block
y = self.ln_1(x)
if self.modulation_fn is not None:
shift, scale, gate = self.modulation_fn(modulation).chunk(3, dim=1)
y = modulate(y, shift, scale)
y = self.attention(y, index)
# If we have modulation func for mlp as well, we will just use the gate for the attention
if self.modulation_fn is not None and self.mlp_modulation_fn is not None:
y = y * gate.unsqueeze(0)
x = x + y
# cross attention block
if self.cross_att:
y = self.cross_att(self.ln_1_5(x), context=context)
# print(y.mean().item())
x = x + y
# mlp block
y = self.ln_2(x)
if self.mlp_modulation_fn is not None:
shift, scale, gate = self.mlp_modulation_fn(modulation).chunk(3, dim=1)
y = modulate(y, shift, scale)
y = self.mlp(y)
# For here we have two cases:
# 1. If we have a modulation function for the MLP, we use it to modulate the output of the MLP
# 2. If we don't have a modulation function for the MLP, we use the modulation function for the attention
if self.modulation_fn is not None:
y = y * gate.unsqueeze(0)
x = x + y
return x
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
modulate_feature_size: int = None,
modulate_act_type: str = 'gelu',
cross_att_layers: int = 0,
return_all_layers=False,
flash_v2=True,
qkv_packed=False,
shift_group=None,
window_size=None):
super().__init__()
self.width = width
self.layers = layers
blocks = []
for _ in range(layers):
layer = ResidualAttentionBlock(width,
heads,
attn_mask,
modulate_feature_size=modulate_feature_size,
modulate_act_type=modulate_act_type,
cross_att = (_ + cross_att_layers)>=layers,
flash_v2=flash_v2,
qkv_packed=qkv_packed,
shift_group=shift_group,
window_size=window_size)
blocks.append(layer)
self.resblocks = nn.Sequential(*blocks)
self.grad_checkpointing = False
self.return_all_layers = return_all_layers
self.flash_v2 = flash_v2
def set_grad_checkpointing(self, flag=True):
self.grad_checkpointing = flag
def forward(self,
x: torch.Tensor,
modulation: torch.Tensor = None,
context: torch.Tensor = None,
additional_residuals = None):
all_x = []
if additional_residuals is not None:
assert len(additional_residuals) == self.layers
for res_i, module in enumerate(self.resblocks):
if self.grad_checkpointing:
# print("Grad checkpointing")
x = checkpoint(module, x, modulation, context, res_i)
else:
x = module(x, modulation, context, res_i)
if additional_residuals is not None:
add_res = additional_residuals[res_i]
x[:, :add_res.shape[1]] = x[:, :add_res.shape[1]] + add_res
all_x.append(x)
if self.return_all_layers:
return all_x
else:
return x
class GaussianUpsampler(nn.Module):
def __init__(self, width,
up_ratio,
ch_decay=1,
low_channels=64,
window_size=False,
with_additional_inputs=False):
super().__init__()
self.up_ratio = up_ratio
self.low_channels = low_channels
self.window_size = window_size
self.base_width = width
self.with_additional_inputs = with_additional_inputs
for res_log2 in range(int(np.log2(up_ratio))):
_width = width
width = max(width // ch_decay, 64)
heads = int(width / 64)
width = heads * 64
if self.with_additional_inputs:
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width+self.base_width, width, 2))
else:
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width, width, 2))
encoder = Transformer(width, 2, heads,
modulate_feature_size=None,
modulate_act_type=None,
cross_att_layers=0,
return_all_layers=False,
flash_v2=False,
qkv_packed=False,
shift_group=False,
window_size=window_size)
self.add_module(f'attention_{res_log2}', encoder)
self.out_channels = width
self.ln_post = LayerNorm(width)
def forward(self, x, additional_inputs=None):
if self.with_additional_inputs:
assert len(additional_inputs) == int(np.log2(self.up_ratio))
for res_log2 in range(int(np.log2(self.up_ratio))):
if self.with_additional_inputs:
add_input = additional_inputs[res_log2]
scale = x.shape[1] // add_input.shape[1]
add_input = add_input.repeat_interleave(scale, 1)
x = torch.cat([x, add_input], dim=2)
x = getattr(self, f'upsampler_{res_log2}')(x)
x = x.permute(1, 0, 2)
x = getattr(self, f'attention_{res_log2}')(x)
x = x.permute(1, 0, 2)
x = self.ln_post(x)
return x
class HyperGaussianUpsampler(nn.Module):
def __init__(self, width,
resolution,
up_ratio,
ch_decay=1,
window_size=False,
with_additional_inputs=False,
upsampler_kwargs={}):
super().__init__()
self.up_ratio = up_ratio
self.window_size = window_size
self.base_width = width
self.with_additional_inputs = with_additional_inputs
self.resolution = resolution
for res_log2 in range(int(np.log2(up_ratio))):
if res_log2 == 0:
_width = width
width = width
heads = int(width / 64)
width = heads * 64
if self.with_additional_inputs:
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width+self.base_width, width, 2))
else:
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width, width, 2))
encoder = Transformer(width, 2, heads,
modulate_feature_size=None,
modulate_act_type=None,
cross_att_layers=0,
return_all_layers=False,
flash_v2=False,
qkv_packed=False,
shift_group=False,
window_size=window_size)
self.add_module(f'attention_{res_log2}', encoder)
self.resolution = self.resolution * 2
else:
self.resolution = self.resolution * 2
self.add_module(f'upsample_{res_log2}',
UpsamplerLayers_conv(in_channels=width,
out_channels=width,
resolution=self.resolution,
conv_block_type = 'convnext',
**upsampler_kwargs))
self.out_channels = width
# self.ln_post = LayerNorm(width)
self.ln_post = LayerNorm([self.resolution, self.resolution, width])
def forward(self, x, additional_inputs=None):
if self.with_additional_inputs:
assert len(additional_inputs) == int(np.log2(self.up_ratio))
for res_log2 in range(int(np.log2(self.up_ratio))):
if res_log2 == 0:
if self.with_additional_inputs:
add_input = additional_inputs[res_log2]
scale = x.shape[1] // add_input.shape[1]
add_input = add_input.repeat_interleave(scale, 1)
x = torch.cat([x, add_input], dim=2)
x = getattr(self, f'upsampler_{res_log2}')(x)
x = x.permute(1, 0, 2)
x = getattr(self, f'attention_{res_log2}')(x)
x = x.permute(1, 0, 2)
x = x.reshape(x.shape[0], int(math.sqrt(x.shape[1])), int(math.sqrt(x.shape[1])), -1).permute(0, 3, 1, 2)
else:
x = getattr(self, f'upsample_{res_log2}')(x)
x = self.ln_post(x.permute(0, 2, 3, 1))
return x
class VisionTransformer(nn.Module):
def __init__(self,
# transformer params
in_channels: int,
patch_size: int,
width: int,
layers: int,
heads: int,
weight: str = None,
encode_layers: int = 0,
shift_group = False,
flash_v2 = False,
qkv_packed = False,
window_size = False,
use_pe = False,
# modualtion params
modulate_feature_size: int = None,
modulate_act_type: str = 'gelu',
# camera condition
camera_condition: str = 'plucker',
# init params
disable_dino=False,
error_weight_init_mode='mean',
# other params
add_zero_conv=False,
return_all_layers=False,
disable_post_ln=False,
rope=None):
super().__init__()
self.patch_size = patch_size
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
self.use_pe = use_pe
self.rope = rope
self.disable_dino = disable_dino
# if not self.disable_dino:
# scale = width ** -0.5
# self.class_embedding = nn.Parameter(scale * torch.randn(width))
# self.positional_embedding = nn.Parameter(scale * torch.randn((input_res// patch_size) ** 2 + 1, width))
# else:
# if self.use_pe:
# self.positional_embedding = nn.Parameter(torch.zeros(1, (input_res// patch_size) ** 2, width))
# nn.init.trunc_normal_(self.positional_embedding, std=0.02)
self.ln_pre = LayerNorm(width)
self.add_zero_conv = add_zero_conv
self.return_all_layers = return_all_layers
self.disable_post_ln = disable_post_ln
self.flash_v2 = flash_v2
self.qkv_packed = qkv_packed
self.camera_condition = camera_condition
if self.camera_condition == 'plucker': assert modulate_feature_size is None
if self.add_zero_conv:
assert self.return_all_layers
self.zero_convs = nn.ModuleList([zero_module(nn.Conv1d(in_channels=width, out_channels=width, kernel_size=1, stride=1, bias=True)) for _ in range(layers)])
self.encode_layers = encode_layers
if self.encode_layers > 0:
self.encoder = Transformer(width, encode_layers, heads,
modulate_feature_size=modulate_feature_size,
modulate_act_type=modulate_act_type,
cross_att_layers=0,
return_all_layers=return_all_layers,
flash_v2=flash_v2,
qkv_packed=qkv_packed,
shift_group=shift_group,
window_size=window_size)
self.transformer = Transformer(width, layers-encode_layers, heads,
modulate_feature_size=modulate_feature_size,
modulate_act_type=modulate_act_type,
cross_att_layers=0,
return_all_layers=return_all_layers,
flash_v2=flash_v2,
qkv_packed=qkv_packed,
shift_group=shift_group,
window_size=window_size)
if not self.disable_post_ln:
self.ln_post = LayerNorm(width)
if weight is not None:
if not self.disable_dino:
if "clip" in weight:
raise NotImplementedError()
elif weight.startswith("vit_b_16"):
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
elif weight.startswith("vit_b_8"):
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
else:
raise NotImplementedError()
else:
self.apply(_init_weights)
# Init the weight and bias of modulation_fn to zero
if modulate_feature_size != 0:
for block in self.transformer.resblocks:
if block.modulation_fn is not None:
block.modulation_fn[2].weight.data.zero_()
block.modulation_fn[2].bias.data.zero_()
if block.mlp_modulation_fn is not None:
block.mlp_modulation_fn[2].weight.data.zero_()
block.mlp_modulation_fn[2].bias.data.zero_()
for block in self.transformer.resblocks:
if block.cross_att:
zero_module(block.cross_att.to_out)
def set_grad_checkpointing(self, flag=True):
self.transformer.set_grad_checkpointing(flag)
def forward(self,
x: torch.Tensor,
modulation: torch.Tensor = None,
context: torch.Tensor = None,
additional_residuals=None,
abla_crossview=False,
pos=None):
# image tokenization
bs, vs = x.shape[:2]
x = rearrange(x, 'b v c h w -> (b v) c h w')
pos = rearrange(pos, 'b v c h -> (b v) c h')
if self.camera_condition == 'plucker' and modulation is not None:
modulation = rearrange(modulation, 'b v c h w -> (b v) c h w')
x = torch.cat([x, modulation], dim=1)
modulation = None
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# pre-normalization
x = self.ln_pre(x)
B, N, C = x.shape
x = x.reshape(B, N, -1, 64)
x = x.permute(0, 2, 1, 3)
# print('pre x mean: ', x.mean().item())
# print('pre x var: ', x.var().item())
x = x + self.rope(torch.ones_like(x).to(x), pos)
# print('x mean: ', x.mean().item())
# print('x var: ', x.var().item())
x = x.permute(0, 2, 1, 3)
x = x.reshape(B, N, -1)
# use encode to extract features
if self.encode_layers > 0:
x = x.permute(1, 0, 2) # NLD -> LND
x = self.encoder(x, modulation, context, additional_residuals=additional_residuals)
x = x.permute(1, 0, 2) # LND -> NLD
if not self.disable_dino:
x = x.permute(1, 0, 2) # NLD -> LND
else:
if not abla_crossview:
# flatten x along the video dimension
x = rearrange(x, '(b v) n d -> b (v n) d', v=vs)
# print(x.shape)
x = x.permute(1, 0, 2) # NLD -> LND
else:
x = x.permute(1, 0, 2)
x = self.transformer(x, modulation, context, additional_residuals=additional_residuals)
if self.add_zero_conv:
assert isinstance(x, (list, tuple))
assert len(x) == len(self.zero_convs)
new_x = []
for sub_x, sub_zero_conv in zip(x, self.zero_convs):
sub_x_out = sub_zero_conv(sub_x.permute(1, 2, 0))
new_x.append(sub_x_out.permute(2, 0, 1))
x = new_x
if self.return_all_layers:
assert isinstance(x, (list, tuple))
if not self.disable_post_ln:
x_final = x[-1].permute(1, 0, 2) # LND -> NLD
x_final = self.ln_post(x_final)
x_final = rearrange(x_final, 'b (v n) d -> b v n d', v=vs)
x = [s.permute(1, 0, 2) for s in x]
x.append(x_final)
return x
if not self.disable_post_ln:
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
if not self.disable_dino:
x = rearrange(x, '(b v) n d -> b v n d', b=bs, v=vs)
else:
if not abla_crossview:
# reshape x back to video dimension
x = rearrange(x, 'b (v n) d -> b v n d', v=vs)
else:
x = rearrange(x, '(b v) n d -> b v n d', v=vs)
return x
def extra_repr(self) -> str:
pass
class VisionTransformer_fusion(nn.Module):
def __init__(self,
# transformer params
in_channels: int,
patch_size: int,
width: int,
layers: int,
heads: int,
weight: str = None,
encode_layers: int = 0,
shift_group = False,
flash_v2 = False,
qkv_packed = False,
window_size = False,
use_pe = False,
# modualtion params
modulate_feature_size: int = None,
modulate_act_type: str = 'gelu',
# camera condition
camera_condition: str = 'plucker',
# init params
disable_dino=False,
error_weight_init_mode='mean',
# other params
add_zero_conv=False,
return_all_layers=False,
disable_post_ln=False,
rope=None):
super().__init__()
self.patch_size = patch_size
self.use_pe = use_pe
self.rope = rope
self.disable_dino = disable_dino
# if not self.disable_dino:
# scale = width ** -0.5
# self.class_embedding = nn.Parameter(scale * torch.randn(width))
# self.positional_embedding = nn.Parameter(scale * torch.randn((input_res// patch_size) ** 2 + 1, width))
# else:
# if self.use_pe:
# self.positional_embedding = nn.Parameter(torch.zeros(1, (input_res// patch_size) ** 2, width))
# nn.init.trunc_normal_(self.positional_embedding, std=0.02)
self.ln_pre = LayerNorm(width)
self.add_zero_conv = add_zero_conv
self.return_all_layers = return_all_layers
self.disable_post_ln = disable_post_ln
self.flash_v2 = flash_v2
self.qkv_packed = qkv_packed
self.camera_condition = camera_condition
if self.camera_condition == 'plucker': assert modulate_feature_size is None
if self.add_zero_conv:
assert self.return_all_layers
self.zero_convs = nn.ModuleList([zero_module(nn.Conv1d(in_channels=width, out_channels=width, kernel_size=1, stride=1, bias=True)) for _ in range(layers)])
self.encode_layers = encode_layers
if self.encode_layers > 0:
self.encoder = Transformer(width, encode_layers, heads,
modulate_feature_size=modulate_feature_size,
modulate_act_type=modulate_act_type,
cross_att_layers=0,
return_all_layers=return_all_layers,
flash_v2=flash_v2,
qkv_packed=qkv_packed,
shift_group=shift_group,
window_size=window_size)
self.transformer = Transformer(width, layers-encode_layers, heads,
modulate_feature_size=modulate_feature_size,
modulate_act_type=modulate_act_type,
cross_att_layers=0,
return_all_layers=return_all_layers,
flash_v2=flash_v2,
qkv_packed=qkv_packed,
shift_group=shift_group,
window_size=window_size)
if not self.disable_post_ln:
self.ln_post = LayerNorm(width)
if weight is not None:
if not self.disable_dino:
if "clip" in weight:
raise NotImplementedError()
elif weight.startswith("vit_b_16"):
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
elif weight.startswith("vit_b_8"):
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
else:
raise NotImplementedError()
else:
self.apply(_init_weights)
# Init the weight and bias of modulation_fn to zero
if modulate_feature_size != 0:
for block in self.transformer.resblocks:
if block.modulation_fn is not None:
block.modulation_fn[2].weight.data.zero_()
block.modulation_fn[2].bias.data.zero_()
if block.mlp_modulation_fn is not None:
block.mlp_modulation_fn[2].weight.data.zero_()
block.mlp_modulation_fn[2].bias.data.zero_()
for block in self.transformer.resblocks:
if block.cross_att:
zero_module(block.cross_att.to_out)
def set_grad_checkpointing(self, flag=True):
self.transformer.set_grad_checkpointing(flag)
def forward(self,
x: torch.Tensor,
modulation: torch.Tensor = None,
context: torch.Tensor = None,
additional_residuals=None,
abla_crossview=False,
pos=None):
# image tokenization
bs, vs = x.shape[:2]
x = rearrange(x, 'b v h g -> (b v) h g') # shape = [*, grid ** 2, width]
pos = rearrange(pos, 'b v c h -> (b v) c h')
# pre-normalization
B, N, C = x.shape
x = x.reshape(B, N, -1, 64)
x = x.permute(0, 2, 1, 3)
# print('pre x mean: ', x.mean().item())
# print('pre x var: ', x.var().item())
x = x + self.rope(torch.ones_like(x).to(x), pos)
# print('x mean: ', x.mean().item())
# print('x var: ', x.var().item())
x = x.permute(0, 2, 1, 3)
x = x.reshape(B, N, -1)
# use encode to extract features
if self.encode_layers > 0:
x = x.permute(1, 0, 2) # NLD -> LND
x = self.encoder(x, modulation, context, additional_residuals=additional_residuals)
x = x.permute(1, 0, 2) # LND -> NLD
if not self.disable_dino:
x = x.permute(1, 0, 2) # NLD -> LND
else:
if not abla_crossview:
# flatten x along the video dimension
x = rearrange(x, '(b v) n d -> b (v n) d', v=vs)
# print(x.shape)
x = x.permute(1, 0, 2) # NLD -> LND
else:
x = x.permute(1, 0, 2)
x = self.transformer(x, modulation, context, additional_residuals=additional_residuals)
if self.add_zero_conv:
assert isinstance(x, (list, tuple))
assert len(x) == len(self.zero_convs)
new_x = []
for sub_x, sub_zero_conv in zip(x, self.zero_convs):
sub_x_out = sub_zero_conv(sub_x.permute(1, 2, 0))
new_x.append(sub_x_out.permute(2, 0, 1))
x = new_x
if self.return_all_layers:
assert isinstance(x, (list, tuple))
if not self.disable_post_ln:
x_final = x[-1].permute(1, 0, 2) # LND -> NLD
x_final = self.ln_post(x_final)
x_final = rearrange(x_final, 'b (v n) d -> b v n d', v=vs)
x = [s.permute(1, 0, 2) for s in x]
x.append(x_final)
return x
if not self.disable_post_ln:
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
if not self.disable_dino:
x = rearrange(x, '(b v) n d -> b v n d', b=bs, v=vs)
else:
if not abla_crossview:
# reshape x back to video dimension
x = rearrange(x, 'b (v n) d -> b v n d', v=vs)
else:
x = rearrange(x, '(b v) n d -> b v n d', v=vs)
return x
def extra_repr(self) -> str:
pass
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic'):
"""
Resize positional embeddings, implementation from google/simclr and open_clip.
"""
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('positional_embedding', None)
if old_pos_embed is None:
return
# Compute the grid size and extra tokens
old_pos_len = state_dict["positional_embedding"].shape[0]
old_grid_size = round((state_dict["positional_embedding"].shape[0]) ** 0.5)
grid_size = round((model.positional_embedding.shape[0]) ** 0.5)
if old_grid_size == grid_size:
return
extra_tokens = old_pos_len - (old_grid_size ** 2)
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
# Only interpolate the positional emb part, not the extra token part.
pos_emb_img = pos_emb_img.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size * grid_size, -1)[0]
# Concatenate back the
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['positional_embedding'] = new_pos_embed
myname2timmname = {
"vit_b_16_mae": None,
"vit_b_16_in": "vit_base_patch16_224",
"vit_b_16_in21k": 'vit_base_patch16_224_in21k',
"vit_b_16_sam": 'vit_base_patch16_224_sam',
"vit_b_16_dino": 'vit_base_patch16_224_dino',
"vit_b_16_mill_in21k": 'vit_base_patch16_224_miil_in21k',
"vit_b_16_mill": 'vit_base_patch16_224_miil',
"vit_b_8_dino": 'vit_base_patch16_224_dino',
}
def load_timm_to_clip(module, config_name="vit_b_16_mae", init_mode='zero'):
from torch import nn
from clip.model import LayerNorm as CLIPLayerNorm
from clip.model import QuickGELU
from torch.nn import GELU
from torch.nn import LayerNorm
import json
now_dir = os.path.abspath(os.path.dirname(__file__))
timm2clip = json.load(open(f"{now_dir}/timm2clip_vit_b_16.json"))
assert config_name in myname2timmname, f"The name {config_name} is not one of {list(myname2timmname.keys())}"
try:
timm_weight = torch.load(f"/sensei-fs/users/hatan/model/{config_name}.pth")["model"]
except Exception as e:
try:
print(f"/input/yhxu/models/dino_weights/{config_name}.pth")
timm_weight = torch.load(f"/input/yhxu/models/dino_weights/{config_name}.pth")["model"]
except Exception as e:
try:
print(f"/home/yhxu/models/dino_weights/{config_name}.pth")
timm_weight = torch.load(f"/home/yhxu/models/dino_weights/{config_name}.pth")["model"]
except:
try:
timm_weight = torch.load(f"/nas2/zifan/checkpoint/dino_weights/{config_name}.pth")["model"]
except Exception as e:
print("Please download weight with support/dump_timm_weights.py. \n"
"If using mae weight, please check https://github.com/facebookresearch/mae,"
"and download the weight as vit_b_16_mae.pth")
assert False
# Build model's state dict
clipname2timmweight = {}
for timm_key, clip_key in timm2clip.items():
timm_value = timm_weight[timm_key]
clipname2timmweight[clip_key[len("visual."):]] = timm_value.squeeze()
# Resize positional embedding
resize_pos_embed(clipname2timmweight, module)
# Load weight to model.
model_visual_keys = set(module.state_dict().keys())
load_keys = set(clipname2timmweight.keys())
# print(f"Load not in model: {load_keys - model_visual_keys}")
# print(f"Model not in load: {model_visual_keys - load_keys}")
# status = module.load_state_dict(clipname2timmweight, strict=False)
try:
status = module.load_state_dict(clipname2timmweight, strict=False)
except:
print('conv.weight has error!')
if init_mode == 'zero':
new_weight = torch.zeros_like(clipname2timmweight['conv1.weight'])
new_weight = new_weight.repeat(1, 2, 1, 1)
new_weight[:,:3] = clipname2timmweight['conv1.weight']
elif init_mode == 'mean':
new_weight = torch.zeros_like(clipname2timmweight['conv1.weight'])
new_weight = new_weight.repeat(1, 3, 1, 1)
new_weight = ((clipname2timmweight['conv1.weight']).repeat(1, 3, 1, 1))/3
clipname2timmweight['conv1.weight'] = new_weight
status = module.load_state_dict(clipname2timmweight, strict=False)
# Since timm model has bias, we add it back here.
module.conv1.bias = nn.Parameter(clipname2timmweight['conv1.bias'])
# Reinit the visual weights that not covered by timm
module.ln_pre.reset_parameters()
def convert_clip_to_timm(module):
"""Copy from detectron2, frozen BN"""
res = module
if isinstance(module, CLIPLayerNorm):
# Timm uses eps=1e-6 while CLIP uses eps=1e-5
res = LayerNorm(module.normalized_shape, eps=1e-6, elementwise_affine=module.elementwise_affine)
if module.elementwise_affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
elif isinstance(module, QuickGELU):
# Timm uses GELU while CLIP uses QuickGELU
res = GELU()
else:
for name, child in module.named_children():
new_child = convert_clip_to_timm(child)
if new_child is not child:
res.add_module(name, new_child)
return res