Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from functools import partial | |
from copy import deepcopy | |
from .dinov2.layers import Mlp | |
from ..utils.geometry import homogenize_points | |
from .layers.pos_embed import RoPE2D, PositionGetter | |
from .layers.block import BlockRope | |
from .layers.attention import FlashAttentionRope | |
from .layers.transformer_head import TransformerDecoder, LinearPts3d | |
from .layers.camera_head import CameraHead | |
from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg | |
from huggingface_hub import PyTorchModelHubMixin | |
class Pi3(nn.Module, PyTorchModelHubMixin): | |
def __init__( | |
self, | |
pos_type='rope100', | |
decoder_size='large', | |
): | |
super().__init__() | |
# ---------------------- | |
# Encoder | |
# ---------------------- | |
self.encoder = dinov2_vitl14_reg(pretrained=False) | |
self.patch_size = 14 | |
del self.encoder.mask_token | |
# ---------------------- | |
# Positonal Encoding | |
# ---------------------- | |
self.pos_type = pos_type if pos_type is not None else 'none' | |
self.rope=None | |
if self.pos_type.startswith('rope'): # eg rope100 | |
if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") | |
freq = float(self.pos_type[len('rope'):]) | |
self.rope = RoPE2D(freq=freq) | |
self.position_getter = PositionGetter() | |
else: | |
raise NotImplementedError | |
# ---------------------- | |
# Decoder | |
# ---------------------- | |
enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024 | |
if decoder_size == 'small': | |
dec_embed_dim = 384 | |
dec_num_heads = 6 | |
mlp_ratio = 4 | |
dec_depth = 24 | |
elif decoder_size == 'base': | |
dec_embed_dim = 768 | |
dec_num_heads = 12 | |
mlp_ratio = 4 | |
dec_depth = 24 | |
elif decoder_size == 'large': | |
dec_embed_dim = 1024 | |
dec_num_heads = 16 | |
mlp_ratio = 4 | |
dec_depth = 36 | |
else: | |
raise NotImplementedError | |
self.decoder = nn.ModuleList([ | |
BlockRope( | |
dim=dec_embed_dim, | |
num_heads=dec_num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=True, | |
proj_bias=True, | |
ffn_bias=True, | |
drop_path=0.0, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
act_layer=nn.GELU, | |
ffn_layer=Mlp, | |
init_values=0.01, | |
qk_norm=True, | |
attn_class=FlashAttentionRope, | |
rope=self.rope | |
) for _ in range(dec_depth)]) | |
self.dec_embed_dim = dec_embed_dim | |
# ---------------------- | |
# Register_token | |
# ---------------------- | |
num_register_tokens = 5 | |
self.patch_start_idx = num_register_tokens | |
self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)) | |
nn.init.normal_(self.register_token, std=1e-6) | |
# ---------------------- | |
# Local Points Decoder | |
# ---------------------- | |
self.point_decoder = TransformerDecoder( | |
in_dim=2*self.dec_embed_dim, | |
dec_embed_dim=1024, | |
dec_num_heads=16, | |
out_dim=1024, | |
rope=self.rope, | |
) | |
self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3) | |
# ---------------------- | |
# Conf Decoder | |
# ---------------------- | |
self.conf_decoder = deepcopy(self.point_decoder) | |
self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1) | |
# ---------------------- | |
# Camera Pose Decoder | |
# ---------------------- | |
self.camera_decoder = TransformerDecoder( | |
in_dim=2*self.dec_embed_dim, | |
dec_embed_dim=1024, | |
dec_num_heads=16, # 8 | |
out_dim=512, | |
rope=self.rope, | |
use_checkpoint=False | |
) | |
self.camera_head = CameraHead(dim=512) | |
# For ImageNet Normalize | |
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) | |
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) | |
self.register_buffer("image_mean", image_mean) | |
self.register_buffer("image_std", image_std) | |
def decode(self, hidden, N, H, W): | |
BN, hw, _ = hidden.shape | |
B = BN // N | |
final_output = [] | |
hidden = hidden.reshape(B*N, hw, -1) | |
register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:]) | |
# Concatenate special tokens with patch tokens | |
hidden = torch.cat([register_token, hidden], dim=1) | |
hw = hidden.shape[1] | |
if self.pos_type.startswith('rope'): | |
pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device) | |
if self.patch_start_idx > 0: | |
# do not use position embedding for special tokens (camera and register tokens) | |
# so set pos to 0 for the special tokens | |
pos = pos + 1 | |
pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype) | |
pos = torch.cat([pos_special, pos], dim=1) | |
for i in range(len(self.decoder)): | |
blk = self.decoder[i] | |
if i % 2 == 0: | |
pos = pos.reshape(B*N, hw, -1) | |
hidden = hidden.reshape(B*N, hw, -1) | |
else: | |
pos = pos.reshape(B, N*hw, -1) | |
hidden = hidden.reshape(B, N*hw, -1) | |
hidden = blk(hidden, xpos=pos) | |
if i+1 in [len(self.decoder)-1, len(self.decoder)]: | |
final_output.append(hidden.reshape(B*N, hw, -1)) | |
return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1) | |
def forward(self, imgs): | |
imgs = (imgs - self.image_mean) / self.image_std | |
B, N, _, H, W = imgs.shape | |
patch_h, patch_w = H // 14, W // 14 | |
# encode by dinov2 | |
imgs = imgs.reshape(B*N, _, H, W) | |
hidden = self.encoder(imgs, is_training=True) | |
if isinstance(hidden, dict): | |
hidden = hidden["x_norm_patchtokens"] | |
hidden, pos = self.decode(hidden, N, H, W) | |
point_hidden = self.point_decoder(hidden, xpos=pos) | |
conf_hidden = self.conf_decoder(hidden, xpos=pos) | |
camera_hidden = self.camera_decoder(hidden, xpos=pos) | |
with torch.amp.autocast(device_type='cuda', enabled=False): | |
# local points | |
point_hidden = point_hidden.float() | |
ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) | |
xy, z = ret.split([2, 1], dim=-1) | |
z = torch.exp(z) | |
local_points = torch.cat([xy * z, z], dim=-1) | |
# confidence | |
conf_hidden = conf_hidden.float() | |
conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) | |
# camera | |
camera_hidden = camera_hidden.float() | |
camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4) | |
# unproject local points using camera poses | |
points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3] | |
return dict( | |
points=points, | |
local_points=local_points, | |
conf=conf, | |
camera_poses=camera_poses, | |
) | |