Pi3 / pi3 /models /pi3.py
yyfz233's picture
Initial commit
853528a
import torch
import torch.nn as nn
from functools import partial
from copy import deepcopy
from .dinov2.layers import Mlp
from ..utils.geometry import homogenize_points
from .layers.pos_embed import RoPE2D, PositionGetter
from .layers.block import BlockRope
from .layers.attention import FlashAttentionRope
from .layers.transformer_head import TransformerDecoder, LinearPts3d
from .layers.camera_head import CameraHead
from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg
from huggingface_hub import PyTorchModelHubMixin
class Pi3(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
pos_type='rope100',
decoder_size='large',
):
super().__init__()
# ----------------------
# Encoder
# ----------------------
self.encoder = dinov2_vitl14_reg(pretrained=False)
self.patch_size = 14
del self.encoder.mask_token
# ----------------------
# Positonal Encoding
# ----------------------
self.pos_type = pos_type if pos_type is not None else 'none'
self.rope=None
if self.pos_type.startswith('rope'): # eg rope100
if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
freq = float(self.pos_type[len('rope'):])
self.rope = RoPE2D(freq=freq)
self.position_getter = PositionGetter()
else:
raise NotImplementedError
# ----------------------
# Decoder
# ----------------------
enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024
if decoder_size == 'small':
dec_embed_dim = 384
dec_num_heads = 6
mlp_ratio = 4
dec_depth = 24
elif decoder_size == 'base':
dec_embed_dim = 768
dec_num_heads = 12
mlp_ratio = 4
dec_depth = 24
elif decoder_size == 'large':
dec_embed_dim = 1024
dec_num_heads = 16
mlp_ratio = 4
dec_depth = 36
else:
raise NotImplementedError
self.decoder = nn.ModuleList([
BlockRope(
dim=dec_embed_dim,
num_heads=dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=0.01,
qk_norm=True,
attn_class=FlashAttentionRope,
rope=self.rope
) for _ in range(dec_depth)])
self.dec_embed_dim = dec_embed_dim
# ----------------------
# Register_token
# ----------------------
num_register_tokens = 5
self.patch_start_idx = num_register_tokens
self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim))
nn.init.normal_(self.register_token, std=1e-6)
# ----------------------
# Local Points Decoder
# ----------------------
self.point_decoder = TransformerDecoder(
in_dim=2*self.dec_embed_dim,
dec_embed_dim=1024,
dec_num_heads=16,
out_dim=1024,
rope=self.rope,
)
self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
# ----------------------
# Conf Decoder
# ----------------------
self.conf_decoder = deepcopy(self.point_decoder)
self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
# ----------------------
# Camera Pose Decoder
# ----------------------
self.camera_decoder = TransformerDecoder(
in_dim=2*self.dec_embed_dim,
dec_embed_dim=1024,
dec_num_heads=16, # 8
out_dim=512,
rope=self.rope,
use_checkpoint=False
)
self.camera_head = CameraHead(dim=512)
# For ImageNet Normalize
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
self.register_buffer("image_mean", image_mean)
self.register_buffer("image_std", image_std)
def decode(self, hidden, N, H, W):
BN, hw, _ = hidden.shape
B = BN // N
final_output = []
hidden = hidden.reshape(B*N, hw, -1)
register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:])
# Concatenate special tokens with patch tokens
hidden = torch.cat([register_token, hidden], dim=1)
hw = hidden.shape[1]
if self.pos_type.startswith('rope'):
pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device)
if self.patch_start_idx > 0:
# do not use position embedding for special tokens (camera and register tokens)
# so set pos to 0 for the special tokens
pos = pos + 1
pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype)
pos = torch.cat([pos_special, pos], dim=1)
for i in range(len(self.decoder)):
blk = self.decoder[i]
if i % 2 == 0:
pos = pos.reshape(B*N, hw, -1)
hidden = hidden.reshape(B*N, hw, -1)
else:
pos = pos.reshape(B, N*hw, -1)
hidden = hidden.reshape(B, N*hw, -1)
hidden = blk(hidden, xpos=pos)
if i+1 in [len(self.decoder)-1, len(self.decoder)]:
final_output.append(hidden.reshape(B*N, hw, -1))
return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1)
def forward(self, imgs):
imgs = (imgs - self.image_mean) / self.image_std
B, N, _, H, W = imgs.shape
patch_h, patch_w = H // 14, W // 14
# encode by dinov2
imgs = imgs.reshape(B*N, _, H, W)
hidden = self.encoder(imgs, is_training=True)
if isinstance(hidden, dict):
hidden = hidden["x_norm_patchtokens"]
hidden, pos = self.decode(hidden, N, H, W)
point_hidden = self.point_decoder(hidden, xpos=pos)
conf_hidden = self.conf_decoder(hidden, xpos=pos)
camera_hidden = self.camera_decoder(hidden, xpos=pos)
with torch.amp.autocast(device_type='cuda', enabled=False):
# local points
point_hidden = point_hidden.float()
ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
xy, z = ret.split([2, 1], dim=-1)
z = torch.exp(z)
local_points = torch.cat([xy * z, z], dim=-1)
# confidence
conf_hidden = conf_hidden.float()
conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
# camera
camera_hidden = camera_hidden.float()
camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4)
# unproject local points using camera poses
points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3]
return dict(
points=points,
local_points=local_points,
conf=conf,
camera_poses=camera_poses,
)