Pi3 / pi3 /models /layers /transformer_head.py
yyfz233's picture
Initial commit
853528a
from .attention import FlashAttentionRope
from .block import BlockRope
from ..dinov2.layers import Mlp
import torch.nn as nn
from functools import partial
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
class TransformerDecoder(nn.Module):
def __init__(
self,
in_dim,
out_dim,
dec_embed_dim=512,
depth=5,
dec_num_heads=8,
mlp_ratio=4,
rope=None,
need_project=True,
use_checkpoint=False,
):
super().__init__()
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
BlockRope(
dim=dec_embed_dim,
num_heads=dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=None,
qk_norm=False,
# attn_class=MemEffAttentionRope,
attn_class=FlashAttentionRope,
rope=rope
) for _ in range(depth)])
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
def forward(self, hidden, xpos=None):
hidden = self.projects(hidden)
for i, blk in enumerate(self.blocks):
if self.use_checkpoint and self.training:
hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
else:
hidden = blk(hidden, xpos=xpos)
out = self.linear_out(hidden)
return out
class LinearPts3d (nn.Module):
"""
Linear head for dust3r
Each token outputs: - 16x16 3D points (+ confidence)
"""
def __init__(self, patch_size, dec_embed_dim, output_dim=3,):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
def forward(self, decout, img_shape):
H, W = img_shape
tokens = decout[-1]
B, S, D = tokens.shape
# extract 3D points
feat = self.proj(tokens) # B,S,D
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
# permute + norm depth
return feat.permute(0, 2, 3, 1)