# coding=utf-8 # Copyright 2023-2024 Xiaomi Corporation and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Dasheng (Deep Audio-Signal Holistic Embeddings) model.""" import collections import math from functools import partial from typing import Any, Optional, Tuple import torch import torch.utils.checkpoint from einops.layers.torch import Rearrange from torch import nn from transformers.modeling_outputs import SequenceClassifierOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, ) from .configuration_dasheng import DashengConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DashengConfig" # Audio classification docstring _SEQ_CLASS_CHECKPOINT = "mispeech/dasheng-base" DASHENG_PRETRAINED_MODEL_ARCHIVE_LIST = [ "mispeech/dasheng-base", "mispeech/dasheng-0.6B", "mispeech/dasheng-1.2B", # See all Dasheng models at https://huggingface.co/models?search=mispeech%2Fdasheng ] # The functions `trunc_normal_`, `_no_grad_trunc_normal_`, `drop_path` and the module `DropPath`` are taken from timm def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): return _no_grad_trunc_normal_(tensor, mean, std, a, b) def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): return f"drop_prob={round(self.drop_prob,3):0.3f}" def to_2tuple(x: Any) -> Tuple[Any, Any]: if isinstance(x, collections.abc.Iterable): return x return (x, x) class AudioPatchEmbed(nn.Module): def __init__( self, input_size=224, patch_size=16, patch_stride=16, in_chans=1, embed_dim=768, norm_layer=None, flatten=False, ): super().__init__() input_size = to_2tuple(input_size) patch_size = to_2tuple(patch_size) patch_stride = to_2tuple(patch_stride) self.input_size = input_size self.patch_size = patch_size self.patch_stride = patch_stride self.grid_size = ( input_size[0] // patch_stride[0], input_size[1] // patch_stride[1], ) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) if self.flatten: x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") x = self.norm(x) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, init_values=None, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type="Attention", ): super().__init__() self.norm1 = norm_layer(dim) attn_type = globals()[attention_type] self.attn = attn_type(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class AudioTransformerMAE_Encoder(nn.Module): def __init__( self, patch_size=16, patch_stride=16, embed_dim=768, depth=12, num_heads=8, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=None, act_layer=None, init_values=None, target_length=1012, pooling="mean", wavtransforms=None, spectransforms=None, time_patch_out: Optional[float] = None, freq_patch_out: Optional[float] = None, block_type="Block", attention_type="Attention", eval_avg="mean", **kwargs, ): super().__init__() assert pooling in ("mean", "token", "logit") self.pooling = pooling self.embed_dim = embed_dim self.patch_stride = patch_stride self.patch_size = patch_size self.n_mels = kwargs.get("n_mels", 64) init_bn = kwargs.get("init_bn", True) self.eval_avg = eval_avg self.time_patch_out = time_patch_out self.freq_patch_out = freq_patch_out self.pad_last = kwargs.get("pad_last", True) if init_bn: self.init_bn = nn.Sequential( Rearrange("b c f t -> b f c t"), torch.nn.BatchNorm2d(self.n_mels, momentum=0.01), Rearrange("b f c t -> b c f t"), ) self.target_length = target_length self.patch_embed = AudioPatchEmbed( input_size=(self.n_mels, target_length), embed_dim=self.embed_dim, patch_size=self.patch_size, flatten=False, patch_stride=self.patch_stride, ) self.spectransforms = nn.Sequential() if spectransforms is None else spectransforms self.wavtransforms = nn.Sequential() if wavtransforms is None else wavtransforms self.num_patches = self.patch_embed.num_patches if pooling == "token": self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.token_pos_embed = nn.Parameter(torch.randn(1, embed_dim) * 0.02) self.time_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02) self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.grid_size[0], 1) * 0.02) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.pos_drop = nn.Dropout(p=drop_rate) block_function = globals()[block_type] self.blocks = nn.Sequential( *[ block_function( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_type=attention_type, ) for i in range(depth) ] ) self.norm = norm_layer(embed_dim) if hasattr(self, "cls_token") and self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) @torch.jit.ignore def no_weight_decay(self): return {"time_pos_embed", "cls_token", "freq_pos_embed", "token_pos_embed"} def forward_features(self, x): x = self.patch_embed(x) b, c, f, t = x.shape x = x + self.time_pos_embed[:, :, :, :t] x = x + self.freq_pos_embed[:, :, :, :] # Just for sin pos embed x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") if self.pooling == "token": cls_token = self.cls_token.expand(x.shape[0], -1, -1) cls_token = cls_token + self.token_pos_embed[:, :] x = torch.cat((cls_token, x), dim=1) x = self.pos_drop(x) x = self.blocks(x) x = self.norm(x) return x def forward(self, x): x = self.init_bn(x) if self.init_bn is not None else x # Remember starting position if we pad padding_start = 0 if x.shape[-1] > self.target_length: splits = x.split(self.target_length, -1) if splits[-1].shape[-1] < self.target_length: if self.pad_last: pad = torch.zeros(*x.shape[:-1], self.target_length, device=x.device) pad[..., : splits[-1].shape[-1]] = splits[-1] padding_start = x.shape[-1] // self.patch_stride[-1] splits = torch.stack((*splits[:-1], pad), dim=0) else: splits = torch.stack(splits[:-1], dim=0) else: splits = torch.stack(splits[:-1], dim=0) n_splits = len(splits) x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t else: n_splits = 1 x = self.forward_features(x) x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1])) if padding_start: x = x[:,:padding_start, :] return x DASHENG_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`DashengConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ DASHENG_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, n_mels, sequence_length)`): The sequence of audio features extracted from the audio signal. Can be obtained from a raw audio waveform using `~transformers.DashengFeatureExtractor.__call__`. """ class DashengPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = DashengConfig base_model_prefix = "dasheng" main_input_name = "input_values" supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) @add_start_docstrings( "The Dasheng Model transformer with an optional linear layer on top of the pooled output.", DASHENG_START_DOCSTRING, ) class DashengModel(DashengPreTrainedModel): def __init__(self, config: DashengConfig, outputdim: Optional[int] = None) -> None: r""" Initializes the model. Args: config (DashengConfig): Configuration class for the model. outputdim (int, optional): Dimension of the output layer. If None, the model returns the hidden states. Defaults to None. """ super().__init__(config) self.config = config self.name = config.name self.encoder = AudioTransformerMAE_Encoder(**config.encoder_kwargs) # Classifier head if outputdim is not None: self.outputlayer = nn.Sequential( nn.LayerNorm(config.encoder_kwargs["embed_dim"]), nn.Linear(config.encoder_kwargs["embed_dim"], outputdim), ) else: self.outputlayer = nn.Identity() outputdim = config.encoder_kwargs["embed_dim"] self.outputdim = outputdim # Initialize weights and apply final processing self.post_init() def forward_head(self, x: torch.Tensor) -> torch.Tensor: if self.encoder.pooling == "token": x = x[:, 0] return self.outputlayer(x).sigmoid() elif self.encoder.pooling == "mean": x = x.mean(1) return self.outputlayer(x).sigmoid() elif self.encoder.pooling == "logit": x = x.mean(1) return self.outputlayer(x) else: raise NotImplementedError(f"Pooling {self.encoder.pooling} not implemented.") def freeze_encoder(self) -> None: for param in self.encoder.parameters(): param.requires_grad = False self._requires_grad = False @add_start_docstrings_to_model_forward(DASHENG_INPUTS_DOCSTRING.format("batch_size, n_mels, sequence_length")) @add_code_sample_docstrings( checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, modality="audio", model_cls="DashengModel", ) def forward(self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None) -> SequenceClassifierOutput: """ Runs a forward pass of the Dasheng model with audio features. The model returns logits and hidden states. If labels are provided, the model also returns the loss. """ x = torch.unsqueeze(input_values, 1) last_hidden_states = self.encoder(x) logits = self.forward_head(last_hidden_states) if labels is not None: try: loss_fct = getattr(nn.modules.loss, self.config.loss)() except AttributeError: raise NotImplementedError(f"Loss {self.config.loss} not implemented.") labels = nn.functional.one_hot(labels, num_classes=self.outputdim).float() loss = loss_fct(logits, labels) else: loss = None return SequenceClassifierOutput(logits=logits, loss=loss, hidden_states=last_hidden_states)