Spaces:
Runtime error
Runtime error
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from pointnet2_ops import pointnet2_utils | |
def build_shared_mlp(mlp_spec: List[int], bn: bool = True): | |
layers = [] | |
for i in range(1, len(mlp_spec)): | |
layers.append( | |
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) | |
) | |
if bn: | |
layers.append(nn.BatchNorm2d(mlp_spec[i])) | |
layers.append(nn.ReLU(True)) | |
return nn.Sequential(*layers) | |
class _PointnetSAModuleBase(nn.Module): | |
def __init__(self): | |
super(_PointnetSAModuleBase, self).__init__() | |
self.npoint = None | |
self.groupers = None | |
self.mlps = None | |
def forward( | |
self, xyz: torch.Tensor, features: Optional[torch.Tensor] | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
r""" | |
Parameters | |
---------- | |
xyz : torch.Tensor | |
(B, N, 3) tensor of the xyz coordinates of the features | |
features : torch.Tensor | |
(B, C, N) tensor of the descriptors of the the features | |
Returns | |
------- | |
new_xyz : torch.Tensor | |
(B, npoint, 3) tensor of the new features' xyz | |
new_features : torch.Tensor | |
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors | |
""" | |
new_features_list = [] | |
xyz_flipped = xyz.transpose(1, 2).contiguous() | |
new_xyz = ( | |
pointnet2_utils.gather_operation( | |
xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) | |
) | |
.transpose(1, 2) | |
.contiguous() | |
if self.npoint is not None | |
else None | |
) | |
for i in range(len(self.groupers)): | |
new_features = self.groupers[i]( | |
xyz, new_xyz, features | |
) # (B, C, npoint, nsample) | |
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) | |
new_features = F.max_pool2d( | |
new_features, kernel_size=[1, new_features.size(3)] | |
) # (B, mlp[-1], npoint, 1) | |
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) | |
new_features_list.append(new_features) | |
return new_xyz, torch.cat(new_features_list, dim=1) | |
class PointnetSAModuleMSG(_PointnetSAModuleBase): | |
r"""Pointnet set abstrction layer with multiscale grouping | |
Parameters | |
---------- | |
npoint : int | |
Number of features | |
radii : list of float32 | |
list of radii to group with | |
nsamples : list of int32 | |
Number of samples in each ball query | |
mlps : list of list of int32 | |
Spec of the pointnet before the global max_pool for each scale | |
bn : bool | |
Use batchnorm | |
""" | |
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): | |
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None | |
super(PointnetSAModuleMSG, self).__init__() | |
assert len(radii) == len(nsamples) == len(mlps) | |
self.npoint = npoint | |
self.groupers = nn.ModuleList() | |
self.mlps = nn.ModuleList() | |
for i in range(len(radii)): | |
radius = radii[i] | |
nsample = nsamples[i] | |
self.groupers.append( | |
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) | |
if npoint is not None | |
else pointnet2_utils.GroupAll(use_xyz) | |
) | |
mlp_spec = mlps[i] | |
if use_xyz: | |
mlp_spec[0] += 3 | |
self.mlps.append(build_shared_mlp(mlp_spec, bn)) | |
class PointnetSAModule(PointnetSAModuleMSG): | |
r"""Pointnet set abstrction layer | |
Parameters | |
---------- | |
npoint : int | |
Number of features | |
radius : float | |
Radius of ball | |
nsample : int | |
Number of samples in the ball query | |
mlp : list | |
Spec of the pointnet before the global max_pool | |
bn : bool | |
Use batchnorm | |
""" | |
def __init__( | |
self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True | |
): | |
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None | |
super(PointnetSAModule, self).__init__( | |
mlps=[mlp], | |
npoint=npoint, | |
radii=[radius], | |
nsamples=[nsample], | |
bn=bn, | |
use_xyz=use_xyz, | |
) | |
class PointnetFPModule(nn.Module): | |
r"""Propigates the features of one set to another | |
Parameters | |
---------- | |
mlp : list | |
Pointnet module parameters | |
bn : bool | |
Use batchnorm | |
""" | |
def __init__(self, mlp, bn=True): | |
# type: (PointnetFPModule, List[int], bool) -> None | |
super(PointnetFPModule, self).__init__() | |
self.mlp = build_shared_mlp(mlp, bn=bn) | |
def forward(self, unknown, known, unknow_feats, known_feats): | |
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor | |
r""" | |
Parameters | |
---------- | |
unknown : torch.Tensor | |
(B, n, 3) tensor of the xyz positions of the unknown features | |
known : torch.Tensor | |
(B, m, 3) tensor of the xyz positions of the known features | |
unknow_feats : torch.Tensor | |
(B, C1, n) tensor of the features to be propigated to | |
known_feats : torch.Tensor | |
(B, C2, m) tensor of features to be propigated | |
Returns | |
------- | |
new_features : torch.Tensor | |
(B, mlp[-1], n) tensor of the features of the unknown features | |
""" | |
if known is not None: | |
dist, idx = pointnet2_utils.three_nn(unknown, known) | |
dist_recip = 1.0 / (dist + 1e-8) | |
norm = torch.sum(dist_recip, dim=2, keepdim=True) | |
weight = dist_recip / norm | |
interpolated_feats = pointnet2_utils.three_interpolate( | |
known_feats, idx, weight | |
) | |
else: | |
interpolated_feats = known_feats.expand( | |
*(known_feats.size()[0:2] + [unknown.size(1)]) | |
) | |
if unknow_feats is not None: | |
new_features = torch.cat( | |
[interpolated_feats, unknow_feats], dim=1 | |
) # (B, C2 + C1, n) | |
else: | |
new_features = interpolated_feats | |
new_features = new_features.unsqueeze(-1) | |
new_features = self.mlp(new_features) | |
return new_features.squeeze(-1) | |