File size: 19,934 Bytes
ceeabec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
import logging
from dataclasses import dataclass, field
from typing import Optional


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

import numpy as np
import random
import os
import sys

from fairseq.data.data_utils import compute_mask_indices
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec import (
    Wav2Vec2Config,  
    TransformerEncoder,
)

# Debug print to show where Wav2Vec2Config is defined
print(f"Wav2Vec2Config is imported from: {Wav2Vec2Config.__module__}")
print(f"Full path: {sys.modules[Wav2Vec2Config.__module__].__file__}")

from fairseq.modules import (
    LayerNorm,
)

logger = logging.getLogger(__name__)


@dataclass
class SignHubertConfig(Wav2Vec2Config):
    # pos_conv_kernel: int = field(default=32)
    conv_pos: int = field(default=32)
    discrete: bool = field(default=False)
    codebook_size: int = field(default=256)
    channels_embed_dim: int = field(default=384)
    channels_pose_embed_dim: int = field(default=14)
    intermediate_dim: int = field(default=1024)  # This will be overridden if needed
    mask_strategy: str = field(default="random")
    channels: str = field(default="face,left_hand,right_hand,body_posture")
    

@register_model("signhubert_onlyhands", dataclass=SignHubertConfig)
class SignHubertModel(BaseFairseqModel):
    def __init__(self, cfg: SignHubertConfig):
        super().__init__()
        self.cfg = cfg
        # print(cfg)
        self.discrete = cfg.discrete  # since it's hubert this will always be discrete anyways

        self.embed = cfg.encoder_embed_dim # whether it is small(384), base(768), large, etc.
        self.channel_embed = cfg.channels_embed_dim  # embedding dimension for face, left_hand and right_hand (default: 384)
        self.channel_pose_embed = cfg.channels_pose_embed_dim  # embedding dimension for pose (default: 14) 
        self.intermediate_dim = cfg.intermediate_dim  # intermediate dimension before the projection layer to encoder_embed_dim (default: 1024)
        
        self.channels = cfg.channels.split(",")
        
        self.post_extract_proj = nn.Linear(cfg.intermediate_dim, cfg.encoder_embed_dim)  # 4 channels concatenated

        self.mask_prob = cfg.mask_prob
        self.mask_selection = cfg.mask_selection
        self.mask_strategy = cfg.mask_strategy
        self.mask_other = cfg.mask_other
        self.mask_length = cfg.mask_length
        self.no_mask_overlap = cfg.no_mask_overlap
        self.mask_min_space = cfg.mask_min_space

        self.mask_channel_prob = cfg.mask_channel_prob
        self.mask_channel_before = cfg.mask_channel_before
        self.mask_channel_selection = cfg.mask_channel_selection
        self.mask_channel_other = cfg.mask_channel_other
        self.mask_channel_length = cfg.mask_channel_length
        self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
        self.mask_channel_min_space = cfg.mask_channel_min_space

        self.dropout_input = nn.Dropout(cfg.dropout_input)
        self.dropout_features = nn.Dropout(cfg.dropout_features)

        self.feature_grad_mult = cfg.feature_grad_mult

        self.mask_emb = nn.Parameter(
            torch.FloatTensor(1, 1, 1, cfg.intermediate_dim // len(self.channels)).uniform_()
        )

        self.encoder = TransformerEncoder(cfg)
        self.layer_norm = LayerNorm(self.channel_embed * len(self.channels))
        
        
        if "face" in self.channels:
            self.layer_norm_face = LayerNorm(self.channel_embed)
            self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
        if "left_hand" in self.channels:
            self.layer_norm_lhand = LayerNorm(self.channel_embed)
            self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
        if "right_hand" in self.channels:
            self.layer_norm_rhand = LayerNorm(self.channel_embed)
            self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
        if "body_posture" in self.channels:
            self.layer_norm_body = LayerNorm(self.channel_pose_embed)
            self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // len(self.channels))

        self.codebook_size = cfg.codebook_size # number of codebook vectors
        
        self.heads = []
        for i in range(len(self.channels)):
            self.heads.append(nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size))
        
        self.heads = torch.nn.ModuleList(self.heads)
        
        # self.heads = torch.nn.ModuleList([
        #         nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size) ,
        #         nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
        #         nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
        #     ]
        # )



        # # Define separate linear layers for each channel
        # self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
        # self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
        # self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
        # self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // 4)


    def state_dict(self, destination=None, prefix="", keep_vars=False):

        state = super().state_dict(destination, prefix, keep_vars)

        return state



    @classmethod
    def build_model(cls, cfg: SignHubertConfig, task=None):
        """Build a new model instance."""

        return cls(cfg)

    def apply_mask(
        self,
        x,
        padding_mask,
        mask_indices=None,
        mask_channel_indices=None,
    ):
        B, T, C, D = x.shape

        # Initialize a mask vector with ones (same shape as x)
        mask = torch.ones_like(x)

        # channel masking
        if self.mask_prob > 0 and self.mask_strategy == "channel":
            if mask_indices is None:
                mask_indices = torch.zeros_like(x[:,:,:,0], dtype=bool)
                num_channels_to_mask = int(C * self.mask_prob)
                num_channels_to_mask = max(1, num_channels_to_mask)
                
                for i in range(B):
                    channels_to_mask = np.random.choice(C, num_channels_to_mask, replace=False)
                    mask_indices[i, :, channels_to_mask] = True

            mask[mask_indices.unsqueeze(-1).expand(-1, -1, -1, D)] = 0

        # gloss/time masking
        elif self.mask_prob > 0 and self.mask_strategy == "gloss":
            if mask_indices is None:
                mask_indices_channel = compute_mask_indices(
                    (B, T),
                    padding_mask,
                    self.mask_prob,
                    self.mask_length,
                    self.mask_selection,
                    self.mask_other,
                    min_masks=1,
                    no_overlap=self.no_mask_channel_overlap,
                    min_space=self.mask_min_space,
                    require_same_masks=self.cfg.require_same_masks,
                    mask_dropout=self.cfg.mask_dropout,
                )
                mask_indices_channel = torch.from_numpy(mask_indices_channel).to(x.device)

            # Apply the same mask to all channels
            mask_indices = mask_indices_channel.unsqueeze(2).expand(-1, -1, C)
            mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
            mask[mask_indices] = 0

        # random masking        
        elif self.mask_prob > 0 and self.mask_strategy == "random":
            if mask_indices is None:
                mask_indices = compute_mask_indices(
                    (B, T*C),  # Note: T*C instead of T
                    padding_mask,
                    self.mask_prob,
                    self.mask_length,
                    self.mask_selection,
                    self.mask_other,
                    min_masks=1,
                    no_overlap=self.no_mask_channel_overlap,
                    min_space=self.mask_min_space,
                    require_same_masks=self.cfg.require_same_masks,
                    mask_dropout=self.cfg.mask_dropout,
                )
                mask_indices = torch.from_numpy(mask_indices).to(x.device)
            mask_indices = mask_indices.view(B, T, C)
            mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
            mask[mask_indices] = 0        
        else:
            raise ValueError(f"unknown mask strategy {self.mask_strategy}")

        # Apply the mask to x and return the masked tensor with the same shape as x
        # x = x * mask
        x = x * mask + self.mask_emb * (1 - mask)

        return x, mask
        # mask is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked

    
    def forward(
        self,
        source,
        padding_mask=None,
        mask=True,
        features_only=False,
        layer=None,
        mask_indices=None,
        mask_channel_indices=None,
        padding_count=None,
        kmeans_labels=None,  
        ):
        
        channels_to_use = []
        for c in self.channels:
            if c in source[0]:
                channels_to_use.append(c)
                
        for c in channels_to_use:
            if c == "face":
                face_features_list = []
                label_face_features_list = []
            elif c == "left_hand":
                left_hand_features_list = []
                label_left_hand_features_list = []
            elif c == "right_hand":
                right_hand_features_list = []
                label_right_hand_features_list = []
            elif c == "body_posture":
                body_posture_features_list = []
                label_body_posture_features_list = []

        # # source is a list of dictionaries with keys "face", "left_hand", "right_hand", "body_posture"
        # face_features_list = []
        # left_hand_features_list = []
        # right_hand_features_list = []
        # body_posture_features_list = []
        # label_face_features_list = []
        # label_left_hand_features_list = []
        # label_right_hand_features_list = []
        # label_body_posture_features_list = []

        # for sample in source:
        #     face_features_list.append(sample["face"])   # Tx384
        #     left_hand_features_list.append(sample["left_hand"]) # Tx384
        #     right_hand_features_list.append(sample["right_hand"])   # Tx384
        #     body_posture_features_list.append(sample["body_posture"])   # Tx14
        #     label_face_features_list.append(sample["label_face"])   # Tx1
        #     label_left_hand_features_list.append(sample["label_left_hand"])   # Tx1
        #     label_right_hand_features_list.append(sample["label_right_hand"])   # Tx1
        #     label_body_posture_features_list.append(sample["label_body_posture"])   # Tx1
            
        for sample in source:
            for c in channels_to_use:
                if c == "face":
                    face_features_list.append(sample["face"])   # Tx384
                    label_face_features_list.append(sample["label_face"])   # Tx1
                elif c == "left_hand":
                    left_hand_features_list.append(sample["left_hand"]) # Tx384
                    label_left_hand_features_list.append(sample["label_left_hand"])   # Tx1
                elif c == "right_hand":
                    right_hand_features_list.append(sample["right_hand"])   # Tx384
                    label_right_hand_features_list.append(sample["label_right_hand"])   # Tx1
                elif c == "body_posture":
                    body_posture_features_list.append(sample["body_posture"])   # Tx14
                    label_body_posture_features_list.append(sample["label_body_posture"])   # Tx1
                    
    
            

        # face_features = torch.stack(face_features_list) # BxTx384
        # left_hand_features = torch.stack(left_hand_features_list)   # BxTx384
        # right_hand_features = torch.stack(right_hand_features_list) # BxTx384
        # body_posture_features = torch.stack(body_posture_features_list) # BxTx14
        # face_labels = torch.stack(label_face_features_list) # BxTx1
        # left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
        # right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
        # body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
        

        # # Apply layer normalization to each part
        # face_features = self.layer_norm_face(face_features) # BxTx384
        # left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
        # right_hand_features = self.layer_norm_rhand(right_hand_features)    # BxTx384
        # body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14

        # # Apply separate linear projections for each channel
        # face_features = self.face_proj(face_features) # BxTx256
        # left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
        # right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
        # body_posture_features = self.body_posture_proj(body_posture_features)   # BxTx256
        
        features_list = []
        labels_list = []
        
        for c in channels_to_use:
            if c == "face":
                face_features = torch.stack(face_features_list) # BxTx384
                face_labels = torch.stack(label_face_features_list) # BxTx1
                face_features = self.layer_norm_face(face_features) # BxTx384
                face_features = self.face_proj(face_features) # BxTx256
                features_list.append(face_features)
                labels_list.append(face_labels)
            elif c == "left_hand":
                left_hand_features = torch.stack(left_hand_features_list) # BxTx384
                left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
                left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
                left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
                features_list.append(left_hand_features)
                labels_list.append(left_hand_labels)
            elif c == "right_hand":
                right_hand_features = torch.stack(right_hand_features_list) # BxTx384
                right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
                right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384
                right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
                features_list.append(right_hand_features)
                labels_list.append(right_hand_labels)
            elif c == "body_posture":
                body_posture_features = torch.stack(body_posture_features_list) # BxTx14
                body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
                body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14
                body_posture_features = self.body_posture_proj(body_posture_features)   # BxTx256
                features_list.append(body_posture_features)
                labels_list.append(body_posture_labels)
        

        # concatenate the projected features to have dimension BxTxCxD where C=4 and D=256
        # features = torch.stack(
        #     [
        #         face_features,
        #         left_hand_features,
        #         right_hand_features,
        #         body_posture_features
        #     ], 
        #     dim=2) # BxTx4x256
        
        features = torch.stack(features_list, dim=2) # BxTx4x256
        
        if mask:
            x, mask_indices = self.apply_mask(
                features,
                padding_mask,
                mask_indices=mask_indices,
                mask_channel_indices=mask_channel_indices,
            )   
        # mask_indices is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked
        else:
            x = features
            mask_indices = None
            
            
        x = self.dropout_input(x) # BxTx4x256

        x = x.view(x.size(0), x.size(1), -1)  # BxTx1024
        if self.post_extract_proj is not None:
            x = self.post_extract_proj(x)  # BxTx768

        x, layer_results = self.encoder(
            x,
            padding_mask=padding_mask,
            layer=layer,
        )

        if features_only:
            return {
                "x": x,
                "padding_mask": padding_mask,
                "layer_results": layer_results,
            }

        result = {
            "losses": {},
        }
    
        # use linear heads to compute the discrete prediction for each channel and make it into a single tensor of shape BxTxCxcodebook_size        
        predictions = []
        for i, head in enumerate(self.heads):
            channel_pred = head(x)  # BxTxcodebook_size
            predictions.append(channel_pred)
        predictions = torch.stack(predictions, dim=2)  # BxTx4xcodebook_size

        # labels = torch.stack(
        #     [
        #         face_labels,
        #         left_hand_labels,
        #         right_hand_labels,
        #         body_posture_labels
        #     ], 
        #     dim=2) # BxTx4x1
        
        labels = torch.stack(labels_list, dim=2) # BxTx4x1
        # print(f"predictions shape: {predictions.shape} and labels shape: {labels.shape}")

        predictions_flat = predictions.view(-1, self.codebook_size)  # Shape: (B * T * C, codebook_size)
        labels_flat = labels.view(-1)  # Shape: (B * T * C)

        # Ensure labels are of correct shape
        labels_flat = labels_flat.squeeze(-1)  # Remove the last dimension if it's size 1

        # Correct the mask_indices to match the shape of predictions_flat
        mask_indices_reduced = mask_indices.any(dim=-1)  # Reduce mask to (B, T, C) by collapsing last dimension
        mask_indices_flat = mask_indices_reduced.view(-1)  # Flatten to match the shape of (B * T * C)

        # Calculate the loss only for the masked positions (where mask_indices_flat is zero)
        masked_loss = F.cross_entropy(
            predictions_flat[mask_indices_flat == 0],
            labels_flat[mask_indices_flat == 0],
            reduction='none'
        )

        # Store the result
        result['losses']['kmeans_loss'] = masked_loss

        

        if "sample_size" not in result:
            result['sample_size'] = masked_loss.numel()

        return result

    @staticmethod
    def compute_var(y):
        y = y.view(-1, y.size(-1))
        if dist.is_initialized():
            zc = torch.tensor(y.size(0)).cuda()
            zs = y.sum(dim=0)
            zss = (y ** 2).sum(dim=0)

            dist.all_reduce(zc)
            dist.all_reduce(zs)
            dist.all_reduce(zss)

            var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
            return torch.sqrt(var + 1e-6).mean()
        else:
            return torch.sqrt(y.var(dim=0) + 1e-6).mean()

    def extract_features(
        self, source, padding_mask, kmeans_labels, mask=False, layer=None
    ):
        res = self.forward(
            source,
            padding_mask,
            mask=mask,
            features_only=True,
            layer=layer,
            kmeans_labels=kmeans_labels,
        )
        return res

    def remove_pretraining_modules(self, last_layer=None):
        self.heads = None
        self.final_proj = None
        if last_layer is not None:
            self.encoder.layers = nn.ModuleList(
                l for i, l in enumerate(self.encoder.layers) if i <= last_layer
            )