File size: 11,335 Bytes
fe64bad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .kan import KANLayer
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
"""
Standard positional encoding with Sin/Cos functions + LayerNorm to preserve
temporal relationships between frames throughtout sequence-modeling.
"""
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.d_model = d_model
# Precompute positional encodings (PE) using sinusoidal functions
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1) # (max_len, 1, d_model)
self.register_buffer("pe", pe)
self.norm_pe = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (seq_len, batch_size, d_model)
Returns:
Tensor with positional encodings added and normalized
"""
seq_len = x.size(0)
x2 = x + self.pe[:seq_len, :] # Add positional encodings
x2 = self.norm_pe(x2) # Normalize
return self.dropout(x2)
class Encoder_TRANSFORMER(nn.Module):
"""
Encoder module using Transformer architecture with KAN layers.
Key components:
- KANLayer which eplaces linear projections with learnable 1D splines;
- Transformer Encoder processing temporal dependencies.
"""
def __init__(
self,
modeltype,
njoints: int,
nfeats: int,
num_frames: int,
num_classes: int,
translation,
pose_rep,
glob,
glob_rot,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
**kargs
):
super().__init__()
self.njoints = njoints
self.nfeats = nfeats
self.num_frames = num_frames
self.num_classes = num_classes
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.translation = translation
self.latent_dim = latent_dim # Latent space dimensionality
self.ff_size = ff_size # Feedforward network size
self.num_layers = num_layers # Transformer layers
self.num_heads = num_heads # Multi-head attention heads
self.dropout = dropout
self.activation = activation
self.input_feats = self.njoints * self.nfeats # Input feature dimension
# Learnable parameters for μ and σ (variational posterior)
self.muQuery = nn.Parameter(torch.randn(1, self.latent_dim))
self.sigmaQuery = nn.Parameter(torch.randn(1, self.latent_dim))
# KANLayer for skeleton embedding:
# Input: njoints * nfeats (flattened joint features)
# Output: latent_dim (compressed representation)
# KANLayer replaces linear projections with a matrix of 1D B-splines
self.skelEmbedding = KANLayer(self.input_feats, self.latent_dim)
# Positional Encoding for temporal alignment
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
# Transformer Encoder with multi-head attention
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation
)
self.seqTransEncoder = nn.TransformerEncoder(encoder_layer, num_layers=self.num_layers)
self.encoder_norm = nn.LayerNorm(self.latent_dim) # Final normalization
def forward(self, batch: dict) -> dict:
"""
batch["x"]: (batch, njoints, nfeats, nframes)
batch["y"]: (batch,) — classes (if none, then == 0)
batch["mask"]: (batch, nframes) — bool-mask of actual frames
"""
x, y, mask = batch["x"], batch["y"], batch["mask"]
bs, nj, nf, nf2 = x.shape # nf2 = nframes
assert nf2 == self.num_frames, "Frame dimension mismatch"
# Reshape input: (nframes, batch, njoints*nfeats)
x_seq = x.permute(3, 0, 1, 2).reshape(self.num_frames, bs, self.input_feats)
# Applies learnable 1D splines to input features
x_emb = self.skelEmbedding(x_seq) # (nframes, batch, latent_dim)
# Handle class labels (y)
if y is None:
y = torch.zeros(bs, dtype=torch.long, device=x.device)
else:
y = y.clamp(0, self.num_classes - 1)
# Initialize mu and sigma queries:
mu_init = self.muQuery.expand(bs, -1) # (batch, latent_dim)
sigma_init = self.sigmaQuery.expand(bs, -1) # (batch, latent_dim)
# Concatenate [mu, sigma, x_emb] for Transformer input
mu_init = mu_init.unsqueeze(0) # (1, batch, latent_dim)
sigma_init = sigma_init.unsqueeze(0) # (1, batch, latent_dim)
xcat = torch.cat((mu_init, sigma_init, x_emb), dim=0) # (2 + nframes, batch, latent_dim)
# Update mask for mu/sigma
mu_sigma_mask = torch.ones((bs, 2), dtype=torch.bool, device=x.device)
mask_seq = torch.cat((mu_sigma_mask, mask), dim=1) # (batch, 2 + nframes)
# Add positional encodings
xcat_pe = self.sequence_pos_encoder(xcat) # (2 + nframes, batch, latent_dim)
# Transformer Encoder
encoded = self.seqTransEncoder(
xcat_pe,
src_key_padding_mask=~mask_seq # True = mask padding
) # (2 + nframes, batch, latent_dim)
# Final normalization
encoded = self.encoder_norm(encoded)
# Extract mu and logvar (logvar stors in encoded)
mu = encoded[0] # (batch, latent_dim)
logvar = encoded[1] # (batch, latent_dim)
# Reparameterization
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std # (batch, latent_dim)
return {"mu": mu, "logvar": logvar, "z": z}
class Decoder_TRANSFORMER(nn.Module):
"""
Decoder module using Transformer architecture with KAN-layer:
- KANLayer: Final projection layer for skeleton reconstruction
- Transformer Decoder: Autoregressive generation of sequences
"""
def __init__(
self,
modeltype,
njoints: int,
nfeats: int,
num_frames: int,
num_classes: int,
translation,
pose_rep,
glob,
glob_rot,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
**kargs
):
super().__init__()
self.njoints = njoints
self.nfeats = nfeats
self.num_frames = num_frames
self.num_classes = num_classes
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.translation = translation
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.input_feats = self.njoints * self.nfeats
# Bias parameters for action-specific generation
self.actionBiases = nn.Parameter(torch.randn(1, self.latent_dim))
# Positional Encoding for temporal queries
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
# Transformer Decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation
)
self.seqTransDecoder = nn.TransformerDecoder(decoder_layer, num_layers=self.num_layers)
self.decoder_norm = nn.LayerNorm(self.latent_dim) # Final normalization
# Final KANLayer for skeleton reconstruction:
# Input: latent_dim
# Output: input_feats (reconstructed joint features)
self.finallayer = KANLayer(self.latent_dim, self.input_feats)
def forward(self, batch: dict, use_text_emb: bool = False) -> dict:
"""
Forward pass for the decoder.
Args:
batch: Dictionary containing latent codes and metadata
use_text_emb: Whether to use text embeddings instead of latent codes
Returns:
Dictionary with generated output
"""
z = batch["z"] # Latent code: (batch, latent_dim)
y = batch["y"]
mask = batch["mask"] # (batch, nframes)
lengths = batch.get("lengths", None)
bs, nframes = mask.shape
nj, nf = self.njoints, self.nfeats
# Use text embeddings if specified
if use_text_emb:
z = batch["clip_text_emb"] # (batch, latent_dim)
# Normalize latent code
z = F.layer_norm(z, (self.latent_dim,)) # (batch, latent_dim)
z = z.unsqueeze(0) # (1, batch, latent_dim) — memory for decoder
# Generate time queries: (nframes, batch, latent_dim)
timequeries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
# Add positional encodings
timequeries_pe = self.sequence_pos_encoder(timequeries)
# Ensure mask is boolean
if mask.dtype != torch.bool:
mask = mask.bool()
# Transformer Decoder
dec_out = self.seqTransDecoder(
tgt=timequeries_pe,
memory=z,
tgt_key_padding_mask=~mask
) # (nframes, batch, latent_dim)
# Final normalization of the output of decoder
dec_out = self.decoder_norm(dec_out) # (nframes, batch, latent_dim)
# Transforming decoder output via KANLayer into skeletal features (reconstruct)
skel_feats = self.finallayer(dec_out) # (nframes, batch, input_feats)
skel_feats = skel_feats.view(nframes, bs, nj, nf) # (nframes, batch, njoints, nfeats) --> Reshape to joints
# Apply mask to zero out padding
mask_t = mask.T # (nframes, batch)
skel_feats[~mask_t] = 0.0
# Final output format: (batch, njoints, nfeats, nframes)
output = skel_feats.permute(1, 2, 3, 0).contiguous()
if use_text_emb:
batch["txt_output"] = output
else:
batch["output"] = output
return batch |