Spaces:
Running
Running
import torch | |
import math | |
import inspect | |
from torch import nn | |
from torch import Tensor | |
from typing import Tuple | |
from typing import Optional | |
from torch.nn.functional import fold, unfold | |
import numpy as np | |
from . import activations, normalizations | |
from .normalizations import gLN | |
def has_arg(fn, name): | |
"""Checks if a callable accepts a given keyword argument. | |
Args: | |
fn (callable): Callable to inspect. | |
name (str): Check if ``fn`` can be called with ``name`` as a keyword | |
argument. | |
Returns: | |
bool: whether ``fn`` accepts a ``name`` keyword argument. | |
""" | |
signature = inspect.signature(fn) | |
parameter = signature.parameters.get(name) | |
if parameter is None: | |
return False | |
return parameter.kind in ( | |
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
inspect.Parameter.KEYWORD_ONLY, | |
) | |
class SingleRNN(nn.Module): | |
"""Module for a RNN block. | |
Inspired from https://github.com/yluo42/TAC/blob/master/utility/models.py | |
Licensed under CC BY-NC-SA 3.0 US. | |
Args: | |
rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can | |
also be passed in lowercase letters. | |
input_size (int): Dimension of the input feature. The input should have | |
shape [batch, seq_len, input_size]. | |
hidden_size (int): Dimension of the hidden state. | |
n_layers (int, optional): Number of layers used in RNN. Default is 1. | |
dropout (float, optional): Dropout ratio. Default is 0. | |
bidirectional (bool, optional): Whether the RNN layers are | |
bidirectional. Default is ``False``. | |
""" | |
def __init__( | |
self, | |
rnn_type, | |
input_size, | |
hidden_size, | |
n_layers=1, | |
dropout=0, | |
bidirectional=False, | |
): | |
super(SingleRNN, self).__init__() | |
assert rnn_type.upper() in ["RNN", "LSTM", "GRU"] | |
rnn_type = rnn_type.upper() | |
self.rnn_type = rnn_type | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.n_layers = n_layers | |
self.dropout = dropout | |
self.bidirectional = bidirectional | |
self.rnn = getattr(nn, rnn_type)( | |
input_size, | |
hidden_size, | |
num_layers=n_layers, | |
dropout=dropout, | |
batch_first=True, | |
bidirectional=bool(bidirectional), | |
) | |
def output_size(self): | |
return self.hidden_size * (2 if self.bidirectional else 1) | |
def forward(self, inp): | |
""" Input shape [batch, seq, feats] """ | |
self.rnn.flatten_parameters() # Enables faster multi-GPU training. | |
output = inp | |
rnn_output, _ = self.rnn(output) | |
return rnn_output | |
class LSTMBlockTF(nn.Module): | |
def __init__( | |
self, | |
in_chan, | |
hid_size, | |
norm_type="gLN", | |
bidirectional=True, | |
rnn_type="LSTM", | |
num_layers=1, | |
dropout=0, | |
): | |
super(LSTMBlockTF, self).__init__() | |
self.RNN = SingleRNN( | |
rnn_type, | |
in_chan, | |
hid_size, | |
num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
) | |
self.linear = nn.Linear(self.RNN.output_size, in_chan) | |
self.norm = normalizations.get(norm_type)(in_chan) | |
def forward(self, x): | |
B, F, T = x.size() | |
output = self.RNN(x.transpose(1, 2)) # B, T, N | |
output = self.linear(output) | |
output = output.transpose(1, -1) # B, N, T | |
output = self.norm(output) | |
return output + x | |
# ===================Transformer====================== | |
class Linear(nn.Module): | |
""" | |
Wrapper class of torch.nn.Linear | |
Weight initialize by xavier initialization and bias initialize to zeros. | |
""" | |
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | |
super(Linear, self).__init__() | |
self.linear = nn.Linear(in_features, out_features, bias=bias) | |
nn.init.xavier_uniform_(self.linear.weight) | |
if bias: | |
nn.init.zeros_(self.linear.bias) | |
def forward(self, x): | |
return self.linear(x) | |
class Swish(nn.Module): | |
""" | |
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied | |
to a variety of challenging domains such as Image classification and Machine translation. | |
""" | |
def __init__(self): | |
super(Swish, self).__init__() | |
def forward(self, inputs): | |
return inputs * inputs.sigmoid() | |
class Transpose(nn.Module): | |
""" Wrapper class of torch.transpose() for Sequential module. """ | |
def __init__(self, shape: tuple): | |
super(Transpose, self).__init__() | |
self.shape = shape | |
def forward(self, x: Tensor) -> Tensor: | |
return x.transpose(*self.shape) | |
class GLU(nn.Module): | |
""" | |
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing | |
in the paper “Language Modeling with Gated Convolutional Networks” | |
""" | |
def __init__(self, dim: int) -> None: | |
super(GLU, self).__init__() | |
self.dim = dim | |
def forward(self, inputs: Tensor) -> Tensor: | |
outputs, gate = inputs.chunk(2, dim=self.dim) | |
return outputs * gate.sigmoid() | |
class FeedForwardModule(nn.Module): | |
def __init__( | |
self, encoder_dim: int = 512, expansion_factor: int = 4, dropout_p: float = 0.1, | |
) -> None: | |
super(FeedForwardModule, self).__init__() | |
self.sequential = nn.Sequential( | |
nn.LayerNorm(encoder_dim), | |
Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), | |
Swish(), | |
nn.Dropout(p=dropout_p), | |
Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), | |
nn.Dropout(p=dropout_p), | |
) | |
def forward(self, inputs): | |
return self.sequential(inputs) | |
class PositionalEncoding(nn.Module): | |
""" | |
Positional Encoding proposed in "Attention Is All You Need". | |
Since transformer contains no recurrence and no convolution, in order for the model to make | |
use of the order of the sequence, we must add some positional information. | |
"Attention Is All You Need" use sine and cosine functions of different frequencies: | |
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) | |
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) | |
""" | |
def __init__(self, d_model: int = 512, max_len: int = 10000) -> None: | |
super(PositionalEncoding, self).__init__() | |
pe = torch.zeros(max_len, d_model, requires_grad=False) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | |
) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer("pe", pe) | |
def forward(self, length: int) -> Tensor: | |
return self.pe[:, :length] | |
class RelativeMultiHeadAttention(nn.Module): | |
""" | |
Multi-head attention with relative positional encoding. | |
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" | |
Args: | |
d_model (int): The dimension of model | |
num_heads (int): The number of attention heads. | |
dropout_p (float): probability of dropout | |
Inputs: query, key, value, pos_embedding, mask | |
- **query** (batch, time, dim): Tensor containing query vector | |
- **key** (batch, time, dim): Tensor containing key vector | |
- **value** (batch, time, dim): Tensor containing value vector | |
- **pos_embedding** (batch, time, dim): Positional embedding tensor | |
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked | |
Returns: | |
- **outputs**: Tensor produces by relative multi head attention module. | |
""" | |
def __init__( | |
self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1, | |
): | |
super(RelativeMultiHeadAttention, self).__init__() | |
assert d_model % num_heads == 0, "d_model % num_heads should be zero." | |
self.d_model = d_model | |
self.d_head = int(d_model / num_heads) | |
self.num_heads = num_heads | |
self.sqrt_dim = math.sqrt(d_model) | |
self.query_proj = Linear(d_model, d_model) | |
self.key_proj = Linear(d_model, d_model) | |
self.value_proj = Linear(d_model, d_model) | |
self.pos_proj = Linear(d_model, d_model, bias=False) | |
self.dropout = nn.Dropout(p=dropout_p) | |
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
torch.nn.init.xavier_uniform_(self.u_bias) | |
torch.nn.init.xavier_uniform_(self.v_bias) | |
self.out_proj = Linear(d_model, d_model) | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
pos_embedding: Tensor, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
batch_size = value.size(0) | |
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) | |
key = ( | |
self.key_proj(key) | |
.view(batch_size, -1, self.num_heads, self.d_head) | |
.permute(0, 2, 1, 3) | |
) | |
value = ( | |
self.value_proj(value) | |
.view(batch_size, -1, self.num_heads, self.d_head) | |
.permute(0, 2, 1, 3) | |
) | |
pos_embedding = self.pos_proj(pos_embedding).view( | |
batch_size, -1, self.num_heads, self.d_head | |
) | |
content_score = torch.matmul( | |
(query + self.u_bias).transpose(1, 2), key.transpose(2, 3) | |
) | |
pos_score = torch.matmul( | |
(query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1) | |
) | |
pos_score = self._relative_shift(pos_score) | |
score = (content_score + pos_score) / self.sqrt_dim | |
if mask is not None: | |
mask = mask.unsqueeze(1) | |
score.masked_fill_(mask, -1e9) | |
attn = torch.nn.functional.softmax(score, -1) | |
attn = self.dropout(attn) | |
context = torch.matmul(attn, value).transpose(1, 2) | |
context = context.contiguous().view(batch_size, -1, self.d_model) | |
return self.out_proj(context) | |
def _relative_shift(self, pos_score: Tensor) -> Tensor: | |
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() | |
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) | |
padded_pos_score = torch.cat([zeros, pos_score], dim=-1) | |
padded_pos_score = padded_pos_score.view( | |
batch_size, num_heads, seq_length2 + 1, seq_length1 | |
) | |
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) | |
return pos_score | |
class MultiHeadedSelfAttentionModule(nn.Module): | |
""" | |
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, | |
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention | |
module to generalize better on different input length and the resulting encoder is more robust to the variance of | |
the utterance length. Conformer use prenorm residual units with dropout which helps training | |
and regularizing deeper models. | |
Args: | |
d_model (int): The dimension of model | |
num_heads (int): The number of attention heads. | |
dropout_p (float): probability of dropout | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs, mask | |
- **inputs** (batch, time, dim): Tensor containing input vector | |
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked | |
Returns: | |
- **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. | |
""" | |
def __init__( | |
self, d_model: int, num_heads: int, dropout_p: float = 0.1, is_casual=True | |
): | |
super(MultiHeadedSelfAttentionModule, self).__init__() | |
self.positional_encoding = PositionalEncoding(d_model) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) | |
self.dropout = nn.Dropout(p=dropout_p) | |
self.is_casual = is_casual | |
def forward(self, inputs: Tensor): | |
batch_size, seq_length, _ = inputs.size() | |
pos_embedding = self.positional_encoding(seq_length) | |
pos_embedding = pos_embedding.repeat(batch_size, 1, 1) | |
mask = None | |
if self.is_casual: | |
mask = torch.triu( | |
torch.ones((seq_length, seq_length), dtype=torch.uint8).to( | |
inputs.device | |
), | |
diagonal=1, | |
) | |
mask = mask.unsqueeze(0).expand(batch_size, -1, -1).bool() # [B, L, L] | |
inputs = self.layer_norm(inputs) | |
outputs = self.attention( | |
inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask | |
) | |
return self.dropout(outputs) | |
class ResidualConnectionModule(nn.Module): | |
""" | |
Residual Connection Module. | |
outputs = (module(inputs) x module_factor + inputs x input_factor) | |
""" | |
def __init__( | |
self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0 | |
): | |
super(ResidualConnectionModule, self).__init__() | |
self.module = module | |
self.module_factor = module_factor | |
self.input_factor = input_factor | |
def forward(self, inputs): | |
return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) | |
class DepthwiseConv1d(nn.Module): | |
""" | |
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, | |
this operation is termed in literature as depthwise convolution. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
kernel_size (int or tuple): Size of the convolving kernel | |
stride (int, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
Inputs: inputs | |
- **inputs** (batch, in_channels, time): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = False, | |
is_casual: bool = True, | |
) -> None: | |
super(DepthwiseConv1d, self).__init__() | |
assert ( | |
out_channels % in_channels == 0 | |
), "out_channels should be constant multiple of in_channels" | |
if is_casual: | |
padding = kernel_size - 1 | |
else: | |
padding = (kernel_size - 1) // 2 | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
groups=in_channels, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
self.is_casual = is_casual | |
self.kernel_size = kernel_size | |
def forward(self, inputs: Tensor) -> Tensor: | |
if self.is_casual: | |
return self.conv(inputs)[:, :, : -(self.kernel_size - 1)] | |
return self.conv(inputs) | |
class PointwiseConv1d(nn.Module): | |
""" | |
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. | |
This operation often used to match dimensions. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
stride (int, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
Inputs: inputs | |
- **inputs** (batch, in_channels, time): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = True, | |
) -> None: | |
super(PointwiseConv1d, self).__init__() | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.conv(inputs) | |
class ConformerConvModule(nn.Module): | |
""" | |
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). | |
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution | |
to aid training deep models. | |
Args: | |
in_channels (int): Number of channels in the input | |
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 | |
dropout_p (float, optional): probability of dropout | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs | |
inputs (batch, time, dim): Tensor contains input sequences | |
Outputs: outputs | |
outputs (batch, time, dim): Tensor produces by conformer convolution module. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
kernel_size: int = 31, | |
expansion_factor: int = 2, | |
dropout_p: float = 0.1, | |
is_casual: bool = True, | |
) -> None: | |
super(ConformerConvModule, self).__init__() | |
assert ( | |
kernel_size - 1 | |
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" | |
self.sequential = nn.Sequential( | |
nn.LayerNorm(in_channels), | |
Transpose(shape=(1, 2)), | |
PointwiseConv1d( | |
in_channels, | |
in_channels * expansion_factor, | |
stride=1, | |
padding=0, | |
bias=True, | |
), | |
GLU(dim=1), | |
DepthwiseConv1d( | |
in_channels, in_channels, kernel_size, stride=1, is_casual=is_casual | |
), | |
nn.BatchNorm1d(in_channels), | |
Swish(), | |
PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), | |
nn.Dropout(p=dropout_p), | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.sequential(inputs).transpose(1, 2) | |
class TransformerLayer(nn.Module): | |
def __init__( | |
self, in_chan=128, n_head=8, n_att=1, dropout=0.1, max_len=500, is_casual=True | |
): | |
super(TransformerLayer, self).__init__() | |
self.in_chan = in_chan | |
self.n_head = n_head | |
self.dropout = dropout | |
self.max_len = max_len | |
self.n_att = n_att | |
self.seq = nn.Sequential( | |
ResidualConnectionModule( | |
FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout), | |
module_factor=0.5, | |
), | |
ResidualConnectionModule( | |
MultiHeadedSelfAttentionModule(in_chan, n_head, dropout, is_casual) | |
), | |
ResidualConnectionModule( | |
ConformerConvModule(in_chan, 31, 2, dropout, is_casual=is_casual) | |
), | |
ResidualConnectionModule( | |
FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout), | |
module_factor=0.5, | |
), | |
nn.LayerNorm(in_chan), | |
) | |
def forward(self, x): | |
return self.seq(x) | |
class TransformerBlockTF(nn.Module): | |
def __init__( | |
self, | |
in_chan, | |
n_head=8, | |
n_att=1, | |
dropout=0.1, | |
max_len=500, | |
norm_type="cLN", | |
is_casual=True, | |
): | |
super(TransformerBlockTF, self).__init__() | |
self.transformer = TransformerLayer( | |
in_chan, n_head, n_att, dropout, max_len, is_casual | |
) | |
self.norm = normalizations.get(norm_type)(in_chan) | |
def forward(self, x): | |
B, F, T = x.size() | |
output = self.transformer(x.permute(0, 2, 1).contiguous()) # B, T, N | |
output = output.permute(0, 2, 1).contiguous() # B, N, T | |
output = self.norm(output) | |
return output + x | |
# ==================================================== | |
class DPRNNBlock(nn.Module): | |
def __init__( | |
self, | |
in_chan, | |
hid_size, | |
norm_type="gLN", | |
bidirectional=True, | |
rnn_type="LSTM", | |
num_layers=1, | |
dropout=0, | |
): | |
super(DPRNNBlock, self).__init__() | |
self.intra_RNN = SingleRNN( | |
rnn_type, | |
in_chan, | |
hid_size, | |
num_layers, | |
dropout=dropout, | |
bidirectional=True, | |
) | |
self.inter_RNN = SingleRNN( | |
rnn_type, | |
in_chan, | |
hid_size, | |
num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
) | |
self.intra_linear = nn.Linear(self.intra_RNN.output_size, in_chan) | |
self.intra_norm = normalizations.get(norm_type)(in_chan) | |
self.inter_linear = nn.Linear(self.inter_RNN.output_size, in_chan) | |
self.inter_norm = normalizations.get(norm_type)(in_chan) | |
def forward(self, x): | |
""" Input shape : [batch, feats, chunk_size, num_chunks] """ | |
B, N, K, L = x.size() | |
output = x # for skip connection | |
# Intra-chunk processing | |
x = x.transpose(1, -1).reshape(B * L, K, N) | |
x = self.intra_RNN(x) | |
x = self.intra_linear(x) | |
x = x.reshape(B, L, K, N).transpose(1, -1) | |
x = self.intra_norm(x) | |
output = output + x | |
# Inter-chunk processing | |
x = output.transpose(1, 2).transpose(2, -1).reshape(B * K, L, N) | |
x = self.inter_RNN(x) | |
x = self.inter_linear(x) | |
x = x.reshape(B, K, L, N).transpose(1, -1).transpose(2, -1).contiguous() | |
x = self.inter_norm(x) | |
return output + x | |
class DPRNN(nn.Module): | |
def __init__( | |
self, | |
in_chan, | |
n_src, | |
out_chan=None, | |
bn_chan=128, | |
hid_size=128, | |
chunk_size=100, | |
hop_size=None, | |
n_repeats=6, | |
norm_type="gLN", | |
mask_act="relu", | |
bidirectional=True, | |
rnn_type="LSTM", | |
num_layers=1, | |
dropout=0, | |
): | |
super(DPRNN, self).__init__() | |
self.in_chan = in_chan | |
out_chan = out_chan if out_chan is not None else in_chan | |
self.out_chan = out_chan | |
self.bn_chan = bn_chan | |
self.hid_size = hid_size | |
self.chunk_size = chunk_size | |
hop_size = hop_size if hop_size is not None else chunk_size // 2 | |
self.hop_size = hop_size | |
self.n_repeats = n_repeats | |
self.n_src = n_src | |
self.norm_type = norm_type | |
self.mask_act = mask_act | |
self.bidirectional = bidirectional | |
self.rnn_type = rnn_type | |
self.num_layers = num_layers | |
self.dropout = dropout | |
layer_norm = normalizations.get(norm_type)(in_chan) | |
bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) | |
self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) | |
# Succession of DPRNNBlocks. | |
net = [] | |
for x in range(self.n_repeats): | |
net += [ | |
DPRNNBlock( | |
bn_chan, | |
hid_size, | |
norm_type=norm_type, | |
bidirectional=bidirectional, | |
rnn_type=rnn_type, | |
num_layers=num_layers, | |
dropout=dropout, | |
) | |
] | |
self.net = nn.Sequential(*net) | |
# Masking in 3D space | |
net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) | |
self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) | |
# Gating and masking in 2D space (after fold) | |
self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh()) | |
self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) | |
self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False) | |
# Get activation function. | |
mask_nl_class = activations.get(mask_act) | |
# For softmax, feed the source dimension. | |
if has_arg(mask_nl_class, "dim"): | |
self.output_act = mask_nl_class(dim=1) | |
else: | |
self.output_act = mask_nl_class() | |
def forward(self, mixture_w): | |
r"""Forward. | |
Args: | |
mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ | |
Returns: | |
:class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ | |
""" | |
batch, n_filters, n_frames = mixture_w.size() | |
output = self.bottleneck(mixture_w) # [batch, bn_chan, n_frames] | |
output = unfold( | |
output.unsqueeze(-1), | |
kernel_size=(self.chunk_size, 1), | |
padding=(self.chunk_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
n_chunks = output.shape[-1] | |
output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks) | |
# Apply stacked DPRNN Blocks sequentially | |
output = self.net(output) | |
# Map to sources with kind of 2D masks | |
output = self.first_out(output) | |
output = output.reshape( | |
batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks | |
) | |
# Overlap and add: | |
# [batch, out_chan, chunk_size, n_chunks] -> [batch, out_chan, n_frames] | |
to_unfold = self.bn_chan * self.chunk_size | |
output = fold( | |
output.reshape(batch * self.n_src, to_unfold, n_chunks), | |
(n_frames, 1), | |
kernel_size=(self.chunk_size, 1), | |
padding=(self.chunk_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
# Apply gating | |
output = output.reshape(batch * self.n_src, self.bn_chan, -1) | |
# output = self.net_out(output) * self.net_gate(output) | |
# Compute mask | |
score = self.mask_net(output) | |
est_mask = self.output_act(score) | |
est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames) | |
return est_mask | |
def get_config(self): | |
config = { | |
"in_chan": self.in_chan, | |
"out_chan": self.out_chan, | |
"bn_chan": self.bn_chan, | |
"hid_size": self.hid_size, | |
"chunk_size": self.chunk_size, | |
"hop_size": self.hop_size, | |
"n_repeats": self.n_repeats, | |
"n_src": self.n_src, | |
"norm_type": self.norm_type, | |
"mask_act": self.mask_act, | |
"bidirectional": self.bidirectional, | |
"rnn_type": self.rnn_type, | |
"num_layers": self.num_layers, | |
"dropout": self.dropout, | |
} | |
return config | |
class DPRNNLinear(nn.Module): | |
def __init__( | |
self, | |
in_chan, | |
n_src, | |
out_chan=None, | |
bn_chan=128, | |
hid_size=128, | |
chunk_size=100, | |
hop_size=None, | |
n_repeats=6, | |
norm_type="gLN", | |
mask_act="relu", | |
bidirectional=True, | |
rnn_type="LSTM", | |
num_layers=1, | |
dropout=0, | |
): | |
super(DPRNNLinear, self).__init__() | |
self.in_chan = in_chan | |
out_chan = out_chan if out_chan is not None else in_chan | |
self.out_chan = out_chan | |
self.bn_chan = bn_chan | |
self.hid_size = hid_size | |
self.chunk_size = chunk_size | |
hop_size = hop_size if hop_size is not None else chunk_size // 2 | |
self.hop_size = hop_size | |
self.n_repeats = n_repeats | |
self.n_src = n_src | |
self.norm_type = norm_type | |
self.mask_act = mask_act | |
self.bidirectional = bidirectional | |
self.rnn_type = rnn_type | |
self.num_layers = num_layers | |
self.dropout = dropout | |
layer_norm = normalizations.get(norm_type)(in_chan) | |
bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) | |
self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) | |
# Succession of DPRNNBlocks. | |
net = [] | |
for x in range(self.n_repeats): | |
net += [ | |
DPRNNBlock( | |
bn_chan, | |
hid_size, | |
norm_type=norm_type, | |
bidirectional=bidirectional, | |
rnn_type=rnn_type, | |
num_layers=num_layers, | |
dropout=dropout, | |
) | |
] | |
self.net = nn.Sequential(*net) | |
# Masking in 3D space | |
net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) | |
self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) | |
# Gating and masking in 2D space (after fold) | |
# self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh()) | |
self.net_out = nn.Linear(bn_chan, out_chan) | |
self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) | |
self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False) | |
# Get activation function. | |
mask_nl_class = activations.get(mask_act) | |
# For softmax, feed the source dimension. | |
if has_arg(mask_nl_class, "dim"): | |
self.output_act = mask_nl_class(dim=1) | |
else: | |
self.output_act = mask_nl_class() | |
def forward(self, mixture_w): | |
r"""Forward. | |
Args: | |
mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ | |
Returns: | |
:class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ | |
""" | |
batch, n_filters, n_frames = mixture_w.size() | |
output = self.bottleneck(mixture_w) # [batch, bn_chan, n_frames] | |
output = unfold( | |
output.unsqueeze(-1), | |
kernel_size=(self.chunk_size, 1), | |
padding=(self.chunk_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
n_chunks = output.shape[-1] | |
output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks) | |
# Apply stacked DPRNN Blocks sequentially | |
output = self.net(output) | |
# Map to sources with kind of 2D masks | |
output = self.first_out(output) | |
output = output.reshape( | |
batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks | |
) | |
# Overlap and add: | |
# [batch, out_chan, chunk_size, n_chunks] -> [batch, out_chan, n_frames] | |
to_unfold = self.bn_chan * self.chunk_size | |
output = fold( | |
output.reshape(batch * self.n_src, to_unfold, n_chunks), | |
(n_frames, 1), | |
kernel_size=(self.chunk_size, 1), | |
padding=(self.chunk_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
# Apply gating | |
output = output.reshape(batch * self.n_src, self.bn_chan, -1) | |
output = self.net_out(output.transpose(1, 1)).transpose(1, 2) * self.net_gate( | |
output | |
) | |
# Compute mask | |
score = self.mask_net(output) | |
est_mask = self.output_act(score) | |
est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames) | |
return est_mask | |
def get_config(self): | |
config = { | |
"in_chan": self.in_chan, | |
"out_chan": self.out_chan, | |
"bn_chan": self.bn_chan, | |
"hid_size": self.hid_size, | |
"chunk_size": self.chunk_size, | |
"hop_size": self.hop_size, | |
"n_repeats": self.n_repeats, | |
"n_src": self.n_src, | |
"norm_type": self.norm_type, | |
"mask_act": self.mask_act, | |
"bidirectional": self.bidirectional, | |
"rnn_type": self.rnn_type, | |
"num_layers": self.num_layers, | |
"dropout": self.dropout, | |
} | |
return config | |