Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,798 Bytes
853528a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import torch
import torch.nn as nn
from functools import partial
from copy import deepcopy
from .dinov2.layers import Mlp
from ..utils.geometry import homogenize_points
from .layers.pos_embed import RoPE2D, PositionGetter
from .layers.block import BlockRope
from .layers.attention import FlashAttentionRope
from .layers.transformer_head import TransformerDecoder, LinearPts3d
from .layers.camera_head import CameraHead
from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg
from huggingface_hub import PyTorchModelHubMixin
class Pi3(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
pos_type='rope100',
decoder_size='large',
):
super().__init__()
# ----------------------
# Encoder
# ----------------------
self.encoder = dinov2_vitl14_reg(pretrained=False)
self.patch_size = 14
del self.encoder.mask_token
# ----------------------
# Positonal Encoding
# ----------------------
self.pos_type = pos_type if pos_type is not None else 'none'
self.rope=None
if self.pos_type.startswith('rope'): # eg rope100
if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
freq = float(self.pos_type[len('rope'):])
self.rope = RoPE2D(freq=freq)
self.position_getter = PositionGetter()
else:
raise NotImplementedError
# ----------------------
# Decoder
# ----------------------
enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024
if decoder_size == 'small':
dec_embed_dim = 384
dec_num_heads = 6
mlp_ratio = 4
dec_depth = 24
elif decoder_size == 'base':
dec_embed_dim = 768
dec_num_heads = 12
mlp_ratio = 4
dec_depth = 24
elif decoder_size == 'large':
dec_embed_dim = 1024
dec_num_heads = 16
mlp_ratio = 4
dec_depth = 36
else:
raise NotImplementedError
self.decoder = nn.ModuleList([
BlockRope(
dim=dec_embed_dim,
num_heads=dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=0.01,
qk_norm=True,
attn_class=FlashAttentionRope,
rope=self.rope
) for _ in range(dec_depth)])
self.dec_embed_dim = dec_embed_dim
# ----------------------
# Register_token
# ----------------------
num_register_tokens = 5
self.patch_start_idx = num_register_tokens
self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim))
nn.init.normal_(self.register_token, std=1e-6)
# ----------------------
# Local Points Decoder
# ----------------------
self.point_decoder = TransformerDecoder(
in_dim=2*self.dec_embed_dim,
dec_embed_dim=1024,
dec_num_heads=16,
out_dim=1024,
rope=self.rope,
)
self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
# ----------------------
# Conf Decoder
# ----------------------
self.conf_decoder = deepcopy(self.point_decoder)
self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
# ----------------------
# Camera Pose Decoder
# ----------------------
self.camera_decoder = TransformerDecoder(
in_dim=2*self.dec_embed_dim,
dec_embed_dim=1024,
dec_num_heads=16, # 8
out_dim=512,
rope=self.rope,
use_checkpoint=False
)
self.camera_head = CameraHead(dim=512)
# For ImageNet Normalize
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
self.register_buffer("image_mean", image_mean)
self.register_buffer("image_std", image_std)
def decode(self, hidden, N, H, W):
BN, hw, _ = hidden.shape
B = BN // N
final_output = []
hidden = hidden.reshape(B*N, hw, -1)
register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:])
# Concatenate special tokens with patch tokens
hidden = torch.cat([register_token, hidden], dim=1)
hw = hidden.shape[1]
if self.pos_type.startswith('rope'):
pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device)
if self.patch_start_idx > 0:
# do not use position embedding for special tokens (camera and register tokens)
# so set pos to 0 for the special tokens
pos = pos + 1
pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype)
pos = torch.cat([pos_special, pos], dim=1)
for i in range(len(self.decoder)):
blk = self.decoder[i]
if i % 2 == 0:
pos = pos.reshape(B*N, hw, -1)
hidden = hidden.reshape(B*N, hw, -1)
else:
pos = pos.reshape(B, N*hw, -1)
hidden = hidden.reshape(B, N*hw, -1)
hidden = blk(hidden, xpos=pos)
if i+1 in [len(self.decoder)-1, len(self.decoder)]:
final_output.append(hidden.reshape(B*N, hw, -1))
return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1)
def forward(self, imgs):
imgs = (imgs - self.image_mean) / self.image_std
B, N, _, H, W = imgs.shape
patch_h, patch_w = H // 14, W // 14
# encode by dinov2
imgs = imgs.reshape(B*N, _, H, W)
hidden = self.encoder(imgs, is_training=True)
if isinstance(hidden, dict):
hidden = hidden["x_norm_patchtokens"]
hidden, pos = self.decode(hidden, N, H, W)
point_hidden = self.point_decoder(hidden, xpos=pos)
conf_hidden = self.conf_decoder(hidden, xpos=pos)
camera_hidden = self.camera_decoder(hidden, xpos=pos)
with torch.amp.autocast(device_type='cuda', enabled=False):
# local points
point_hidden = point_hidden.float()
ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
xy, z = ret.split([2, 1], dim=-1)
z = torch.exp(z)
local_points = torch.cat([xy * z, z], dim=-1)
# confidence
conf_hidden = conf_hidden.float()
conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
# camera
camera_hidden = camera_hidden.float()
camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4)
# unproject local points using camera poses
points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3]
return dict(
points=points,
local_points=local_points,
conf=conf,
camera_poses=camera_poses,
)
|