smog / src /models /architectures /transformer.py
vonexel's picture
add: src
fe64bad verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .kan import KANLayer
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
"""
Standard positional encoding with Sin/Cos functions + LayerNorm to preserve
temporal relationships between frames throughtout sequence-modeling.
"""
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.d_model = d_model
# Precompute positional encodings (PE) using sinusoidal functions
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1) # (max_len, 1, d_model)
self.register_buffer("pe", pe)
self.norm_pe = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (seq_len, batch_size, d_model)
Returns:
Tensor with positional encodings added and normalized
"""
seq_len = x.size(0)
x2 = x + self.pe[:seq_len, :] # Add positional encodings
x2 = self.norm_pe(x2) # Normalize
return self.dropout(x2)
class Encoder_TRANSFORMER(nn.Module):
"""
Encoder module using Transformer architecture with KAN layers.
Key components:
- KANLayer which eplaces linear projections with learnable 1D splines;
- Transformer Encoder processing temporal dependencies.
"""
def __init__(
self,
modeltype,
njoints: int,
nfeats: int,
num_frames: int,
num_classes: int,
translation,
pose_rep,
glob,
glob_rot,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
**kargs
):
super().__init__()
self.njoints = njoints
self.nfeats = nfeats
self.num_frames = num_frames
self.num_classes = num_classes
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.translation = translation
self.latent_dim = latent_dim # Latent space dimensionality
self.ff_size = ff_size # Feedforward network size
self.num_layers = num_layers # Transformer layers
self.num_heads = num_heads # Multi-head attention heads
self.dropout = dropout
self.activation = activation
self.input_feats = self.njoints * self.nfeats # Input feature dimension
# Learnable parameters for μ and σ (variational posterior)
self.muQuery = nn.Parameter(torch.randn(1, self.latent_dim))
self.sigmaQuery = nn.Parameter(torch.randn(1, self.latent_dim))
# KANLayer for skeleton embedding:
# Input: njoints * nfeats (flattened joint features)
# Output: latent_dim (compressed representation)
# KANLayer replaces linear projections with a matrix of 1D B-splines
self.skelEmbedding = KANLayer(self.input_feats, self.latent_dim)
# Positional Encoding for temporal alignment
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
# Transformer Encoder with multi-head attention
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation
)
self.seqTransEncoder = nn.TransformerEncoder(encoder_layer, num_layers=self.num_layers)
self.encoder_norm = nn.LayerNorm(self.latent_dim) # Final normalization
def forward(self, batch: dict) -> dict:
"""
batch["x"]: (batch, njoints, nfeats, nframes)
batch["y"]: (batch,) — classes (if none, then == 0)
batch["mask"]: (batch, nframes) — bool-mask of actual frames
"""
x, y, mask = batch["x"], batch["y"], batch["mask"]
bs, nj, nf, nf2 = x.shape # nf2 = nframes
assert nf2 == self.num_frames, "Frame dimension mismatch"
# Reshape input: (nframes, batch, njoints*nfeats)
x_seq = x.permute(3, 0, 1, 2).reshape(self.num_frames, bs, self.input_feats)
# Applies learnable 1D splines to input features
x_emb = self.skelEmbedding(x_seq) # (nframes, batch, latent_dim)
# Handle class labels (y)
if y is None:
y = torch.zeros(bs, dtype=torch.long, device=x.device)
else:
y = y.clamp(0, self.num_classes - 1)
# Initialize mu and sigma queries:
mu_init = self.muQuery.expand(bs, -1) # (batch, latent_dim)
sigma_init = self.sigmaQuery.expand(bs, -1) # (batch, latent_dim)
# Concatenate [mu, sigma, x_emb] for Transformer input
mu_init = mu_init.unsqueeze(0) # (1, batch, latent_dim)
sigma_init = sigma_init.unsqueeze(0) # (1, batch, latent_dim)
xcat = torch.cat((mu_init, sigma_init, x_emb), dim=0) # (2 + nframes, batch, latent_dim)
# Update mask for mu/sigma
mu_sigma_mask = torch.ones((bs, 2), dtype=torch.bool, device=x.device)
mask_seq = torch.cat((mu_sigma_mask, mask), dim=1) # (batch, 2 + nframes)
# Add positional encodings
xcat_pe = self.sequence_pos_encoder(xcat) # (2 + nframes, batch, latent_dim)
# Transformer Encoder
encoded = self.seqTransEncoder(
xcat_pe,
src_key_padding_mask=~mask_seq # True = mask padding
) # (2 + nframes, batch, latent_dim)
# Final normalization
encoded = self.encoder_norm(encoded)
# Extract mu and logvar (logvar stors in encoded)
mu = encoded[0] # (batch, latent_dim)
logvar = encoded[1] # (batch, latent_dim)
# Reparameterization
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std # (batch, latent_dim)
return {"mu": mu, "logvar": logvar, "z": z}
class Decoder_TRANSFORMER(nn.Module):
"""
Decoder module using Transformer architecture with KAN-layer:
- KANLayer: Final projection layer for skeleton reconstruction
- Transformer Decoder: Autoregressive generation of sequences
"""
def __init__(
self,
modeltype,
njoints: int,
nfeats: int,
num_frames: int,
num_classes: int,
translation,
pose_rep,
glob,
glob_rot,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
**kargs
):
super().__init__()
self.njoints = njoints
self.nfeats = nfeats
self.num_frames = num_frames
self.num_classes = num_classes
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.translation = translation
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.input_feats = self.njoints * self.nfeats
# Bias parameters for action-specific generation
self.actionBiases = nn.Parameter(torch.randn(1, self.latent_dim))
# Positional Encoding for temporal queries
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
# Transformer Decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation
)
self.seqTransDecoder = nn.TransformerDecoder(decoder_layer, num_layers=self.num_layers)
self.decoder_norm = nn.LayerNorm(self.latent_dim) # Final normalization
# Final KANLayer for skeleton reconstruction:
# Input: latent_dim
# Output: input_feats (reconstructed joint features)
self.finallayer = KANLayer(self.latent_dim, self.input_feats)
def forward(self, batch: dict, use_text_emb: bool = False) -> dict:
"""
Forward pass for the decoder.
Args:
batch: Dictionary containing latent codes and metadata
use_text_emb: Whether to use text embeddings instead of latent codes
Returns:
Dictionary with generated output
"""
z = batch["z"] # Latent code: (batch, latent_dim)
y = batch["y"]
mask = batch["mask"] # (batch, nframes)
lengths = batch.get("lengths", None)
bs, nframes = mask.shape
nj, nf = self.njoints, self.nfeats
# Use text embeddings if specified
if use_text_emb:
z = batch["clip_text_emb"] # (batch, latent_dim)
# Normalize latent code
z = F.layer_norm(z, (self.latent_dim,)) # (batch, latent_dim)
z = z.unsqueeze(0) # (1, batch, latent_dim) — memory for decoder
# Generate time queries: (nframes, batch, latent_dim)
timequeries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
# Add positional encodings
timequeries_pe = self.sequence_pos_encoder(timequeries)
# Ensure mask is boolean
if mask.dtype != torch.bool:
mask = mask.bool()
# Transformer Decoder
dec_out = self.seqTransDecoder(
tgt=timequeries_pe,
memory=z,
tgt_key_padding_mask=~mask
) # (nframes, batch, latent_dim)
# Final normalization of the output of decoder
dec_out = self.decoder_norm(dec_out) # (nframes, batch, latent_dim)
# Transforming decoder output via KANLayer into skeletal features (reconstruct)
skel_feats = self.finallayer(dec_out) # (nframes, batch, input_feats)
skel_feats = skel_feats.view(nframes, bs, nj, nf) # (nframes, batch, njoints, nfeats) --> Reshape to joints
# Apply mask to zero out padding
mask_t = mask.T # (nframes, batch)
skel_feats[~mask_t] = 0.0
# Final output format: (batch, njoints, nfeats, nframes)
output = skel_feats.permute(1, 2, 3, 0).contiguous()
if use_text_emb:
batch["txt_output"] = output
else:
batch["output"] = output
return batch