dasheng-0.6B / modeling_dasheng.py
jimbozhang's picture
Upload 3 files
ed2a3f5 verified
# coding=utf-8
# Copyright 2023-2024 Xiaomi Corporation and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Dasheng (Deep Audio-Signal Holistic Embeddings) model."""
import collections
import math
from functools import partial
from typing import Any, Optional, Tuple
import torch
import torch.utils.checkpoint
from einops.layers.torch import Rearrange
from torch import nn
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_dasheng import DashengConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DashengConfig"
# Audio classification docstring
_SEQ_CLASS_CHECKPOINT = "mispeech/dasheng-base"
DASHENG_PRETRAINED_MODEL_ARCHIVE_LIST = [
"mispeech/dasheng-base",
"mispeech/dasheng-0.6B",
"mispeech/dasheng-1.2B",
# See all Dasheng models at https://huggingface.co/models?search=mispeech%2Fdasheng
]
# The functions `trunc_normal_`, `_no_grad_trunc_normal_`, `drop_path` and the module `DropPath`` are taken from timm
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
def to_2tuple(x: Any) -> Tuple[Any, Any]:
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
class AudioPatchEmbed(nn.Module):
def __init__(
self,
input_size=224,
patch_size=16,
patch_stride=16,
in_chans=1,
embed_dim=768,
norm_layer=None,
flatten=False,
):
super().__init__()
input_size = to_2tuple(input_size)
patch_size = to_2tuple(patch_size)
patch_stride = to_2tuple(patch_stride)
self.input_size = input_size
self.patch_size = patch_size
self.patch_stride = patch_stride
self.grid_size = (
input_size[0] // patch_stride[0],
input_size[1] // patch_stride[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c")
x = self.norm(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
init_values=None,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attention_type="Attention",
):
super().__init__()
self.norm1 = norm_layer(dim)
attn_type = globals()[attention_type]
self.attn = attn_type(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class AudioTransformerMAE_Encoder(nn.Module):
def __init__(
self,
patch_size=16,
patch_stride=16,
embed_dim=768,
depth=12,
num_heads=8,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=None,
act_layer=None,
init_values=None,
target_length=1012,
pooling="mean",
wavtransforms=None,
spectransforms=None,
time_patch_out: Optional[float] = None,
freq_patch_out: Optional[float] = None,
block_type="Block",
attention_type="Attention",
eval_avg="mean",
**kwargs,
):
super().__init__()
assert pooling in ("mean", "token", "logit")
self.pooling = pooling
self.embed_dim = embed_dim
self.patch_stride = patch_stride
self.patch_size = patch_size
self.n_mels = kwargs.get("n_mels", 64)
init_bn = kwargs.get("init_bn", True)
self.eval_avg = eval_avg
self.time_patch_out = time_patch_out
self.freq_patch_out = freq_patch_out
self.pad_last = kwargs.get("pad_last", True)
if init_bn:
self.init_bn = nn.Sequential(
Rearrange("b c f t -> b f c t"),
torch.nn.BatchNorm2d(self.n_mels, momentum=0.01),
Rearrange("b f c t -> b c f t"),
)
self.target_length = target_length
self.patch_embed = AudioPatchEmbed(
input_size=(self.n_mels, target_length),
embed_dim=self.embed_dim,
patch_size=self.patch_size,
flatten=False,
patch_stride=self.patch_stride,
)
self.spectransforms = nn.Sequential() if spectransforms is None else spectransforms
self.wavtransforms = nn.Sequential() if wavtransforms is None else wavtransforms
self.num_patches = self.patch_embed.num_patches
if pooling == "token":
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.token_pos_embed = nn.Parameter(torch.randn(1, embed_dim) * 0.02)
self.time_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02)
self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.grid_size[0], 1) * 0.02)
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.pos_drop = nn.Dropout(p=drop_rate)
block_function = globals()[block_type]
self.blocks = nn.Sequential(
*[
block_function(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
attention_type=attention_type,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
if hasattr(self, "cls_token") and self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
@torch.jit.ignore
def no_weight_decay(self):
return {"time_pos_embed", "cls_token", "freq_pos_embed", "token_pos_embed"}
def forward_features(self, x):
x = self.patch_embed(x)
b, c, f, t = x.shape
x = x + self.time_pos_embed[:, :, :, :t]
x = x + self.freq_pos_embed[:, :, :, :] # Just for sin pos embed
x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c")
if self.pooling == "token":
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
cls_token = cls_token + self.token_pos_embed[:, :]
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x)
x = self.blocks(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.init_bn(x) if self.init_bn is not None else x
# Remember starting position if we pad
padding_start = 0
if x.shape[-1] > self.target_length:
splits = x.split(self.target_length, -1)
if splits[-1].shape[-1] < self.target_length:
if self.pad_last:
pad = torch.zeros(*x.shape[:-1], self.target_length, device=x.device)
pad[..., : splits[-1].shape[-1]] = splits[-1]
padding_start = x.shape[-1] // self.patch_stride[-1]
splits = torch.stack((*splits[:-1], pad), dim=0)
else:
splits = torch.stack(splits[:-1], dim=0)
else:
splits = torch.stack(splits[:-1], dim=0)
n_splits = len(splits)
x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t
else:
n_splits = 1
x = self.forward_features(x)
x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1]))
if padding_start:
x = x[:,:padding_start, :]
return x
DASHENG_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`DashengConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
DASHENG_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, n_mels, sequence_length)`):
The sequence of audio features extracted from the audio signal. Can be obtained from a raw audio waveform
using `~transformers.DashengFeatureExtractor.__call__`.
"""
class DashengPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DashengConfig
base_model_prefix = "dasheng"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
@add_start_docstrings(
"The Dasheng Model transformer with an optional linear layer on top of the pooled output.",
DASHENG_START_DOCSTRING,
)
class DashengModel(DashengPreTrainedModel):
def __init__(self, config: DashengConfig, outputdim: Optional[int] = None) -> None:
r"""
Initializes the model.
Args:
config (DashengConfig): Configuration class for the model.
outputdim (int, optional): Dimension of the output layer. If None, the model returns the hidden states. Defaults to None.
"""
super().__init__(config)
self.config = config
self.name = config.name
self.encoder = AudioTransformerMAE_Encoder(**config.encoder_kwargs)
# Classifier head
if outputdim is not None:
self.outputlayer = nn.Sequential(
nn.LayerNorm(config.encoder_kwargs["embed_dim"]),
nn.Linear(config.encoder_kwargs["embed_dim"], outputdim),
)
else:
self.outputlayer = nn.Identity()
outputdim = config.encoder_kwargs["embed_dim"]
self.outputdim = outputdim
# Initialize weights and apply final processing
self.post_init()
def forward_head(self, x: torch.Tensor) -> torch.Tensor:
if self.encoder.pooling == "token":
x = x[:, 0]
return self.outputlayer(x).sigmoid()
elif self.encoder.pooling == "mean":
x = x.mean(1)
return self.outputlayer(x).sigmoid()
elif self.encoder.pooling == "logit":
x = x.mean(1)
return self.outputlayer(x)
else:
raise NotImplementedError(f"Pooling {self.encoder.pooling} not implemented.")
def freeze_encoder(self) -> None:
for param in self.encoder.parameters():
param.requires_grad = False
self._requires_grad = False
@add_start_docstrings_to_model_forward(DASHENG_INPUTS_DOCSTRING.format("batch_size, n_mels, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
model_cls="DashengModel",
)
def forward(self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None) -> SequenceClassifierOutput:
"""
Runs a forward pass of the Dasheng model with audio features. The model returns logits and hidden states. If labels are provided, the model also returns the loss.
"""
x = torch.unsqueeze(input_values, 1)
last_hidden_states = self.encoder(x)
logits = self.forward_head(last_hidden_states)
if labels is not None:
try:
loss_fct = getattr(nn.modules.loss, self.config.loss)()
except AttributeError:
raise NotImplementedError(f"Loss {self.config.loss} not implemented.")
labels = nn.functional.one_hot(labels, num_classes=self.outputdim).float()
loss = loss_fct(logits, labels)
else:
loss = None
return SequenceClassifierOutput(logits=logits, loss=loss, hidden_states=last_hidden_states)