Spaces:
Running
on
Zero
Running
on
Zero
from .attention import FlashAttentionRope | |
from .block import BlockRope | |
from ..dinov2.layers import Mlp | |
import torch.nn as nn | |
from functools import partial | |
from torch.utils.checkpoint import checkpoint | |
import torch.nn.functional as F | |
class TransformerDecoder(nn.Module): | |
def __init__( | |
self, | |
in_dim, | |
out_dim, | |
dec_embed_dim=512, | |
depth=5, | |
dec_num_heads=8, | |
mlp_ratio=4, | |
rope=None, | |
need_project=True, | |
use_checkpoint=False, | |
): | |
super().__init__() | |
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() | |
self.use_checkpoint = use_checkpoint | |
self.blocks = nn.ModuleList([ | |
BlockRope( | |
dim=dec_embed_dim, | |
num_heads=dec_num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=True, | |
proj_bias=True, | |
ffn_bias=True, | |
drop_path=0.0, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
act_layer=nn.GELU, | |
ffn_layer=Mlp, | |
init_values=None, | |
qk_norm=False, | |
# attn_class=MemEffAttentionRope, | |
attn_class=FlashAttentionRope, | |
rope=rope | |
) for _ in range(depth)]) | |
self.linear_out = nn.Linear(dec_embed_dim, out_dim) | |
def forward(self, hidden, xpos=None): | |
hidden = self.projects(hidden) | |
for i, blk in enumerate(self.blocks): | |
if self.use_checkpoint and self.training: | |
hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False) | |
else: | |
hidden = blk(hidden, xpos=xpos) | |
out = self.linear_out(hidden) | |
return out | |
class LinearPts3d (nn.Module): | |
""" | |
Linear head for dust3r | |
Each token outputs: - 16x16 3D points (+ confidence) | |
""" | |
def __init__(self, patch_size, dec_embed_dim, output_dim=3,): | |
super().__init__() | |
self.patch_size = patch_size | |
self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2) | |
def forward(self, decout, img_shape): | |
H, W = img_shape | |
tokens = decout[-1] | |
B, S, D = tokens.shape | |
# extract 3D points | |
feat = self.proj(tokens) # B,S,D | |
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) | |
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W | |
# permute + norm depth | |
return feat.permute(0, 2, 3, 1) |