MohamedRashad's picture
Add skin and tokenizer systems with parsing and tokenization functionalities
11b119e
import sys
import torch.nn as nn
import spconv.pytorch as spconv
from collections import OrderedDict
from .utils.structure import Point
class PointModule(nn.Module):
r"""PointModule
placeholder, all module subclass from this will take Point in PointSequential.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class PointSequential(PointModule):
r"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
"""
def __init__(self, *args, **kwargs):
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
for name, module in kwargs.items():
if sys.version_info < (3, 6):
raise ValueError("kwargs only supported in py36+")
if name in self._modules:
raise ValueError("name exists.")
self.add_module(name, module)
def __getitem__(self, idx):
if not (-len(self) <= idx < len(self)):
raise IndexError("index {} is out of range".format(idx))
if idx < 0:
idx += len(self)
it = iter(self._modules.values())
for i in range(idx):
next(it)
return next(it)
def __len__(self):
return len(self._modules)
def add(self, module, name=None):
if name is None:
name = str(len(self._modules))
if name in self._modules:
raise KeyError("name exists")
self.add_module(name, module)
def forward(self, input):
for k, module in self._modules.items():
# Point module
if isinstance(module, PointModule):
input = module(input)
# Spconv module
elif spconv.modules.is_spconv_module(module):
if isinstance(input, Point):
input.sparse_conv_feat = module(input.sparse_conv_feat)
input.feat = input.sparse_conv_feat.features
else:
input = module(input)
# PyTorch module
else:
if isinstance(input, Point):
input.feat = module(input.feat)
if "sparse_conv_feat" in input.keys():
input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(
input.feat
)
elif isinstance(input, spconv.SparseConvTensor):
if input.indices.shape[0] != 0:
input = input.replace_feature(module(input.features))
else:
input = module(input)
return input