File size: 11,335 Bytes
fe64bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .kan import KANLayer


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        """

        Standard positional encoding with Sin/Cos functions + LayerNorm to preserve 

        temporal relationships between frames throughtout sequence-modeling.

        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.d_model = d_model

        # Precompute positional encodings (PE) using sinusoidal functions
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # (max_len, 1, d_model)
        self.register_buffer("pe", pe)
        self.norm_pe = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Args:

            x: Input tensor of shape (seq_len, batch_size, d_model)

        Returns:

            Tensor with positional encodings added and normalized

        """
        seq_len = x.size(0)
        x2 = x + self.pe[:seq_len, :]  # Add positional encodings
        x2 = self.norm_pe(x2)          # Normalize
        return self.dropout(x2)


class Encoder_TRANSFORMER(nn.Module):
    """

    Encoder module using Transformer architecture with KAN layers.

    Key components:

    - KANLayer which eplaces linear projections with learnable 1D splines;

    - Transformer Encoder processing temporal dependencies.

        """
    def __init__(

        self,

        modeltype,

        njoints: int,

        nfeats: int,

        num_frames: int,

        num_classes: int,

        translation,

        pose_rep,

        glob,

        glob_rot,

        latent_dim: int = 256,              

        ff_size: int = 1024,                

        num_layers: int = 4,                

        num_heads: int = 4,                 

        dropout: float = 0.1,

        activation: str = "gelu",

        **kargs

    ):
        super().__init__()
        self.njoints = njoints
        self.nfeats = nfeats
        self.num_frames = num_frames
        self.num_classes = num_classes
        self.pose_rep = pose_rep
        self.glob = glob
        self.glob_rot = glob_rot
        self.translation = translation

        self.latent_dim = latent_dim                            # Latent space dimensionality
        self.ff_size = ff_size                                  # Feedforward network size
        self.num_layers = num_layers                            # Transformer layers
        self.num_heads = num_heads                              # Multi-head attention heads
        self.dropout = dropout
        self.activation = activation

        self.input_feats = self.njoints * self.nfeats           # Input feature dimension

        # Learnable parameters for μ and σ (variational posterior)
        self.muQuery = nn.Parameter(torch.randn(1, self.latent_dim))
        self.sigmaQuery = nn.Parameter(torch.randn(1, self.latent_dim))

        # KANLayer for skeleton embedding:
        # Input: njoints * nfeats (flattened joint features)
        # Output: latent_dim (compressed representation)
        # KANLayer replaces linear projections with a matrix of 1D B-splines
        self.skelEmbedding = KANLayer(self.input_feats, self.latent_dim)

        # Positional Encoding for temporal alignment
        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)

        # Transformer Encoder with multi-head attention
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.latent_dim,
            nhead=self.num_heads,
            dim_feedforward=self.ff_size,
            dropout=self.dropout,
            activation=self.activation
        )
        self.seqTransEncoder = nn.TransformerEncoder(encoder_layer, num_layers=self.num_layers)
        self.encoder_norm = nn.LayerNorm(self.latent_dim)  # Final normalization

    def forward(self, batch: dict) -> dict:
        """

        batch["x"]: (batch, njoints, nfeats, nframes)

        batch["y"]: (batch,)  — classes (if none, then == 0)

        batch["mask"]: (batch, nframes) — bool-mask of actual frames

        """
        x, y, mask = batch["x"], batch["y"], batch["mask"]
        bs, nj, nf, nf2 = x.shape  # nf2 = nframes
        assert nf2 == self.num_frames, "Frame dimension mismatch"

        # Reshape input: (nframes, batch, njoints*nfeats)
        x_seq = x.permute(3, 0, 1, 2).reshape(self.num_frames, bs, self.input_feats)

        # Applies learnable 1D splines to input features
        x_emb = self.skelEmbedding(x_seq)  # (nframes, batch, latent_dim)

        # Handle class labels (y)
        if y is None:
            y = torch.zeros(bs, dtype=torch.long, device=x.device)
        else:
            y = y.clamp(0, self.num_classes - 1)

        # Initialize mu and sigma queries:
        mu_init = self.muQuery.expand(bs, -1)       # (batch, latent_dim)
        sigma_init = self.sigmaQuery.expand(bs, -1) # (batch, latent_dim)

        # Concatenate [mu, sigma, x_emb] for Transformer input
        mu_init = mu_init.unsqueeze(0)       # (1, batch, latent_dim)
        sigma_init = sigma_init.unsqueeze(0) # (1, batch, latent_dim)
        xcat = torch.cat((mu_init, sigma_init, x_emb), dim=0)  # (2 + nframes, batch, latent_dim)

        # Update mask for mu/sigma
        mu_sigma_mask = torch.ones((bs, 2), dtype=torch.bool, device=x.device)
        mask_seq = torch.cat((mu_sigma_mask, mask), dim=1)  # (batch, 2 + nframes)

        # Add positional encodings
        xcat_pe = self.sequence_pos_encoder(xcat)  # (2 + nframes, batch, latent_dim)

        # Transformer Encoder
        encoded = self.seqTransEncoder(
            xcat_pe,
            src_key_padding_mask=~mask_seq  # True = mask padding
        )  # (2 + nframes, batch, latent_dim)

        # Final normalization
        encoded = self.encoder_norm(encoded)

        # Extract mu and logvar (logvar stors in encoded)
        mu = encoded[0]      # (batch, latent_dim)
        logvar = encoded[1]  # (batch, latent_dim)

        # Reparameterization
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std  # (batch, latent_dim)

        return {"mu": mu, "logvar": logvar, "z": z}


