|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .kan import KANLayer
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
|
"""
|
|
Standard positional encoding with Sin/Cos functions + LayerNorm to preserve
|
|
temporal relationships between frames throughtout sequence-modeling.
|
|
"""
|
|
super(PositionalEncoding, self).__init__()
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
self.d_model = d_model
|
|
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(1)
|
|
self.register_buffer("pe", pe)
|
|
self.norm_pe = nn.LayerNorm(d_model)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x: Input tensor of shape (seq_len, batch_size, d_model)
|
|
Returns:
|
|
Tensor with positional encodings added and normalized
|
|
"""
|
|
seq_len = x.size(0)
|
|
x2 = x + self.pe[:seq_len, :]
|
|
x2 = self.norm_pe(x2)
|
|
return self.dropout(x2)
|
|
|
|
|
|
class Encoder_TRANSFORMER(nn.Module):
|
|
"""
|
|
Encoder module using Transformer architecture with KAN layers.
|
|
Key components:
|
|
- KANLayer which eplaces linear projections with learnable 1D splines;
|
|
- Transformer Encoder processing temporal dependencies.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
modeltype,
|
|
njoints: int,
|
|
nfeats: int,
|
|
num_frames: int,
|
|
num_classes: int,
|
|
translation,
|
|
pose_rep,
|
|
glob,
|
|
glob_rot,
|
|
latent_dim: int = 256,
|
|
ff_size: int = 1024,
|
|
num_layers: int = 4,
|
|
num_heads: int = 4,
|
|
dropout: float = 0.1,
|
|
activation: str = "gelu",
|
|
**kargs
|
|
):
|
|
super().__init__()
|
|
self.njoints = njoints
|
|
self.nfeats = nfeats
|
|
self.num_frames = num_frames
|
|
self.num_classes = num_classes
|
|
self.pose_rep = pose_rep
|
|
self.glob = glob
|
|
self.glob_rot = glob_rot
|
|
self.translation = translation
|
|
|
|
self.latent_dim = latent_dim
|
|
self.ff_size = ff_size
|
|
self.num_layers = num_layers
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.activation = activation
|
|
|
|
self.input_feats = self.njoints * self.nfeats
|
|
|
|
|
|
self.muQuery = nn.Parameter(torch.randn(1, self.latent_dim))
|
|
self.sigmaQuery = nn.Parameter(torch.randn(1, self.latent_dim))
|
|
|
|
|
|
|
|
|
|
|
|
self.skelEmbedding = KANLayer(self.input_feats, self.latent_dim)
|
|
|
|
|
|
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
d_model=self.latent_dim,
|
|
nhead=self.num_heads,
|
|
dim_feedforward=self.ff_size,
|
|
dropout=self.dropout,
|
|
activation=self.activation
|
|
)
|
|
self.seqTransEncoder = nn.TransformerEncoder(encoder_layer, num_layers=self.num_layers)
|
|
self.encoder_norm = nn.LayerNorm(self.latent_dim)
|
|
|
|
def forward(self, batch: dict) -> dict:
|
|
"""
|
|
batch["x"]: (batch, njoints, nfeats, nframes)
|
|
batch["y"]: (batch,) — classes (if none, then == 0)
|
|
batch["mask"]: (batch, nframes) — bool-mask of actual frames
|
|
"""
|
|
x, y, mask = batch["x"], batch["y"], batch["mask"]
|
|
bs, nj, nf, nf2 = x.shape
|
|
assert nf2 == self.num_frames, "Frame dimension mismatch"
|
|
|
|
|
|
x_seq = x.permute(3, 0, 1, 2).reshape(self.num_frames, bs, self.input_feats)
|
|
|
|
|
|
x_emb = self.skelEmbedding(x_seq)
|
|
|
|
|
|
if y is None:
|
|
y = torch.zeros(bs, dtype=torch.long, device=x.device)
|
|
else:
|
|
y = y.clamp(0, self.num_classes - 1)
|
|
|
|
|
|
mu_init = self.muQuery.expand(bs, -1)
|
|
sigma_init = self.sigmaQuery.expand(bs, -1)
|
|
|
|
|
|
mu_init = mu_init.unsqueeze(0)
|
|
sigma_init = sigma_init.unsqueeze(0)
|
|
xcat = torch.cat((mu_init, sigma_init, x_emb), dim=0)
|
|
|
|
|
|
mu_sigma_mask = torch.ones((bs, 2), dtype=torch.bool, device=x.device)
|
|
mask_seq = torch.cat((mu_sigma_mask, mask), dim=1)
|
|
|
|
|
|
xcat_pe = self.sequence_pos_encoder(xcat)
|
|
|
|
|
|
encoded = self.seqTransEncoder(
|
|
xcat_pe,
|
|
src_key_padding_mask=~mask_seq
|
|
)
|
|
|
|
|
|
encoded = self.encoder_norm(encoded)
|
|
|
|
|
|
mu = encoded[0]
|
|
logvar = encoded[1]
|
|
|
|
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std)
|
|
z = mu + eps * std
|
|
|
|
return {"mu": mu, "logvar": logvar, "z": z}
|
|
|
|
|
|
class Decoder_TRANSFORMER(nn.Module):
|
|
"""
|
|
Decoder module using Transformer architecture with KAN-layer:
|
|
- KANLayer: Final projection layer for skeleton reconstruction
|
|
- Transformer Decoder: Autoregressive generation of sequences
|
|
"""
|
|
def __init__(
|
|
self,
|
|
modeltype,
|
|
njoints: int,
|
|
nfeats: int,
|
|
num_frames: int,
|
|
num_classes: int,
|
|
translation,
|
|
pose_rep,
|
|
glob,
|
|
glob_rot,
|
|
latent_dim: int = 256,
|
|
ff_size: int = 1024,
|
|
num_layers: int = 4,
|
|
num_heads: int = 4,
|
|
dropout: float = 0.1,
|
|
activation: str = "gelu",
|
|
**kargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.njoints = njoints
|
|
self.nfeats = nfeats
|
|
self.num_frames = num_frames
|
|
self.num_classes = num_classes
|
|
self.pose_rep = pose_rep
|
|
self.glob = glob
|
|
self.glob_rot = glob_rot
|
|
self.translation = translation
|
|
|
|
self.latent_dim = latent_dim
|
|
self.ff_size = ff_size
|
|
self.num_layers = num_layers
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.activation = activation
|
|
|
|
self.input_feats = self.njoints * self.nfeats
|
|
|
|
|
|
self.actionBiases = nn.Parameter(torch.randn(1, self.latent_dim))
|
|
|
|
|
|
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
|
|
|
|
|
|
decoder_layer = nn.TransformerDecoderLayer(
|
|
d_model=self.latent_dim,
|
|
nhead=self.num_heads,
|
|
dim_feedforward=self.ff_size,
|
|
dropout=self.dropout,
|
|
activation=self.activation
|
|
)
|
|
self.seqTransDecoder = nn.TransformerDecoder(decoder_layer, num_layers=self.num_layers)
|
|
self.decoder_norm = nn.LayerNorm(self.latent_dim)
|
|
|
|
|
|
|
|
|
|
self.finallayer = KANLayer(self.latent_dim, self.input_feats)
|
|
|
|
def forward(self, batch: dict, use_text_emb: bool = False) -> dict:
|
|
"""
|
|
Forward pass for the decoder.
|
|
Args:
|
|
batch: Dictionary containing latent codes and metadata
|
|
use_text_emb: Whether to use text embeddings instead of latent codes
|
|
Returns:
|
|
Dictionary with generated output
|
|
"""
|
|
z = batch["z"]
|
|
y = batch["y"]
|
|
mask = batch["mask"]
|
|
lengths = batch.get("lengths", None)
|
|
bs, nframes = mask.shape
|
|
nj, nf = self.njoints, self.nfeats
|
|
|
|
|
|
if use_text_emb:
|
|
z = batch["clip_text_emb"]
|
|
|
|
|
|
z = F.layer_norm(z, (self.latent_dim,))
|
|
z = z.unsqueeze(0)
|
|
|
|
|
|
timequeries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
|
|
|
|
|
|
timequeries_pe = self.sequence_pos_encoder(timequeries)
|
|
|
|
|
|
if mask.dtype != torch.bool:
|
|
mask = mask.bool()
|
|
|
|
|
|
dec_out = self.seqTransDecoder(
|
|
tgt=timequeries_pe,
|
|
memory=z,
|
|
tgt_key_padding_mask=~mask
|
|
)
|
|
|
|
|
|
dec_out = self.decoder_norm(dec_out)
|
|
|
|
|
|
skel_feats = self.finallayer(dec_out)
|
|
skel_feats = skel_feats.view(nframes, bs, nj, nf)
|
|
|
|
|
|
mask_t = mask.T
|
|
skel_feats[~mask_t] = 0.0
|
|
|
|
|
|
output = skel_feats.permute(1, 2, 3, 0).contiguous()
|
|
|
|
if use_text_emb:
|
|
batch["txt_output"] = output
|
|
else:
|
|
batch["output"] = output
|
|
|
|
return batch |