Spaces:
Running
Running
import warnings | |
from typing import Optional | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
def make_enc_dec( | |
fb_name, | |
n_filters, | |
kernel_size, | |
stride=None, | |
sample_rate=8000.0, | |
who_is_pinv=None, | |
padding=0, | |
output_padding=0, | |
**kwargs, | |
): | |
"""Creates congruent encoder and decoder from the same filterbank family. | |
Args: | |
fb_name (str, className): Filterbank family from which to make encoder | |
and decoder. To choose among [``'free'``, ``'analytic_free'``, | |
``'param_sinc'``, ``'stft'``]. Can also be a class defined in a | |
submodule in this subpackade (e.g. :class:`~.FreeFB`). | |
n_filters (int): Number of filters. | |
kernel_size (int): Length of the filters. | |
stride (int, optional): Stride of the convolution. | |
If None (default), set to ``kernel_size // 2``. | |
sample_rate (float): Sample rate of the expected audio. | |
Defaults to 8000.0. | |
who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will | |
be used. If string (among [``'encoder'``, ``'decoder'``]), decides | |
which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of | |
the other one. | |
padding (int): Zero-padding added to both sides of the input. | |
Passed to Encoder and Decoder. | |
output_padding (int): Additional size added to one side of the output shape. | |
Passed to Decoder. | |
**kwargs: Arguments which will be passed to the filterbank class | |
additionally to the usual `n_filters`, `kernel_size` and `stride`. | |
Depends on the filterbank family. | |
Returns: | |
:class:`.Encoder`, :class:`.Decoder` | |
""" | |
fb_class = get(fb_name) | |
if who_is_pinv in ["dec", "decoder"]: | |
fb = fb_class( | |
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs | |
) | |
enc = Encoder(fb, padding=padding) | |
# Decoder filterbank is pseudo inverse of encoder filterbank. | |
dec = Decoder.pinv_of(fb) | |
elif who_is_pinv in ["enc", "encoder"]: | |
fb = fb_class( | |
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs | |
) | |
dec = Decoder(fb, padding=padding, output_padding=output_padding) | |
# Encoder filterbank is pseudo inverse of decoder filterbank. | |
enc = Encoder.pinv_of(fb) | |
else: | |
fb = fb_class( | |
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs | |
) | |
enc = Encoder(fb, padding=padding) | |
# Filters between encoder and decoder should not be shared. | |
fb = fb_class( | |
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs | |
) | |
dec = Decoder(fb, padding=padding, output_padding=output_padding) | |
return enc, dec | |
def register_filterbank(custom_fb): | |
"""Register a custom filterbank, gettable with `filterbanks.get`. | |
Args: | |
custom_fb: Custom filterbank to register. | |
""" | |
if ( | |
custom_fb.__name__ in globals().keys() | |
or custom_fb.__name__.lower() in globals().keys() | |
): | |
raise ValueError( | |
f"Filterbank {custom_fb.__name__} already exists. Choose another name." | |
) | |
globals().update({custom_fb.__name__: custom_fb}) | |
def get(identifier): | |
"""Returns a filterbank class from a string. Returns its input if it | |
is callable (already a :class:`.Filterbank` for example). | |
Args: | |
identifier (str or Callable or None): the filterbank identifier. | |
Returns: | |
:class:`.Filterbank` or None | |
""" | |
if identifier is None: | |
return None | |
elif callable(identifier): | |
return identifier | |
elif isinstance(identifier, str): | |
cls = globals().get(identifier) | |
if cls is None: | |
raise ValueError( | |
"Could not interpret filterbank identifier: " + str(identifier) | |
) | |
return cls | |
else: | |
raise ValueError( | |
"Could not interpret filterbank identifier: " + str(identifier) | |
) | |
class Filterbank(nn.Module): | |
"""Base Filterbank class. | |
Each subclass has to implement a ``filters`` method. | |
Args: | |
n_filters (int): Number of filters. | |
kernel_size (int): Length of the filters. | |
stride (int, optional): Stride of the conv or transposed conv. (Hop size). | |
If None (default), set to ``kernel_size // 2``. | |
sample_rate (float): Sample rate of the expected audio. | |
Defaults to 8000. | |
Attributes: | |
n_feats_out (int): Number of output filters. | |
""" | |
def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0): | |
super(Filterbank, self).__init__() | |
self.n_filters = n_filters | |
self.kernel_size = kernel_size | |
self.stride = stride if stride else self.kernel_size // 2 | |
# If not specified otherwise in the filterbank's init, output | |
# number of features is equal to number of required filters. | |
self.n_feats_out = n_filters | |
self.sample_rate = sample_rate | |
def filters(self): | |
"""Abstract method for filters.""" | |
raise NotImplementedError | |
def pre_analysis(self, wav: torch.Tensor): | |
"""Apply transform before encoder convolution.""" | |
return wav | |
def post_analysis(self, spec: torch.Tensor): | |
"""Apply transform to encoder convolution.""" | |
return spec | |
def pre_synthesis(self, spec: torch.Tensor): | |
"""Apply transform before decoder transposed convolution.""" | |
return spec | |
def post_synthesis(self, wav: torch.Tensor): | |
"""Apply transform after decoder transposed convolution.""" | |
return wav | |
def get_config(self): | |
"""Returns dictionary of arguments to re-instantiate the class. | |
Needs to be subclassed if the filterbanks takes additional arguments | |
than ``n_filters`` ``kernel_size`` ``stride`` and ``sample_rate``. | |
""" | |
config = { | |
"fb_name": self.__class__.__name__, | |
"n_filters": self.n_filters, | |
"kernel_size": self.kernel_size, | |
"stride": self.stride, | |
"sample_rate": self.sample_rate, | |
} | |
return config | |
def forward(self, waveform): | |
raise NotImplementedError( | |
"Filterbanks must be wrapped with an Encoder or a Decoder." | |
) | |
class _EncDec(nn.Module): | |
"""Base private class for Encoder and Decoder. | |
Common parameters and methods. | |
Args: | |
filterbank (:class:`Filterbank`): Filterbank instance. The filterbank | |
to use as an encoder or a decoder. | |
is_pinv (bool): Whether to be the pseudo inverse of filterbank. | |
Attributes: | |
filterbank (:class:`Filterbank`) | |
stride (int) | |
is_pinv (bool) | |
""" | |
def __init__(self, filterbank, is_pinv=False): | |
super(_EncDec, self).__init__() | |
self.filterbank = filterbank | |
self.sample_rate = getattr(filterbank, "sample_rate", None) | |
self.stride = self.filterbank.stride | |
self.is_pinv = is_pinv | |
def filters(self): | |
return self.filterbank.filters() | |
def compute_filter_pinv(self, filters): | |
"""Computes pseudo inverse filterbank of given filters.""" | |
scale = self.filterbank.stride / self.filterbank.kernel_size | |
shape = filters.shape | |
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape) | |
# Compensate for the overlap-add. | |
return ifilt * scale | |
def get_filters(self): | |
"""Returns filters or pinv filters depending on `is_pinv` attribute""" | |
if self.is_pinv: | |
return self.compute_filter_pinv(self.filters()) | |
else: | |
return self.filters() | |
def get_config(self): | |
"""Returns dictionary of arguments to re-instantiate the class.""" | |
config = {"is_pinv": self.is_pinv} | |
base_config = self.filterbank.get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class Encoder(_EncDec): | |
r"""Encoder class. | |
Add encoding methods to Filterbank classes. | |
Not intended to be subclassed. | |
Args: | |
filterbank (:class:`Filterbank`): The filterbank to use | |
as an encoder. | |
is_pinv (bool): Whether to be the pseudo inverse of filterbank. | |
as_conv1d (bool): Whether to behave like nn.Conv1d. | |
If True (default), forwarding input with shape :math:`(batch, 1, time)` | |
will output a tensor of shape :math:`(batch, freq, conv\_time)`. | |
If False, will output a tensor of shape :math:`(batch, 1, freq, conv\_time)`. | |
padding (int): Zero-padding added to both sides of the input. | |
""" | |
def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0): | |
super(Encoder, self).__init__(filterbank, is_pinv=is_pinv) | |
self.as_conv1d = as_conv1d | |
self.n_feats_out = self.filterbank.n_feats_out | |
self.kernel_size = self.filterbank.kernel_size | |
self.padding = padding | |
def pinv_of(cls, filterbank, **kwargs): | |
"""Returns an :class:`~.Encoder`, pseudo inverse of a | |
:class:`~.Filterbank` or :class:`~.Decoder`.""" | |
if isinstance(filterbank, Filterbank): | |
return cls(filterbank, is_pinv=True, **kwargs) | |
elif isinstance(filterbank, Decoder): | |
return cls(filterbank.filterbank, is_pinv=True, **kwargs) | |
def forward(self, waveform): | |
"""Convolve input waveform with the filters from a filterbank. | |
Args: | |
waveform (:class:`torch.Tensor`): any tensor with samples along the | |
last dimension. The waveform representation with and | |
batch/channel etc.. dimension. | |
Returns: | |
:class:`torch.Tensor`: The corresponding TF domain signal. | |
Shapes | |
>>> (time, ) -> (freq, conv_time) | |
>>> (batch, time) -> (batch, freq, conv_time) # Avoid | |
>>> if as_conv1d: | |
>>> (batch, 1, time) -> (batch, freq, conv_time) | |
>>> (batch, chan, time) -> (batch, chan, freq, conv_time) | |
>>> else: | |
>>> (batch, chan, time) -> (batch, chan, freq, conv_time) | |
>>> (batch, any, dim, time) -> (batch, any, dim, freq, conv_time) | |
""" | |
filters = self.get_filters() | |
waveform = self.filterbank.pre_analysis(waveform) | |
spec = multishape_conv1d( | |
waveform, | |
filters=filters, | |
stride=self.stride, | |
padding=self.padding, | |
as_conv1d=self.as_conv1d, | |
) | |
return self.filterbank.post_analysis(spec) | |
def multishape_conv1d( | |
waveform: torch.Tensor, | |
filters: torch.Tensor, | |
stride: int, | |
padding: int = 0, | |
as_conv1d: bool = True, | |
) -> torch.Tensor: | |
if waveform.ndim == 1: | |
# Assumes 1D input with shape (time,) | |
# Output will be (freq, conv_time) | |
return F.conv1d( | |
waveform[None, None], filters, stride=stride, padding=padding | |
).squeeze() | |
elif waveform.ndim == 2: | |
# Assume 2D input with shape (batch or channels, time) | |
# Output will be (batch or channels, freq, conv_time) | |
warnings.warn( | |
"Input tensor was 2D. Applying the corresponding " | |
"Decoder to the current output will result in a 3D " | |
"tensor. This behaviours was introduced to match " | |
"Conv1D and ConvTranspose1D, please use 3D inputs " | |
"to avoid it. For example, this can be done with " | |
"input_tensor.unsqueeze(1)." | |
) | |
return F.conv1d(waveform.unsqueeze(1), filters, stride=stride, padding=padding) | |
elif waveform.ndim == 3: | |
batch, channels, time_len = waveform.shape | |
if channels == 1 and as_conv1d: | |
# That's the common single channel case (batch, 1, time) | |
# Output will be (batch, freq, stft_time), behaves as Conv1D | |
return F.conv1d(waveform, filters, stride=stride, padding=padding) | |
else: | |
# Return batched convolution, input is (batch, 3, time), output will be | |
# (b, 3, f, conv_t). Useful for multichannel transforms. If as_conv1d is | |
# false, (batch, 1, time) will output (batch, 1, freq, conv_time), useful for | |
# consistency. | |
return batch_packed_1d_conv( | |
waveform, filters, stride=stride, padding=padding | |
) | |
else: # waveform.ndim > 3 | |
# This is to compute "multi"multichannel convolution. | |
# Input can be (*, time), output will be (*, freq, conv_time) | |
return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding) | |
def batch_packed_1d_conv( | |
inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0 | |
): | |
# Here we perform multichannel / multi-source convolution. | |
# Output should be (batch, channels, freq, conv_time) | |
batched_conv = F.conv1d( | |
inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding | |
) | |
output_shape = inp.shape[:-1] + batched_conv.shape[-2:] | |
return batched_conv.view(output_shape) | |
class Decoder(_EncDec): | |
"""Decoder class. | |
Add decoding methods to Filterbank classes. | |
Not intended to be subclassed. | |
Args: | |
filterbank (:class:`Filterbank`): The filterbank to use as an decoder. | |
is_pinv (bool): Whether to be the pseudo inverse of filterbank. | |
padding (int): Zero-padding added to both sides of the input. | |
output_padding (int): Additional size added to one side of the | |
output shape. | |
.. note:: | |
``padding`` and ``output_padding`` arguments are directly passed to | |
``F.conv_transpose1d``. | |
""" | |
def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0): | |
super().__init__(filterbank, is_pinv=is_pinv) | |
self.padding = padding | |
self.output_padding = output_padding | |
def pinv_of(cls, filterbank): | |
"""Returns an Decoder, pseudo inverse of a filterbank or Encoder.""" | |
if isinstance(filterbank, Filterbank): | |
return cls(filterbank, is_pinv=True) | |
elif isinstance(filterbank, Encoder): | |
return cls(filterbank.filterbank, is_pinv=True) | |
def forward(self, spec, length: Optional[int] = None) -> torch.Tensor: | |
"""Applies transposed convolution to a TF representation. | |
This is equivalent to overlap-add. | |
Args: | |
spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF | |
representation. (Output of :func:`Encoder.forward`). | |
length: desired output length. | |
Returns: | |
:class:`torch.Tensor`: The corresponding time domain signal. | |
""" | |
filters = self.get_filters() | |
spec = self.filterbank.pre_synthesis(spec) | |
wav = multishape_conv_transpose1d( | |
spec, | |
filters, | |
stride=self.stride, | |
padding=self.padding, | |
output_padding=self.output_padding, | |
) | |
wav = self.filterbank.post_synthesis(wav) | |
if length is not None: | |
length = min(length, wav.shape[-1]) | |
return wav[..., :length] | |
return wav | |
def multishape_conv_transpose1d( | |
spec: torch.Tensor, | |
filters: torch.Tensor, | |
stride: int = 1, | |
padding: int = 0, | |
output_padding: int = 0, | |
) -> torch.Tensor: | |
if spec.ndim == 2: | |
# Input is (freq, conv_time), output is (time) | |
return F.conv_transpose1d( | |
spec.unsqueeze(0), | |
filters, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
).squeeze() | |
if spec.ndim == 3: | |
# Input is (batch, freq, conv_time), output is (batch, 1, time) | |
return F.conv_transpose1d( | |
spec, | |
filters, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
) | |
else: | |
# Multiply all the left dimensions together and group them in the | |
# batch. Make the convolution and restore. | |
view_as = (-1,) + spec.shape[-2:] | |
out = F.conv_transpose1d( | |
spec.reshape(view_as), | |
filters, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
) | |
return out.view(spec.shape[:-2] + (-1,)) | |
class FreeFB(Filterbank): | |
"""Free filterbank without any constraints. Equivalent to | |
:class:`nn.Conv1d`. | |
Args: | |
n_filters (int): Number of filters. | |
kernel_size (int): Length of the filters. | |
stride (int, optional): Stride of the convolution. | |
If None (default), set to ``kernel_size // 2``. | |
sample_rate (float): Sample rate of the expected audio. | |
Defaults to 8000. | |
Attributes: | |
n_feats_out (int): Number of output filters. | |
References | |
[1] : "Filterbank design for end-to-end speech separation". ICASSP 2020. | |
Manuel Pariente, Samuele Cornell, Antoine Deleforge, Emmanuel Vincent. | |
""" | |
def __init__( | |
self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kwargs | |
): | |
super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate) | |
self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size)) | |
for p in self.parameters(): | |
nn.init.xavier_normal_(p) | |
def filters(self): | |
return self._filters | |
free = FreeFB | |