OpenSound's picture
Upload 518 files
dd9600d verified
import os
import pdb
import copy
import torch
import argparse
import loralib as lora
import transformers.models.wavlm.modeling_wavlm as wavlm
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
from torch import nn
from torch.nn import functional as F
from transformers import Wav2Vec2FeatureExtractor
from transformers import WavLMModel
import sys
from pathlib import Path
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
from revgrad import RevGrad
class WavLMEncoderLayer(nn.Module):
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
super().__init__()
self.attention = wavlm.WavLMAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = wavlm.WavLMFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.config = config
if layer_idx > config.num_hidden_layers // 2:
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
attn_residual = hidden_states
hidden_states, attn_weights, position_bias = self.attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
index=index,
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states, position_bias)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WavLMEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
super().__init__()
self.attention = wavlm.WavLMAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = wavlm.WavLMFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.config = config
if layer_idx > config.num_hidden_layers // 2:
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, position_bias = self.attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
outputs = (hidden_states, position_bias)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WavLMWrapper(nn.Module):
def __init__(
self,
pretrain_model="wavlm_large",
hidden_dim=256,
finetune_method="lora",
lora_rank=16,
freeze_params=True,
output_class_num=4,
use_conv_output=True,
apply_gradient_reversal=False,
num_dataset=4
):
super(WavLMWrapper, self).__init__()
# 1. We Load the model first with weights
if pretrain_model == "wavlm":
self.backbone_model = WavLMModel.from_pretrained(
"microsoft/wavlm-base-plus",
output_hidden_states=True,
)
elif pretrain_model == "wavlm_large":
self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
self.backbone_model = WavLMModel.from_pretrained(
"microsoft/wavlm-large",
output_hidden_states=True,
)
self.pretrain_model = pretrain_model
self.finetune_method = finetune_method
self.apply_gradient_reversal = apply_gradient_reversal
self.use_conv_output = use_conv_output
state_dict = self.backbone_model.state_dict()
# 2. Read the model config
self.model_config = self.backbone_model.config
self.model_config.finetune_method = finetune_method
self.model_config.lora_rank = lora_rank
# 3. Config encoder layers with adapter or embedding prompt
if self.pretrain_model == "wavlm":
self.backbone_model.encoder.layers = nn.ModuleList(
[WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
)
elif self.pretrain_model == "wavlm_large":
self.backbone_model.encoder.layers = nn.ModuleList(
[WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
)
# 4. Load the weights back
msg = self.backbone_model.load_state_dict(state_dict, strict=False)
# 5. Freeze the weights
self.freeze_params = freeze_params
if self.freeze_params and self.finetune_method != "lora":
for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
elif self.freeze_params and self.finetune_method == "lora":
for name, p in self.backbone_model.named_parameters():
if name in msg.missing_keys: p.requires_grad = True
else: p.requires_grad = False
else:
for _, p in self.backbone_model.named_parameters(): p.requires_grad = True
# 6. Downstream models
self.model_seq = nn.Sequential(
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
)
if self.use_conv_output:
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
else:
num_layers = self.model_config.num_hidden_layers
self.weights = nn.Parameter(torch.zeros(num_layers))
if apply_gradient_reversal:
self.dataset_layer = nn.Sequential(
RevGrad(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_dataset),
)
self.out_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_class_num),
)
def forward(self, x, length=None, return_feature=False):
# 1. feature extraction and projections
if self.pretrain_model == "wavlm_large":
with torch.no_grad():
signal, attention_mask = list(), list()
if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
for idx in range(len(x)):
input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
signal.append(input["input_values"][0].to(x.device))
signal = torch.stack(signal)
# 2. get length and mask
if length is not None:
length = self.get_feat_extract_output_lengths(length.detach().cpu())
length = length.cuda()
if self.pretrain_model == "wavlm":
x = self.backbone_model(
x, output_hidden_states=True
).hidden_states
else:
x = self.backbone_model(
signal,
attention_mask=attention_mask,
output_hidden_states=True
).hidden_states
# 4. stacked feature
if self.use_conv_output: stacked_feature = torch.stack(x, dim=0)
else: stacked_feature = torch.stack(x, dim=0)[1:]
# 5. Weighted sum
_, *origin_shape = stacked_feature.shape
# Return transformer enc outputs [num_enc_layers, B, T, D]
if self.use_conv_output:
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1)
else:
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
norm_weights = F.softmax(self.weights, dim=-1)
# Perform weighted average
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
features = weighted_feature.view(*origin_shape)
# 6. Pass the weighted average to point-wise 1D Conv
# B x T x D
features = features.transpose(1, 2)
features = self.model_seq(features)
features = features.transpose(1, 2)
# 7. Pooling
if length is not None:
mean, std = list(), list()
for snt_id in range(features.shape[0]):
# Avoiding padded time steps
actual_size = length[snt_id]
mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
features = torch.stack(mean)
else:
features = torch.mean(features, dim=1)
# 8. Output predictions
# B x D
predicted = self.out_layer(features)
if self.apply_gradient_reversal:
dataset_predicted = self.dataset_layer(features)
if return_feature: return predicted, dataset_predicted, features
return predicted, dataset_predicted
if return_feature: return predicted, features
return predicted
# From huggingface
def get_feat_extract_output_lengths(self, input_length):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
input_length = _conv_out_length(input_length, kernel_size, stride)
return input_length
def prepare_mask(length, shape, dtype):
# Modified from huggingface
mask = torch.zeros(
shape, dtype=dtype
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return mask