ShesterG's picture
Add application file
ceeabec
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
import random
import os
import sys
from fairseq.data.data_utils import compute_mask_indices
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec import (
Wav2Vec2Config,
TransformerEncoder,
)
# Debug print to show where Wav2Vec2Config is defined
print(f"Wav2Vec2Config is imported from: {Wav2Vec2Config.__module__}")
print(f"Full path: {sys.modules[Wav2Vec2Config.__module__].__file__}")
from fairseq.modules import (
LayerNorm,
)
logger = logging.getLogger(__name__)
@dataclass
class SignHubertConfig(Wav2Vec2Config):
# pos_conv_kernel: int = field(default=32)
conv_pos: int = field(default=32)
discrete: bool = field(default=False)
codebook_size: int = field(default=256)
channels_embed_dim: int = field(default=384)
channels_pose_embed_dim: int = field(default=14)
intermediate_dim: int = field(default=1024) # This will be overridden if needed
mask_strategy: str = field(default="random")
channels: str = field(default="face,left_hand,right_hand,body_posture")
@register_model("signhubert_onlyhands", dataclass=SignHubertConfig)
class SignHubertModel(BaseFairseqModel):
def __init__(self, cfg: SignHubertConfig):
super().__init__()
self.cfg = cfg
# print(cfg)
self.discrete = cfg.discrete # since it's hubert this will always be discrete anyways
self.embed = cfg.encoder_embed_dim # whether it is small(384), base(768), large, etc.
self.channel_embed = cfg.channels_embed_dim # embedding dimension for face, left_hand and right_hand (default: 384)
self.channel_pose_embed = cfg.channels_pose_embed_dim # embedding dimension for pose (default: 14)
self.intermediate_dim = cfg.intermediate_dim # intermediate dimension before the projection layer to encoder_embed_dim (default: 1024)
self.channels = cfg.channels.split(",")
self.post_extract_proj = nn.Linear(cfg.intermediate_dim, cfg.encoder_embed_dim) # 4 channels concatenated
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_strategy = cfg.mask_strategy
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_before = cfg.mask_channel_before
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.mask_emb = nn.Parameter(
torch.FloatTensor(1, 1, 1, cfg.intermediate_dim // len(self.channels)).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.channel_embed * len(self.channels))
if "face" in self.channels:
self.layer_norm_face = LayerNorm(self.channel_embed)
self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
if "left_hand" in self.channels:
self.layer_norm_lhand = LayerNorm(self.channel_embed)
self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
if "right_hand" in self.channels:
self.layer_norm_rhand = LayerNorm(self.channel_embed)
self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
if "body_posture" in self.channels:
self.layer_norm_body = LayerNorm(self.channel_pose_embed)
self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // len(self.channels))
self.codebook_size = cfg.codebook_size # number of codebook vectors
self.heads = []
for i in range(len(self.channels)):
self.heads.append(nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size))
self.heads = torch.nn.ModuleList(self.heads)
# self.heads = torch.nn.ModuleList([
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size) ,
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
# ]
# )
# # Define separate linear layers for each channel
# self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
# self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
# self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
# self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // 4)
def state_dict(self, destination=None, prefix="", keep_vars=False):
state = super().state_dict(destination, prefix, keep_vars)
return state
@classmethod
def build_model(cls, cfg: SignHubertConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def apply_mask(
self,
x,
padding_mask,
mask_indices=None,
mask_channel_indices=None,
):
B, T, C, D = x.shape
# Initialize a mask vector with ones (same shape as x)
mask = torch.ones_like(x)
# channel masking
if self.mask_prob > 0 and self.mask_strategy == "channel":
if mask_indices is None:
mask_indices = torch.zeros_like(x[:,:,:,0], dtype=bool)
num_channels_to_mask = int(C * self.mask_prob)
num_channels_to_mask = max(1, num_channels_to_mask)
for i in range(B):
channels_to_mask = np.random.choice(C, num_channels_to_mask, replace=False)
mask_indices[i, :, channels_to_mask] = True
mask[mask_indices.unsqueeze(-1).expand(-1, -1, -1, D)] = 0
# gloss/time masking
elif self.mask_prob > 0 and self.mask_strategy == "gloss":
if mask_indices is None:
mask_indices_channel = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=1,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_min_space,
require_same_masks=self.cfg.require_same_masks,
mask_dropout=self.cfg.mask_dropout,
)
mask_indices_channel = torch.from_numpy(mask_indices_channel).to(x.device)
# Apply the same mask to all channels
mask_indices = mask_indices_channel.unsqueeze(2).expand(-1, -1, C)
mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
mask[mask_indices] = 0
# random masking
elif self.mask_prob > 0 and self.mask_strategy == "random":
if mask_indices is None:
mask_indices = compute_mask_indices(
(B, T*C), # Note: T*C instead of T
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=1,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_min_space,
require_same_masks=self.cfg.require_same_masks,
mask_dropout=self.cfg.mask_dropout,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
mask_indices = mask_indices.view(B, T, C)
mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
mask[mask_indices] = 0
else:
raise ValueError(f"unknown mask strategy {self.mask_strategy}")
# Apply the mask to x and return the masked tensor with the same shape as x
# x = x * mask
x = x * mask + self.mask_emb * (1 - mask)
return x, mask
# mask is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked
def forward(
self,
source,
padding_mask=None,
mask=True,
features_only=False,
layer=None,
mask_indices=None,
mask_channel_indices=None,
padding_count=None,
kmeans_labels=None,
):
channels_to_use = []
for c in self.channels:
if c in source[0]:
channels_to_use.append(c)
for c in channels_to_use:
if c == "face":
face_features_list = []
label_face_features_list = []
elif c == "left_hand":
left_hand_features_list = []
label_left_hand_features_list = []
elif c == "right_hand":
right_hand_features_list = []
label_right_hand_features_list = []
elif c == "body_posture":
body_posture_features_list = []
label_body_posture_features_list = []
# # source is a list of dictionaries with keys "face", "left_hand", "right_hand", "body_posture"
# face_features_list = []
# left_hand_features_list = []
# right_hand_features_list = []
# body_posture_features_list = []
# label_face_features_list = []
# label_left_hand_features_list = []
# label_right_hand_features_list = []
# label_body_posture_features_list = []
# for sample in source:
# face_features_list.append(sample["face"]) # Tx384
# left_hand_features_list.append(sample["left_hand"]) # Tx384
# right_hand_features_list.append(sample["right_hand"]) # Tx384
# body_posture_features_list.append(sample["body_posture"]) # Tx14
# label_face_features_list.append(sample["label_face"]) # Tx1
# label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1
# label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1
# label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1
for sample in source:
for c in channels_to_use:
if c == "face":
face_features_list.append(sample["face"]) # Tx384
label_face_features_list.append(sample["label_face"]) # Tx1
elif c == "left_hand":
left_hand_features_list.append(sample["left_hand"]) # Tx384
label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1
elif c == "right_hand":
right_hand_features_list.append(sample["right_hand"]) # Tx384
label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1
elif c == "body_posture":
body_posture_features_list.append(sample["body_posture"]) # Tx14
label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1
# face_features = torch.stack(face_features_list) # BxTx384
# left_hand_features = torch.stack(left_hand_features_list) # BxTx384
# right_hand_features = torch.stack(right_hand_features_list) # BxTx384
# body_posture_features = torch.stack(body_posture_features_list) # BxTx14
# face_labels = torch.stack(label_face_features_list) # BxTx1
# left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
# right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
# body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
# # Apply layer normalization to each part
# face_features = self.layer_norm_face(face_features) # BxTx384
# left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
# right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384
# body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14
# # Apply separate linear projections for each channel
# face_features = self.face_proj(face_features) # BxTx256
# left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
# right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
# body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256
features_list = []
labels_list = []
for c in channels_to_use:
if c == "face":
face_features = torch.stack(face_features_list) # BxTx384
face_labels = torch.stack(label_face_features_list) # BxTx1
face_features = self.layer_norm_face(face_features) # BxTx384
face_features = self.face_proj(face_features) # BxTx256
features_list.append(face_features)
labels_list.append(face_labels)
elif c == "left_hand":
left_hand_features = torch.stack(left_hand_features_list) # BxTx384
left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
features_list.append(left_hand_features)
labels_list.append(left_hand_labels)
elif c == "right_hand":
right_hand_features = torch.stack(right_hand_features_list) # BxTx384
right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384
right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
features_list.append(right_hand_features)
labels_list.append(right_hand_labels)
elif c == "body_posture":
body_posture_features = torch.stack(body_posture_features_list) # BxTx14
body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14
body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256
features_list.append(body_posture_features)
labels_list.append(body_posture_labels)
# concatenate the projected features to have dimension BxTxCxD where C=4 and D=256
# features = torch.stack(
# [
# face_features,
# left_hand_features,
# right_hand_features,
# body_posture_features
# ],
# dim=2) # BxTx4x256
features = torch.stack(features_list, dim=2) # BxTx4x256
if mask:
x, mask_indices = self.apply_mask(
features,
padding_mask,
mask_indices=mask_indices,
mask_channel_indices=mask_channel_indices,
)
# mask_indices is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked
else:
x = features
mask_indices = None
x = self.dropout_input(x) # BxTx4x256
x = x.view(x.size(0), x.size(1), -1) # BxTx1024
if self.post_extract_proj is not None:
x = self.post_extract_proj(x) # BxTx768
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
layer=layer,
)
if features_only:
return {
"x": x,
"padding_mask": padding_mask,
"layer_results": layer_results,
}
result = {
"losses": {},
}
# use linear heads to compute the discrete prediction for each channel and make it into a single tensor of shape BxTxCxcodebook_size
predictions = []
for i, head in enumerate(self.heads):
channel_pred = head(x) # BxTxcodebook_size
predictions.append(channel_pred)
predictions = torch.stack(predictions, dim=2) # BxTx4xcodebook_size
# labels = torch.stack(
# [
# face_labels,
# left_hand_labels,
# right_hand_labels,
# body_posture_labels
# ],
# dim=2) # BxTx4x1
labels = torch.stack(labels_list, dim=2) # BxTx4x1
# print(f"predictions shape: {predictions.shape} and labels shape: {labels.shape}")
predictions_flat = predictions.view(-1, self.codebook_size) # Shape: (B * T * C, codebook_size)
labels_flat = labels.view(-1) # Shape: (B * T * C)
# Ensure labels are of correct shape
labels_flat = labels_flat.squeeze(-1) # Remove the last dimension if it's size 1
# Correct the mask_indices to match the shape of predictions_flat
mask_indices_reduced = mask_indices.any(dim=-1) # Reduce mask to (B, T, C) by collapsing last dimension
mask_indices_flat = mask_indices_reduced.view(-1) # Flatten to match the shape of (B * T * C)
# Calculate the loss only for the masked positions (where mask_indices_flat is zero)
masked_loss = F.cross_entropy(
predictions_flat[mask_indices_flat == 0],
labels_flat[mask_indices_flat == 0],
reduction='none'
)
# Store the result
result['losses']['kmeans_loss'] = masked_loss
if "sample_size" not in result:
result['sample_size'] = masked_loss.numel()
return result
@staticmethod
def compute_var(y):
y = y.view(-1, y.size(-1))
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
zss = (y ** 2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
def extract_features(
self, source, padding_mask, kmeans_labels, mask=False, layer=None
):
res = self.forward(
source,
padding_mask,
mask=mask,
features_only=True,
layer=layer,
kmeans_labels=kmeans_labels,
)
return res
def remove_pretraining_modules(self, last_layer=None):
self.heads = None
self.final_proj = None
if last_layer is not None:
self.encoder.layers = nn.ModuleList(
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
)