MohamedRashad's picture
Add skin and tokenizer systems with parsing and tokenization functionalities
11b119e
from functools import partial
from addict import Dict
import math
import torch
import torch.nn as nn
import spconv.pytorch as spconv
import torch_scatter
from timm.models.layers import DropPath
from typing import Union
from einops import rearrange
try:
import flash_attn
except ImportError:
flash_attn = None
from .utils.misc import offset2bincount
from .utils.structure import Point
from .modules import PointModule, PointSequential
class RPE(torch.nn.Module):
def __init__(self, patch_size, num_heads):
super().__init__()
self.patch_size = patch_size
self.num_heads = num_heads
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
self.rpe_num = 2 * self.pos_bnd + 1
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
def forward(self, coord):
idx = (
coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
+ self.pos_bnd # relative position to positive index
+ torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
)
out = self.rpe_table.index_select(0, idx.reshape(-1))
out = out.view(idx.shape + (-1,)).sum(3)
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
return out
class QueryKeyNorm(nn.Module):
def __init__(self, channels, num_heads):
super(QueryKeyNorm, self).__init__()
self.num_heads = num_heads
self.norm = nn.LayerNorm(channels // num_heads, elementwise_affine=False)
def forward(self, qkv):
H = self.num_heads
#qkv = qkv.reshape(-1, 3, H, qkv.shape[1] // H).permute(1, 0, 2, 3)
qkv = rearrange(qkv, 'N (S H Ch) -> S N H Ch', H=H, S=3)
q, k, v = qkv.unbind(dim=0)
# q, k, v: [N, H, C // H]
q_norm = self.norm(q)
k_norm = self.norm(k)
# qkv_norm: [3, N, H, C // H]
qkv_norm = torch.stack([q_norm, k_norm, v])
qkv_norm = rearrange(qkv_norm, 'S N H Ch -> N (S H Ch)')
return qkv_norm
class SerializedAttention(PointModule):
def __init__(
self,
channels,
num_heads,
patch_size,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
order_index=0,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
enable_qknorm=False,
):
super().__init__()
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
self.channels = channels
self.num_heads = num_heads
self.scale = qk_scale or (channels // num_heads) ** -0.5
self.order_index = order_index
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.enable_rpe = enable_rpe
self.enable_flash = enable_flash
self.enable_qknorm = enable_qknorm
if enable_qknorm:
self.qknorm = QueryKeyNorm(channels, num_heads)
else:
print("WARNING: enable_qknorm is False in PTv3Object and training may be fragile")
if enable_flash:
assert (
enable_rpe is False
), "Set enable_rpe to False when enable Flash Attention"
assert (
upcast_attention is False
), "Set upcast_attention to False when enable Flash Attention"
assert (
upcast_softmax is False
), "Set upcast_softmax to False when enable Flash Attention"
assert flash_attn is not None, "Make sure flash_attn is installed."
self.patch_size = patch_size
self.attn_drop = attn_drop
else:
# when disable flash attention, we still don't want to use mask
# consequently, patch size will auto set to the
# min number of patch_size_max and number of points
self.patch_size_max = patch_size
self.patch_size = 0
self.attn_drop = torch.nn.Dropout(attn_drop)
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
self.proj = torch.nn.Linear(channels, channels)
self.proj_drop = torch.nn.Dropout(proj_drop)
self.softmax = torch.nn.Softmax(dim=-1)
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
@torch.no_grad()
def get_rel_pos(self, point, order):
K = self.patch_size
rel_pos_key = f"rel_pos_{self.order_index}"
if rel_pos_key not in point.keys():
grid_coord = point.grid_coord[order]
grid_coord = grid_coord.reshape(-1, K, 3)
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
return point[rel_pos_key]
@torch.no_grad()
def get_padding_and_inverse(self, point):
pad_key = "pad"
unpad_key = "unpad"
cu_seqlens_key = "cu_seqlens_key"
if (
pad_key not in point.keys()
or unpad_key not in point.keys()
or cu_seqlens_key not in point.keys()
):
offset = point.offset
bincount = offset2bincount(offset)
bincount_pad = (
torch.div(
bincount + self.patch_size - 1,
self.patch_size,
rounding_mode="trunc",
)
* self.patch_size
)
# only pad point when num of points larger than patch_size
mask_pad = bincount > self.patch_size
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
_offset = nn.functional.pad(offset, (1, 0))
_offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
pad = torch.arange(_offset_pad[-1], device=offset.device)
unpad = torch.arange(_offset[-1], device=offset.device)
cu_seqlens = []
for i in range(len(offset)):
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
if bincount[i] != bincount_pad[i]:
pad[
_offset_pad[i + 1]
- self.patch_size
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
] = pad[
_offset_pad[i + 1]
- 2 * self.patch_size
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
- self.patch_size
]
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
cu_seqlens.append(
torch.arange(
_offset_pad[i],
_offset_pad[i + 1],
step=self.patch_size,
dtype=torch.int32,
device=offset.device,
)
)
point[pad_key] = pad
point[unpad_key] = unpad
point[cu_seqlens_key] = nn.functional.pad(
torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
)
return point[pad_key], point[unpad_key], point[cu_seqlens_key]
def forward(self, point):
if not self.enable_flash:
self.patch_size = min(
offset2bincount(point.offset).min().tolist(), self.patch_size_max
)
H = self.num_heads
K = self.patch_size
C = self.channels
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
order = point.serialized_order[self.order_index][pad]
inverse = unpad[point.serialized_inverse[self.order_index]]
# padding and reshape feat and batch for serialized point patch
qkv = self.qkv(point.feat)[order]
if self.enable_qknorm:
qkv = self.qknorm(qkv)
if not self.enable_flash:
# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
q, k, v = (
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
)
# attn
if self.upcast_attention:
q = q.float()
k = k.float()
attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
if self.enable_rpe:
attn = attn + self.rpe(self.get_rel_pos(point, order))
if self.upcast_softmax:
attn = attn.float()
attn = self.softmax(attn)
attn = self.attn_drop(attn).to(qkv.dtype)
feat = (attn @ v).transpose(1, 2).reshape(-1, C)
else:
feat = flash_attn.flash_attn_varlen_qkvpacked_func(
qkv.half().reshape(-1, 3, H, C // H),
cu_seqlens,
max_seqlen=self.patch_size,
dropout_p=self.attn_drop if self.training else 0,
softmax_scale=self.scale,
).reshape(-1, C)
feat = feat.to(qkv.dtype)
feat = feat[inverse]
# ffn
feat = self.proj(feat)
feat = self.proj_drop(feat)
point.feat = feat
return point
class MLP(nn.Module):
def __init__(
self,
in_channels,
hidden_channels=None,
out_channels=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.fc1 = nn.Linear(in_channels, hidden_channels)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_channels, out_channels)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(PointModule):
def __init__(
self,
channels,
num_heads,
patch_size=48,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
pre_norm=True,
order_index=0,
cpe_indice_key=None,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
enable_qknorm=False,
):
super().__init__()
self.channels = channels
self.pre_norm = pre_norm
self.cpe = PointSequential(
spconv.SubMConv3d(
channels,
channels,
kernel_size=3,
bias=True,
indice_key=cpe_indice_key,
),
nn.Linear(channels, channels),
norm_layer(channels),
)
self.norm1 = PointSequential(norm_layer(channels))
self.attn = SerializedAttention(
channels=channels,
patch_size=patch_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
order_index=order_index,
enable_rpe=enable_rpe,
enable_flash=enable_flash,
upcast_attention=upcast_attention,
upcast_softmax=upcast_softmax,
enable_qknorm=enable_qknorm,
)
self.norm2 = PointSequential(norm_layer(channels))
self.mlp = PointSequential(
MLP(
in_channels=channels,
hidden_channels=int(channels * mlp_ratio),
out_channels=channels,
act_layer=act_layer,
drop=proj_drop,
)
)
self.drop_path = PointSequential(
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
def forward(self, point: Point):
shortcut = point.feat
point = self.cpe(point)
point.feat = shortcut + point.feat
shortcut = point.feat
if self.pre_norm:
point = self.norm1(point)
point = self.drop_path(self.attn(point))
point.feat = shortcut + point.feat
if not self.pre_norm:
point = self.norm1(point)
shortcut = point.feat
if self.pre_norm:
point = self.norm2(point)
point = self.drop_path(self.mlp(point))
point.feat = shortcut + point.feat
if not self.pre_norm:
point = self.norm2(point)
# point.sparse_conv_feat.replace_feature(point.feat)
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
return point
class SerializedPooling(PointModule):
def __init__(
self,
in_channels,
out_channels,
stride=2,
norm_layer=None,
act_layer=None,
reduce="max",
shuffle_orders=True,
traceable=True, # record parent and cluster
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8
# TODO: add support to grid pool (any stride)
self.stride = stride
assert reduce in ["sum", "mean", "min", "max"]
self.reduce = reduce
self.shuffle_orders = shuffle_orders
self.traceable = traceable
self.proj = nn.Linear(in_channels, out_channels)
if norm_layer is not None:
self.norm = PointSequential(norm_layer(out_channels))
if act_layer is not None:
self.act = PointSequential(act_layer())
def forward(self, point: Point):
pooling_depth = (math.ceil(self.stride) - 1).bit_length()
if pooling_depth > point.serialized_depth:
pooling_depth = 0
assert {
"serialized_code",
"serialized_order",
"serialized_inverse",
"serialized_depth",
}.issubset(
point.keys()
), "Run point.serialization() point cloud before SerializedPooling"
code = point.serialized_code >> pooling_depth * 3
code_, cluster, counts = torch.unique(
code[0],
sorted=True,
return_inverse=True,
return_counts=True,
)
# indices of point sorted by cluster, for torch_scatter.segment_csr
_, indices = torch.sort(cluster)
# index pointer for sorted point, for torch_scatter.segment_csr
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
# head_indices of each cluster, for reduce attr e.g. code, batch
head_indices = indices[idx_ptr[:-1]]
# generate down code, order, inverse
code = code[:, head_indices]
order = torch.argsort(code)
inverse = torch.zeros_like(order).scatter_(
dim=1,
index=order,
src=torch.arange(0, code.shape[1], device=order.device).repeat(
code.shape[0], 1
),
)
if self.shuffle_orders:
perm = torch.randperm(code.shape[0])
code = code[perm]
order = order[perm]
inverse = inverse[perm]
# collect information
point_dict = Dict(
feat=torch_scatter.segment_csr(
self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
),
coord=torch_scatter.segment_csr(
point.coord[indices], idx_ptr, reduce="mean"
),
grid_coord=point.grid_coord[head_indices] >> pooling_depth,
serialized_code=code,
serialized_order=order,
serialized_inverse=inverse,
serialized_depth=point.serialized_depth - pooling_depth,
batch=point.batch[head_indices],
)
if "condition" in point.keys():
point_dict["condition"] = point.condition
if "context" in point.keys():
point_dict["context"] = point.context
if self.traceable:
point_dict["pooling_inverse"] = cluster
point_dict["pooling_parent"] = point
point = Point(point_dict)
if self.norm is not None:
point = self.norm(point)
if self.act is not None:
point = self.act(point)
point.sparsify()
return point
class SerializedUnpooling(PointModule):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
norm_layer=None,
act_layer=None,
traceable=False, # record parent and cluster
):
super().__init__()
self.proj = PointSequential(nn.Linear(in_channels, out_channels))
self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
if norm_layer is not None:
self.proj.add(norm_layer(out_channels))
self.proj_skip.add(norm_layer(out_channels))
if act_layer is not None:
self.proj.add(act_layer())
self.proj_skip.add(act_layer())
self.traceable = traceable
def forward(self, point):
assert "pooling_parent" in point.keys()
assert "pooling_inverse" in point.keys()
parent = point.pop("pooling_parent")
inverse = point.pop("pooling_inverse")
point = self.proj(point)
parent = self.proj_skip(parent)
parent.feat = parent.feat + point.feat[inverse]
if self.traceable:
parent["unpooling_parent"] = point
return parent
class Embedding(PointModule):
def __init__(
self,
in_channels,
embed_channels,
norm_layer=None,
act_layer=None,
res_linear=False,
):
super().__init__()
self.in_channels = in_channels
self.embed_channels = embed_channels
# TODO: check remove spconv
self.stem = PointSequential(
conv=spconv.SubMConv3d(
in_channels,
embed_channels,
kernel_size=5,
padding=1,
bias=False,
indice_key="stem",
)
)
if norm_layer is not None:
self.stem.add(norm_layer(embed_channels), name="norm")
if act_layer is not None:
self.stem.add(act_layer(), name="act")
if res_linear:
self.res_linear = nn.Linear(in_channels, embed_channels)
else:
self.res_linear = None
def forward(self, point: Point):
if self.res_linear:
res_feature = self.res_linear(point.feat)
point = self.stem(point)
if self.res_linear:
point.feat = point.feat + res_feature
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
return point
class PointTransformerV3Object(PointModule):
def __init__(
self,
in_channels=9,
order=("z", "z-trans", "hilbert", "hilbert-trans"),
stride=(),
enc_depths=(3, 3, 3, 6, 16),
enc_channels=(32, 64, 128, 256, 384),
enc_num_head=(2, 4, 8, 16, 24),
enc_patch_size=(1024, 1024, 1024, 1024, 1024),
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
pre_norm=True,
shuffle_orders=True,
enable_rpe=False,
enable_flash=True,
upcast_attention=False,
upcast_softmax=False,
cls_mode=False,
enable_qknorm=False,
layer_norm=False,
res_linear=True,
):
super().__init__()
self.num_stages = len(enc_depths)
self.order = [order] if isinstance(order, str) else order
self.cls_mode = cls_mode
self.shuffle_orders = shuffle_orders
# norm layers
if layer_norm:
bn_layer = partial(nn.LayerNorm)
else:
print("WARNING: use BatchNorm in ptv3obj !!!")
bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
ln_layer = nn.LayerNorm
# activation layers
act_layer = nn.GELU
self.embedding = Embedding(
in_channels=in_channels,
embed_channels=enc_channels[0],
norm_layer=bn_layer,
act_layer=act_layer,
res_linear=res_linear,
)
# encoder
enc_drop_path = [
x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
]
self.enc = PointSequential()
for s in range(self.num_stages):
enc_drop_path_ = enc_drop_path[
sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
]
enc = PointSequential()
if s > 0:
enc.add(nn.Linear(enc_channels[s - 1], enc_channels[s]))
for i in range(enc_depths[s]):
enc.add(
Block(
channels=enc_channels[s],
num_heads=enc_num_head[s],
patch_size=enc_patch_size[s],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
drop_path=enc_drop_path_[i],
norm_layer=ln_layer,
act_layer=act_layer,
pre_norm=pre_norm,
order_index=i % len(self.order),
cpe_indice_key=f"stage{s}",
enable_rpe=enable_rpe,
enable_flash=enable_flash,
upcast_attention=upcast_attention,
upcast_softmax=upcast_softmax,
enable_qknorm=enable_qknorm,
),
name=f"block{i}",
)
if len(enc) != 0:
self.enc.add(module=enc, name=f"enc{s}")
def forward(self, data_dict, min_coord=None):
point = Point(data_dict)
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders, min_coord=min_coord)
point.sparsify()
point = self.embedding(point)
point = self.enc(point)
return point
def get_encoder(pretrained_path: Union[str, None]=None, freeze_encoder: bool=False, **kwargs) -> PointTransformerV3Object:
point_encoder = PointTransformerV3Object(**kwargs)
if pretrained_path is not None:
checkpoint = torch.load(pretrained_path)
state_dict = checkpoint["state_dict"]
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
point_encoder.load_state_dict(state_dict, strict=False)
if freeze_encoder is True:
for name, param in point_encoder.named_parameters():
if 'res_linear' not in name and 'qknorm' not in name:
param.requires_grad = False
return point_encoder