class Decoder_TRANSFORMER(nn.Module):
    """

    Decoder module using Transformer architecture with KAN-layer:

    - KANLayer: Final projection layer for skeleton reconstruction

    - Transformer Decoder: Autoregressive generation of sequences

    """
    def __init__(

        self,

        modeltype,

        njoints: int,

        nfeats: int,

        num_frames: int,

        num_classes: int,

        translation,

        pose_rep,

        glob,

        glob_rot,

        latent_dim: int = 256,

        ff_size: int = 1024,

        num_layers: int = 4,

        num_heads: int = 4,

        dropout: float = 0.1,

        activation: str = "gelu",

        **kargs

    ):
        super().__init__()

        self.njoints = njoints
        self.nfeats = nfeats
        self.num_frames = num_frames
        self.num_classes = num_classes
        self.pose_rep = pose_rep
        self.glob = glob
        self.glob_rot = glob_rot
        self.translation = translation

        self.latent_dim = latent_dim
        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.activation = activation

        self.input_feats = self.njoints * self.nfeats

        # Bias parameters for action-specific generation
        self.actionBiases = nn.Parameter(torch.randn(1, self.latent_dim))

        # Positional Encoding for temporal queries
        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)

        # Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.latent_dim,
            nhead=self.num_heads,
            dim_feedforward=self.ff_size,
            dropout=self.dropout,
            activation=self.activation
        )
        self.seqTransDecoder = nn.TransformerDecoder(decoder_layer, num_layers=self.num_layers)
        self.decoder_norm = nn.LayerNorm(self.latent_dim)  # Final normalization

        # Final KANLayer for skeleton reconstruction:
        # Input: latent_dim
        # Output: input_feats (reconstructed joint features)
        self.finallayer = KANLayer(self.latent_dim, self.input_feats)

    def forward(self, batch: dict, use_text_emb: bool = False) -> dict:
        """

        Forward pass for the decoder.

        Args:

            batch: Dictionary containing latent codes and metadata

            use_text_emb: Whether to use text embeddings instead of latent codes

        Returns:

            Dictionary with generated output

        """
        z = batch["z"]  # Latent code: (batch, latent_dim)
        y = batch["y"]
        mask = batch["mask"]  # (batch, nframes)
        lengths = batch.get("lengths", None)
        bs, nframes = mask.shape
        nj, nf = self.njoints, self.nfeats

        # Use text embeddings if specified
        if use_text_emb:
            z = batch["clip_text_emb"]  # (batch, latent_dim)

        # Normalize latent code
        z = F.layer_norm(z, (self.latent_dim,))  # (batch, latent_dim)
        z = z.unsqueeze(0)  # (1, batch, latent_dim) — memory for decoder

        # Generate time queries: (nframes, batch, latent_dim)
        timequeries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)

        # Add positional encodings
        timequeries_pe = self.sequence_pos_encoder(timequeries)

        # Ensure mask is boolean
        if mask.dtype != torch.bool:
            mask = mask.bool()

        # Transformer Decoder
        dec_out = self.seqTransDecoder(
            tgt=timequeries_pe,
            memory=z,
            tgt_key_padding_mask=~mask  
        )  # (nframes, batch, latent_dim)

        # Final normalization of the output of decoder
        dec_out = self.decoder_norm(dec_out)  # (nframes, batch, latent_dim)

        # Transforming decoder output via KANLayer into skeletal features (reconstruct)
        skel_feats = self.finallayer(dec_out)  # (nframes, batch, input_feats)
        skel_feats = skel_feats.view(nframes, bs, nj, nf)  # (nframes, batch, njoints, nfeats) --> Reshape to joints

        # Apply mask to zero out padding
        mask_t = mask.T  # (nframes, batch)
        skel_feats[~mask_t] = 0.0

        # Final output format: (batch, njoints, nfeats, nframes)
        output = skel_feats.permute(1, 2, 3, 0).contiguous()

        if use_text_emb:
            batch["txt_output"] = output
        else:
            batch["output"] = output

        return batch