fffiloni's picture
Migrated from GitHub
406f22d verified
import warnings
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
def make_enc_dec(
fb_name,
n_filters,
kernel_size,
stride=None,
sample_rate=8000.0,
who_is_pinv=None,
padding=0,
output_padding=0,
**kwargs,
):
"""Creates congruent encoder and decoder from the same filterbank family.
Args:
fb_name (str, className): Filterbank family from which to make encoder
and decoder. To choose among [``'free'``, ``'analytic_free'``,
``'param_sinc'``, ``'stft'``]. Can also be a class defined in a
submodule in this subpackade (e.g. :class:`~.FreeFB`).
n_filters (int): Number of filters.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the convolution.
If None (default), set to ``kernel_size // 2``.
sample_rate (float): Sample rate of the expected audio.
Defaults to 8000.0.
who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will
be used. If string (among [``'encoder'``, ``'decoder'``]), decides
which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of
the other one.
padding (int): Zero-padding added to both sides of the input.
Passed to Encoder and Decoder.
output_padding (int): Additional size added to one side of the output shape.
Passed to Decoder.
**kwargs: Arguments which will be passed to the filterbank class
additionally to the usual `n_filters`, `kernel_size` and `stride`.
Depends on the filterbank family.
Returns:
:class:`.Encoder`, :class:`.Decoder`
"""
fb_class = get(fb_name)
if who_is_pinv in ["dec", "decoder"]:
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
enc = Encoder(fb, padding=padding)
# Decoder filterbank is pseudo inverse of encoder filterbank.
dec = Decoder.pinv_of(fb)
elif who_is_pinv in ["enc", "encoder"]:
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
dec = Decoder(fb, padding=padding, output_padding=output_padding)
# Encoder filterbank is pseudo inverse of decoder filterbank.
enc = Encoder.pinv_of(fb)
else:
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
enc = Encoder(fb, padding=padding)
# Filters between encoder and decoder should not be shared.
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
dec = Decoder(fb, padding=padding, output_padding=output_padding)
return enc, dec
def register_filterbank(custom_fb):
"""Register a custom filterbank, gettable with `filterbanks.get`.
Args:
custom_fb: Custom filterbank to register.
"""
if (
custom_fb.__name__ in globals().keys()
or custom_fb.__name__.lower() in globals().keys()
):
raise ValueError(
f"Filterbank {custom_fb.__name__} already exists. Choose another name."
)
globals().update({custom_fb.__name__: custom_fb})
def get(identifier):
"""Returns a filterbank class from a string. Returns its input if it
is callable (already a :class:`.Filterbank` for example).
Args:
identifier (str or Callable or None): the filterbank identifier.
Returns:
:class:`.Filterbank` or None
"""
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError(
"Could not interpret filterbank identifier: " + str(identifier)
)
return cls
else:
raise ValueError(
"Could not interpret filterbank identifier: " + str(identifier)
)
class Filterbank(nn.Module):
"""Base Filterbank class.
Each subclass has to implement a ``filters`` method.
Args:
n_filters (int): Number of filters.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the conv or transposed conv. (Hop size).
If None (default), set to ``kernel_size // 2``.
sample_rate (float): Sample rate of the expected audio.
Defaults to 8000.
Attributes:
n_feats_out (int): Number of output filters.
"""
def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0):
super(Filterbank, self).__init__()
self.n_filters = n_filters
self.kernel_size = kernel_size
self.stride = stride if stride else self.kernel_size // 2
# If not specified otherwise in the filterbank's init, output
# number of features is equal to number of required filters.
self.n_feats_out = n_filters
self.sample_rate = sample_rate
def filters(self):
"""Abstract method for filters."""
raise NotImplementedError
def pre_analysis(self, wav: torch.Tensor):
"""Apply transform before encoder convolution."""
return wav
def post_analysis(self, spec: torch.Tensor):
"""Apply transform to encoder convolution."""
return spec
def pre_synthesis(self, spec: torch.Tensor):
"""Apply transform before decoder transposed convolution."""
return spec
def post_synthesis(self, wav: torch.Tensor):
"""Apply transform after decoder transposed convolution."""
return wav
def get_config(self):
"""Returns dictionary of arguments to re-instantiate the class.
Needs to be subclassed if the filterbanks takes additional arguments
than ``n_filters`` ``kernel_size`` ``stride`` and ``sample_rate``.
"""
config = {
"fb_name": self.__class__.__name__,
"n_filters": self.n_filters,
"kernel_size": self.kernel_size,
"stride": self.stride,
"sample_rate": self.sample_rate,
}
return config
def forward(self, waveform):
raise NotImplementedError(
"Filterbanks must be wrapped with an Encoder or a Decoder."
)
class _EncDec(nn.Module):
"""Base private class for Encoder and Decoder.
Common parameters and methods.
Args:
filterbank (:class:`Filterbank`): Filterbank instance. The filterbank
to use as an encoder or a decoder.
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
Attributes:
filterbank (:class:`Filterbank`)
stride (int)
is_pinv (bool)
"""
def __init__(self, filterbank, is_pinv=False):
super(_EncDec, self).__init__()
self.filterbank = filterbank
self.sample_rate = getattr(filterbank, "sample_rate", None)
self.stride = self.filterbank.stride
self.is_pinv = is_pinv
def filters(self):
return self.filterbank.filters()
def compute_filter_pinv(self, filters):
"""Computes pseudo inverse filterbank of given filters."""
scale = self.filterbank.stride / self.filterbank.kernel_size
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale
def get_filters(self):
"""Returns filters or pinv filters depending on `is_pinv` attribute"""
if self.is_pinv:
return self.compute_filter_pinv(self.filters())
else:
return self.filters()
def get_config(self):
"""Returns dictionary of arguments to re-instantiate the class."""
config = {"is_pinv": self.is_pinv}
base_config = self.filterbank.get_config()
return dict(list(base_config.items()) + list(config.items()))
class Encoder(_EncDec):
r"""Encoder class.
Add encoding methods to Filterbank classes.
Not intended to be subclassed.
Args:
filterbank (:class:`Filterbank`): The filterbank to use
as an encoder.
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
as_conv1d (bool): Whether to behave like nn.Conv1d.
If True (default), forwarding input with shape :math:`(batch, 1, time)`
will output a tensor of shape :math:`(batch, freq, conv\_time)`.
If False, will output a tensor of shape :math:`(batch, 1, freq, conv\_time)`.
padding (int): Zero-padding added to both sides of the input.
"""
def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0):
super(Encoder, self).__init__(filterbank, is_pinv=is_pinv)
self.as_conv1d = as_conv1d
self.n_feats_out = self.filterbank.n_feats_out
self.kernel_size = self.filterbank.kernel_size
self.padding = padding
@classmethod
def pinv_of(cls, filterbank, **kwargs):
"""Returns an :class:`~.Encoder`, pseudo inverse of a
:class:`~.Filterbank` or :class:`~.Decoder`."""
if isinstance(filterbank, Filterbank):
return cls(filterbank, is_pinv=True, **kwargs)
elif isinstance(filterbank, Decoder):
return cls(filterbank.filterbank, is_pinv=True, **kwargs)
def forward(self, waveform):
"""Convolve input waveform with the filters from a filterbank.
Args:
waveform (:class:`torch.Tensor`): any tensor with samples along the
last dimension. The waveform representation with and
batch/channel etc.. dimension.
Returns:
:class:`torch.Tensor`: The corresponding TF domain signal.
Shapes
>>> (time, ) -> (freq, conv_time)
>>> (batch, time) -> (batch, freq, conv_time) # Avoid
>>> if as_conv1d:
>>> (batch, 1, time) -> (batch, freq, conv_time)
>>> (batch, chan, time) -> (batch, chan, freq, conv_time)
>>> else:
>>> (batch, chan, time) -> (batch, chan, freq, conv_time)
>>> (batch, any, dim, time) -> (batch, any, dim, freq, conv_time)
"""
filters = self.get_filters()
waveform = self.filterbank.pre_analysis(waveform)
spec = multishape_conv1d(
waveform,
filters=filters,
stride=self.stride,
padding=self.padding,
as_conv1d=self.as_conv1d,
)
return self.filterbank.post_analysis(spec)
def multishape_conv1d(
waveform: torch.Tensor,
filters: torch.Tensor,
stride: int,
padding: int = 0,
as_conv1d: bool = True,
) -> torch.Tensor:
if waveform.ndim == 1:
# Assumes 1D input with shape (time,)
# Output will be (freq, conv_time)
return F.conv1d(
waveform[None, None], filters, stride=stride, padding=padding
).squeeze()
elif waveform.ndim == 2:
# Assume 2D input with shape (batch or channels, time)
# Output will be (batch or channels, freq, conv_time)
warnings.warn(
"Input tensor was 2D. Applying the corresponding "
"Decoder to the current output will result in a 3D "
"tensor. This behaviours was introduced to match "
"Conv1D and ConvTranspose1D, please use 3D inputs "
"to avoid it. For example, this can be done with "
"input_tensor.unsqueeze(1)."
)
return F.conv1d(waveform.unsqueeze(1), filters, stride=stride, padding=padding)
elif waveform.ndim == 3:
batch, channels, time_len = waveform.shape
if channels == 1 and as_conv1d:
# That's the common single channel case (batch, 1, time)
# Output will be (batch, freq, stft_time), behaves as Conv1D
return F.conv1d(waveform, filters, stride=stride, padding=padding)
else:
# Return batched convolution, input is (batch, 3, time), output will be
# (b, 3, f, conv_t). Useful for multichannel transforms. If as_conv1d is
# false, (batch, 1, time) will output (batch, 1, freq, conv_time), useful for
# consistency.
return batch_packed_1d_conv(
waveform, filters, stride=stride, padding=padding
)
else: # waveform.ndim > 3
# This is to compute "multi"multichannel convolution.
# Input can be (*, time), output will be (*, freq, conv_time)
return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding)
def batch_packed_1d_conv(
inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0
):
# Here we perform multichannel / multi-source convolution.
# Output should be (batch, channels, freq, conv_time)
batched_conv = F.conv1d(
inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding
)
output_shape = inp.shape[:-1] + batched_conv.shape[-2:]
return batched_conv.view(output_shape)
class Decoder(_EncDec):
"""Decoder class.
Add decoding methods to Filterbank classes.
Not intended to be subclassed.
Args:
filterbank (:class:`Filterbank`): The filterbank to use as an decoder.
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
padding (int): Zero-padding added to both sides of the input.
output_padding (int): Additional size added to one side of the
output shape.
.. note::
``padding`` and ``output_padding`` arguments are directly passed to
``F.conv_transpose1d``.
"""
def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):
super().__init__(filterbank, is_pinv=is_pinv)
self.padding = padding
self.output_padding = output_padding
@classmethod
def pinv_of(cls, filterbank):
"""Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
if isinstance(filterbank, Filterbank):
return cls(filterbank, is_pinv=True)
elif isinstance(filterbank, Encoder):
return cls(filterbank.filterbank, is_pinv=True)
def forward(self, spec, length: Optional[int] = None) -> torch.Tensor:
"""Applies transposed convolution to a TF representation.
This is equivalent to overlap-add.
Args:
spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
representation. (Output of :func:`Encoder.forward`).
length: desired output length.
Returns:
:class:`torch.Tensor`: The corresponding time domain signal.
"""
filters = self.get_filters()
spec = self.filterbank.pre_synthesis(spec)
wav = multishape_conv_transpose1d(
spec,
filters,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
)
wav = self.filterbank.post_synthesis(wav)
if length is not None:
length = min(length, wav.shape[-1])
return wav[..., :length]
return wav
def multishape_conv_transpose1d(
spec: torch.Tensor,
filters: torch.Tensor,
stride: int = 1,
padding: int = 0,
output_padding: int = 0,
) -> torch.Tensor:
if spec.ndim == 2:
# Input is (freq, conv_time), output is (time)
return F.conv_transpose1d(
spec.unsqueeze(0),
filters,
stride=stride,
padding=padding,
output_padding=output_padding,
).squeeze()
if spec.ndim == 3:
# Input is (batch, freq, conv_time), output is (batch, 1, time)
return F.conv_transpose1d(
spec,
filters,
stride=stride,
padding=padding,
output_padding=output_padding,
)
else:
# Multiply all the left dimensions together and group them in the
# batch. Make the convolution and restore.
view_as = (-1,) + spec.shape[-2:]
out = F.conv_transpose1d(
spec.reshape(view_as),
filters,
stride=stride,
padding=padding,
output_padding=output_padding,
)
return out.view(spec.shape[:-2] + (-1,))
class FreeFB(Filterbank):
"""Free filterbank without any constraints. Equivalent to
:class:`nn.Conv1d`.
Args:
n_filters (int): Number of filters.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the convolution.
If None (default), set to ``kernel_size // 2``.
sample_rate (float): Sample rate of the expected audio.
Defaults to 8000.
Attributes:
n_feats_out (int): Number of output filters.
References
[1] : "Filterbank design for end-to-end speech separation". ICASSP 2020.
Manuel Pariente, Samuele Cornell, Antoine Deleforge, Emmanuel Vincent.
"""
def __init__(
self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kwargs
):
super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate)
self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size))
for p in self.parameters():
nn.init.xavier_normal_(p)
def filters(self):
return self._filters
free = FreeFB