Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from dataclasses import dataclass, field | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
import numpy as np | |
import random | |
import os | |
import sys | |
from fairseq.data.data_utils import compute_mask_indices | |
from fairseq.models import BaseFairseqModel, register_model | |
from fairseq.models.wav2vec import ( | |
Wav2Vec2Config, | |
TransformerEncoder, | |
) | |
# Debug print to show where Wav2Vec2Config is defined | |
print(f"Wav2Vec2Config is imported from: {Wav2Vec2Config.__module__}") | |
print(f"Full path: {sys.modules[Wav2Vec2Config.__module__].__file__}") | |
from fairseq.modules import ( | |
LayerNorm, | |
) | |
logger = logging.getLogger(__name__) | |
class SignHubertConfig(Wav2Vec2Config): | |
# pos_conv_kernel: int = field(default=32) | |
conv_pos: int = field(default=32) | |
discrete: bool = field(default=False) | |
codebook_size: int = field(default=256) | |
channels_embed_dim: int = field(default=384) | |
channels_pose_embed_dim: int = field(default=14) | |
intermediate_dim: int = field(default=1024) # This will be overridden if needed | |
mask_strategy: str = field(default="random") | |
channels: str = field(default="face,left_hand,right_hand,body_posture") | |
class SignHubertModel(BaseFairseqModel): | |
def __init__(self, cfg: SignHubertConfig): | |
super().__init__() | |
self.cfg = cfg | |
# print(cfg) | |
self.discrete = cfg.discrete # since it's hubert this will always be discrete anyways | |
self.embed = cfg.encoder_embed_dim # whether it is small(384), base(768), large, etc. | |
self.channel_embed = cfg.channels_embed_dim # embedding dimension for face, left_hand and right_hand (default: 384) | |
self.channel_pose_embed = cfg.channels_pose_embed_dim # embedding dimension for pose (default: 14) | |
self.intermediate_dim = cfg.intermediate_dim # intermediate dimension before the projection layer to encoder_embed_dim (default: 1024) | |
self.channels = cfg.channels.split(",") | |
self.post_extract_proj = nn.Linear(cfg.intermediate_dim, cfg.encoder_embed_dim) # 4 channels concatenated | |
self.mask_prob = cfg.mask_prob | |
self.mask_selection = cfg.mask_selection | |
self.mask_strategy = cfg.mask_strategy | |
self.mask_other = cfg.mask_other | |
self.mask_length = cfg.mask_length | |
self.no_mask_overlap = cfg.no_mask_overlap | |
self.mask_min_space = cfg.mask_min_space | |
self.mask_channel_prob = cfg.mask_channel_prob | |
self.mask_channel_before = cfg.mask_channel_before | |
self.mask_channel_selection = cfg.mask_channel_selection | |
self.mask_channel_other = cfg.mask_channel_other | |
self.mask_channel_length = cfg.mask_channel_length | |
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap | |
self.mask_channel_min_space = cfg.mask_channel_min_space | |
self.dropout_input = nn.Dropout(cfg.dropout_input) | |
self.dropout_features = nn.Dropout(cfg.dropout_features) | |
self.feature_grad_mult = cfg.feature_grad_mult | |
self.mask_emb = nn.Parameter( | |
torch.FloatTensor(1, 1, 1, cfg.intermediate_dim // len(self.channels)).uniform_() | |
) | |
self.encoder = TransformerEncoder(cfg) | |
self.layer_norm = LayerNorm(self.channel_embed * len(self.channels)) | |
if "face" in self.channels: | |
self.layer_norm_face = LayerNorm(self.channel_embed) | |
self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels)) | |
if "left_hand" in self.channels: | |
self.layer_norm_lhand = LayerNorm(self.channel_embed) | |
self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels)) | |
if "right_hand" in self.channels: | |
self.layer_norm_rhand = LayerNorm(self.channel_embed) | |
self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels)) | |
if "body_posture" in self.channels: | |
self.layer_norm_body = LayerNorm(self.channel_pose_embed) | |
self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // len(self.channels)) | |
self.codebook_size = cfg.codebook_size # number of codebook vectors | |
self.heads = [] | |
for i in range(len(self.channels)): | |
self.heads.append(nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size)) | |
self.heads = torch.nn.ModuleList(self.heads) | |
# self.heads = torch.nn.ModuleList([ | |
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size) , | |
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size), | |
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size), | |
# ] | |
# ) | |
# # Define separate linear layers for each channel | |
# self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4) | |
# self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4) | |
# self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4) | |
# self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // 4) | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
state = super().state_dict(destination, prefix, keep_vars) | |
return state | |
def build_model(cls, cfg: SignHubertConfig, task=None): | |
"""Build a new model instance.""" | |
return cls(cfg) | |
def apply_mask( | |
self, | |
x, | |
padding_mask, | |
mask_indices=None, | |
mask_channel_indices=None, | |
): | |
B, T, C, D = x.shape | |
# Initialize a mask vector with ones (same shape as x) | |
mask = torch.ones_like(x) | |
# channel masking | |
if self.mask_prob > 0 and self.mask_strategy == "channel": | |
if mask_indices is None: | |
mask_indices = torch.zeros_like(x[:,:,:,0], dtype=bool) | |
num_channels_to_mask = int(C * self.mask_prob) | |
num_channels_to_mask = max(1, num_channels_to_mask) | |
for i in range(B): | |
channels_to_mask = np.random.choice(C, num_channels_to_mask, replace=False) | |
mask_indices[i, :, channels_to_mask] = True | |
mask[mask_indices.unsqueeze(-1).expand(-1, -1, -1, D)] = 0 | |
# gloss/time masking | |
elif self.mask_prob > 0 and self.mask_strategy == "gloss": | |
if mask_indices is None: | |
mask_indices_channel = compute_mask_indices( | |
(B, T), | |
padding_mask, | |
self.mask_prob, | |
self.mask_length, | |
self.mask_selection, | |
self.mask_other, | |
min_masks=1, | |
no_overlap=self.no_mask_channel_overlap, | |
min_space=self.mask_min_space, | |
require_same_masks=self.cfg.require_same_masks, | |
mask_dropout=self.cfg.mask_dropout, | |
) | |
mask_indices_channel = torch.from_numpy(mask_indices_channel).to(x.device) | |
# Apply the same mask to all channels | |
mask_indices = mask_indices_channel.unsqueeze(2).expand(-1, -1, C) | |
mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D) | |
mask[mask_indices] = 0 | |
# random masking | |
elif self.mask_prob > 0 and self.mask_strategy == "random": | |
if mask_indices is None: | |
mask_indices = compute_mask_indices( | |
(B, T*C), # Note: T*C instead of T | |
padding_mask, | |
self.mask_prob, | |
self.mask_length, | |
self.mask_selection, | |
self.mask_other, | |
min_masks=1, | |
no_overlap=self.no_mask_channel_overlap, | |
min_space=self.mask_min_space, | |
require_same_masks=self.cfg.require_same_masks, | |
mask_dropout=self.cfg.mask_dropout, | |
) | |
mask_indices = torch.from_numpy(mask_indices).to(x.device) | |
mask_indices = mask_indices.view(B, T, C) | |
mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D) | |
mask[mask_indices] = 0 | |
else: | |
raise ValueError(f"unknown mask strategy {self.mask_strategy}") | |
# Apply the mask to x and return the masked tensor with the same shape as x | |
# x = x * mask | |
x = x * mask + self.mask_emb * (1 - mask) | |
return x, mask | |
# mask is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked | |
def forward( | |
self, | |
source, | |
padding_mask=None, | |
mask=True, | |
features_only=False, | |
layer=None, | |
mask_indices=None, | |
mask_channel_indices=None, | |
padding_count=None, | |
kmeans_labels=None, | |
): | |
channels_to_use = [] | |
for c in self.channels: | |
if c in source[0]: | |
channels_to_use.append(c) | |
for c in channels_to_use: | |
if c == "face": | |
face_features_list = [] | |
label_face_features_list = [] | |
elif c == "left_hand": | |
left_hand_features_list = [] | |
label_left_hand_features_list = [] | |
elif c == "right_hand": | |
right_hand_features_list = [] | |
label_right_hand_features_list = [] | |
elif c == "body_posture": | |
body_posture_features_list = [] | |
label_body_posture_features_list = [] | |
# # source is a list of dictionaries with keys "face", "left_hand", "right_hand", "body_posture" | |
# face_features_list = [] | |
# left_hand_features_list = [] | |
# right_hand_features_list = [] | |
# body_posture_features_list = [] | |
# label_face_features_list = [] | |
# label_left_hand_features_list = [] | |
# label_right_hand_features_list = [] | |
# label_body_posture_features_list = [] | |
# for sample in source: | |
# face_features_list.append(sample["face"]) # Tx384 | |
# left_hand_features_list.append(sample["left_hand"]) # Tx384 | |
# right_hand_features_list.append(sample["right_hand"]) # Tx384 | |
# body_posture_features_list.append(sample["body_posture"]) # Tx14 | |
# label_face_features_list.append(sample["label_face"]) # Tx1 | |
# label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1 | |
# label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1 | |
# label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1 | |
for sample in source: | |
for c in channels_to_use: | |
if c == "face": | |
face_features_list.append(sample["face"]) # Tx384 | |
label_face_features_list.append(sample["label_face"]) # Tx1 | |
elif c == "left_hand": | |
left_hand_features_list.append(sample["left_hand"]) # Tx384 | |
label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1 | |
elif c == "right_hand": | |
right_hand_features_list.append(sample["right_hand"]) # Tx384 | |
label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1 | |
elif c == "body_posture": | |
body_posture_features_list.append(sample["body_posture"]) # Tx14 | |
label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1 | |
# face_features = torch.stack(face_features_list) # BxTx384 | |
# left_hand_features = torch.stack(left_hand_features_list) # BxTx384 | |
# right_hand_features = torch.stack(right_hand_features_list) # BxTx384 | |
# body_posture_features = torch.stack(body_posture_features_list) # BxTx14 | |
# face_labels = torch.stack(label_face_features_list) # BxTx1 | |
# left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1 | |
# right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1 | |
# body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1 | |
# # Apply layer normalization to each part | |
# face_features = self.layer_norm_face(face_features) # BxTx384 | |
# left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384 | |
# right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384 | |
# body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14 | |
# # Apply separate linear projections for each channel | |
# face_features = self.face_proj(face_features) # BxTx256 | |
# left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256 | |
# right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256 | |
# body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256 | |
features_list = [] | |
labels_list = [] | |
for c in channels_to_use: | |
if c == "face": | |
face_features = torch.stack(face_features_list) # BxTx384 | |
face_labels = torch.stack(label_face_features_list) # BxTx1 | |
face_features = self.layer_norm_face(face_features) # BxTx384 | |
face_features = self.face_proj(face_features) # BxTx256 | |
features_list.append(face_features) | |
labels_list.append(face_labels) | |
elif c == "left_hand": | |
left_hand_features = torch.stack(left_hand_features_list) # BxTx384 | |
left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1 | |
left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384 | |
left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256 | |
features_list.append(left_hand_features) | |
labels_list.append(left_hand_labels) | |
elif c == "right_hand": | |
right_hand_features = torch.stack(right_hand_features_list) # BxTx384 | |
right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1 | |
right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384 | |
right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256 | |
features_list.append(right_hand_features) | |
labels_list.append(right_hand_labels) | |
elif c == "body_posture": | |
body_posture_features = torch.stack(body_posture_features_list) # BxTx14 | |
body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1 | |
body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14 | |
body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256 | |
features_list.append(body_posture_features) | |
labels_list.append(body_posture_labels) | |
# concatenate the projected features to have dimension BxTxCxD where C=4 and D=256 | |
# features = torch.stack( | |
# [ | |
# face_features, | |
# left_hand_features, | |
# right_hand_features, | |
# body_posture_features | |
# ], | |
# dim=2) # BxTx4x256 | |
features = torch.stack(features_list, dim=2) # BxTx4x256 | |
if mask: | |
x, mask_indices = self.apply_mask( | |
features, | |
padding_mask, | |
mask_indices=mask_indices, | |
mask_channel_indices=mask_channel_indices, | |
) | |
# mask_indices is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked | |
else: | |
x = features | |
mask_indices = None | |
x = self.dropout_input(x) # BxTx4x256 | |
x = x.view(x.size(0), x.size(1), -1) # BxTx1024 | |
if self.post_extract_proj is not None: | |
x = self.post_extract_proj(x) # BxTx768 | |
x, layer_results = self.encoder( | |
x, | |
padding_mask=padding_mask, | |
layer=layer, | |
) | |
if features_only: | |
return { | |
"x": x, | |
"padding_mask": padding_mask, | |
"layer_results": layer_results, | |
} | |
result = { | |
"losses": {}, | |
} | |
# use linear heads to compute the discrete prediction for each channel and make it into a single tensor of shape BxTxCxcodebook_size | |
predictions = [] | |
for i, head in enumerate(self.heads): | |
channel_pred = head(x) # BxTxcodebook_size | |
predictions.append(channel_pred) | |
predictions = torch.stack(predictions, dim=2) # BxTx4xcodebook_size | |
# labels = torch.stack( | |
# [ | |
# face_labels, | |
# left_hand_labels, | |
# right_hand_labels, | |
# body_posture_labels | |
# ], | |
# dim=2) # BxTx4x1 | |
labels = torch.stack(labels_list, dim=2) # BxTx4x1 | |
# print(f"predictions shape: {predictions.shape} and labels shape: {labels.shape}") | |
predictions_flat = predictions.view(-1, self.codebook_size) # Shape: (B * T * C, codebook_size) | |
labels_flat = labels.view(-1) # Shape: (B * T * C) | |
# Ensure labels are of correct shape | |
labels_flat = labels_flat.squeeze(-1) # Remove the last dimension if it's size 1 | |
# Correct the mask_indices to match the shape of predictions_flat | |
mask_indices_reduced = mask_indices.any(dim=-1) # Reduce mask to (B, T, C) by collapsing last dimension | |
mask_indices_flat = mask_indices_reduced.view(-1) # Flatten to match the shape of (B * T * C) | |
# Calculate the loss only for the masked positions (where mask_indices_flat is zero) | |
masked_loss = F.cross_entropy( | |
predictions_flat[mask_indices_flat == 0], | |
labels_flat[mask_indices_flat == 0], | |
reduction='none' | |
) | |
# Store the result | |
result['losses']['kmeans_loss'] = masked_loss | |
if "sample_size" not in result: | |
result['sample_size'] = masked_loss.numel() | |
return result | |
def compute_var(y): | |
y = y.view(-1, y.size(-1)) | |
if dist.is_initialized(): | |
zc = torch.tensor(y.size(0)).cuda() | |
zs = y.sum(dim=0) | |
zss = (y ** 2).sum(dim=0) | |
dist.all_reduce(zc) | |
dist.all_reduce(zs) | |
dist.all_reduce(zss) | |
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) | |
return torch.sqrt(var + 1e-6).mean() | |
else: | |
return torch.sqrt(y.var(dim=0) + 1e-6).mean() | |
def extract_features( | |
self, source, padding_mask, kmeans_labels, mask=False, layer=None | |
): | |
res = self.forward( | |
source, | |
padding_mask, | |
mask=mask, | |
features_only=True, | |
layer=layer, | |
kmeans_labels=kmeans_labels, | |
) | |
return res | |
def remove_pretraining_modules(self, last_layer=None): | |
self.heads = None | |
self.final_proj = None | |
if last_layer is not None: | |
self.encoder.layers = nn.ModuleList( | |
l for i, l in enumerate(self.encoder.layers) if i <= last_layer | |
) | |