Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,568 Bytes
853528a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import torch.nn as nn
from copy import deepcopy
import torch.nn.functional as F
# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
class ResConvBlock(nn.Module):
"""
1x1 convolution residual block
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
# self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
# self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
# self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
# change 1x1 convolution to linear
self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
def forward(self, res):
x = F.relu(self.res_conv1(res))
x = F.relu(self.res_conv2(x))
x = F.relu(self.res_conv3(x))
res = self.head_skip(res) + x
return res
class CameraHead(nn.Module):
def __init__(self, dim=512):
super().__init__()
output_dim = dim
self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim))
for _ in range(2)])
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.more_mlps = nn.Sequential(
nn.Linear(output_dim,output_dim),
nn.ReLU(),
nn.Linear(output_dim,output_dim),
nn.ReLU()
)
self.fc_t = nn.Linear(output_dim, 3)
self.fc_rot = nn.Linear(output_dim, 9)
def forward(self, feat, patch_h, patch_w):
BN, hw, c = feat.shape
for i in range(2):
feat = self.res_conv[i](feat)
# feat = self.avgpool(feat)
feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ##########
feat = feat.view(feat.size(0), -1)
feat = self.more_mlps(feat) # [B, D_]
with torch.amp.autocast(device_type='cuda', enabled=False):
out_t = self.fc_t(feat.float()) # [B,3]
out_r = self.fc_rot(feat.float()) # [B,9]
pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
return pose
def convert_pose_to_4x4(self, B, out_r, out_t, device):
out_r = self.svd_orthogonalize(out_r) # [N,3,3]
pose = torch.zeros((B, 4, 4), device=device)
pose[:, :3, :3] = out_r
pose[:, :3, 3] = out_t
pose[:, 3, 3] = 1.
return pose
def svd_orthogonalize(self, m):
"""Convert 9D representation to SO(3) using SVD orthogonalization.
Args:
m: [BATCH, 3, 3] 3x3 matrices.
Returns:
[BATCH, 3, 3] SO(3) rotation matrices.
"""
if m.dim() < 3:
m = m.reshape((-1, 3, 3))
m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2)
u, s, v = torch.svd(m_transpose)
det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
# Check orientation reflection.
r = torch.matmul(
torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
u.transpose(-2, -1)
)
return r |