Spaces:
Running
Running
File size: 17,568 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 |
import warnings
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
def make_enc_dec(
fb_name,
n_filters,
kernel_size,
stride=None,
sample_rate=8000.0,
who_is_pinv=None,
padding=0,
output_padding=0,
**kwargs,
):
"""Creates congruent encoder and decoder from the same filterbank family.
Args:
fb_name (str, className): Filterbank family from which to make encoder
and decoder. To choose among [``'free'``, ``'analytic_free'``,
``'param_sinc'``, ``'stft'``]. Can also be a class defined in a
submodule in this subpackade (e.g. :class:`~.FreeFB`).
n_filters (int): Number of filters.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the convolution.
If None (default), set to ``kernel_size // 2``.
sample_rate (float): Sample rate of the expected audio.
Defaults to 8000.0.
who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will
be used. If string (among [``'encoder'``, ``'decoder'``]), decides
which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of
the other one.
padding (int): Zero-padding added to both sides of the input.
Passed to Encoder and Decoder.
output_padding (int): Additional size added to one side of the output shape.
Passed to Decoder.
**kwargs: Arguments which will be passed to the filterbank class
additionally to the usual `n_filters`, `kernel_size` and `stride`.
Depends on the filterbank family.
Returns:
:class:`.Encoder`, :class:`.Decoder`
"""
fb_class = get(fb_name)
if who_is_pinv in ["dec", "decoder"]:
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
enc = Encoder(fb, padding=padding)
# Decoder filterbank is pseudo inverse of encoder filterbank.
dec = Decoder.pinv_of(fb)
elif who_is_pinv in ["enc", "encoder"]:
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
dec = Decoder(fb, padding=padding, output_padding=output_padding)
# Encoder filterbank is pseudo inverse of decoder filterbank.
enc = Encoder.pinv_of(fb)
else:
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
enc = Encoder(fb, padding=padding)
# Filters between encoder and decoder should not be shared.
fb = fb_class(
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
)
dec = Decoder(fb, padding=padding, output_padding=output_padding)
return enc, dec
def register_filterbank(custom_fb):
"""Register a custom filterbank, gettable with `filterbanks.get`.
Args:
custom_fb: Custom filterbank to register.
"""
if (
custom_fb.__name__ in globals().keys()
or custom_fb.__name__.lower() in globals().keys()
):
raise ValueError(
f"Filterbank {custom_fb.__name__} already exists. Choose another name."
)
globals().update({custom_fb.__name__: custom_fb})
def get(identifier):
"""Returns a filterbank class from a string. Returns its input if it
is callable (already a :class:`.Filterbank` for example).
Args:
identifier (str or Callable or None): the filterbank identifier.
Returns:
:class:`.Filterbank` or None
"""
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError(
"Could not interpret filterbank identifier: " + str(identifier)
)
return cls
else:
raise ValueError(
"Could not interpret filterbank identifier: " + str(identifier)
)
class Filterbank(nn.Module):
"""Base Filterbank class.
Each subclass has to implement a ``filters`` method.
Args:
n_filters (int): Number of filters.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the conv or transposed conv. (Hop size).
If None (default), set to ``kernel_size // 2``.
sample_rate (float): Sample rate of the expected audio.
Defaults to 8000.
Attributes:
n_feats_out (int): Number of output filters.
"""
def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0):
super(Filterbank, self).__init__()
self.n_filters = n_filters
self.kernel_size = kernel_size
self.stride = stride if stride else self.kernel_size // 2
# If not specified otherwise in the filterbank's init, output
# number of features is equal to number of required filters.
self.n_feats_out = n_filters
self.sample_rate = sample_rate
def filters(self):
"""Abstract method for filters."""
raise NotImplementedError
def pre_analysis(self, wav: torch.Tensor):
"""Apply transform before encoder convolution."""
return wav
def post_analysis(self, spec: torch.Tensor):
"""Apply transform to encoder convolution."""
return spec
def pre_synthesis(self, spec: torch.Tensor):
"""Apply transform before decoder transposed convolution."""
return spec
def post_synthesis(self, wav: torch.Tensor):
"""Apply transform after decoder transposed convolution."""
return wav
def get_config(self):
"""Returns dictionary of arguments to re-instantiate the class.
Needs to be subclassed if the filterbanks takes additional arguments
than ``n_filters`` ``kernel_size`` ``stride`` and ``sample_rate``.
"""
config = {
"fb_name": self.__class__.__name__,
"n_filters": self.n_filters,
"kernel_size": self.kernel_size,
"stride": self.stride,
"sample_rate": self.sample_rate,
}
return config
def forward(self, waveform):
raise NotImplementedError(
"Filterbanks must be wrapped with an Encoder or a Decoder."
)
class _EncDec(nn.Module):
"""Base private class for Encoder and Decoder.
Common parameters and methods.
Args:
filterbank (:class:`Filterbank`): Filterbank instance. The filterbank
to use as an encoder or a decoder.
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
Attributes:
filterbank (:class:`Filterbank`)
stride (int)
is_pinv (bool)
"""
def __init__(self, filterbank, is_pinv=False):
super(_EncDec, self).__init__()
self.filterbank = filterbank
self.sample_rate = getattr(filterbank, "sample_rate", None)
self.stride = self.filterbank.stride
self.is_pinv = is_pinv
def filters(self):
return self.filterbank.filters()
def compute_filter_pinv(self, filters):
"""Computes pseudo inverse filterbank of given filters."""
scale = self.filterbank.stride / self.filterbank.kernel_size
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale
def get_filters(self):
"""Returns filters or pinv filters depending on `is_pinv` attribute"""
if self.is_pinv:
return self.compute_filter_pinv(self.filters())
else:
return self.filters()
def get_config(self):
"""Returns dictionary of arguments to re-instantiate the class."""
config = {"is_pinv": self.is_pinv}
base_config = self.filterbank.get_config()
return dict(list(base_config.items()) + list(config.items()))
class Encoder(_EncDec):
r"""Encoder class.
Add encoding methods to Filterbank classes.
Not intended to be subclassed.
Args:
filterbank (:class:`Filterbank`): The filterbank to use
as an encoder.
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
as_conv1d (bool): Whether to behave like nn.Conv1d.
If True (default), forwarding input with shape :math:`(batch, 1, time)`
will output a tensor of shape :math:`(batch, freq, conv\_time)`.
If False, will output a tensor of shape :math:`(batch, 1, freq, conv\_time)`.
padding (int): Zero-padding added to both sides of the input.
"""
def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0):
super(Encoder, self).__init__(filterbank, is_pinv=is_pinv)
self.as_conv1d = as_conv1d
self.n_feats_out = self.filterbank.n_feats_out
self.kernel_size = self.filterbank.kernel_size
self.padding = padding
@classmethod
def pinv_of(cls, filterbank, **kwargs):
"""Returns an :class:`~.Encoder`, pseudo inverse of a
:class:`~.Filterbank` or :class:`~.Decoder`."""
if isinstance(filterbank, Filterbank):
return cls(filterbank, is_pinv=True, **kwargs)
elif isinstance(filterbank, Decoder):
return cls(filterbank.filterbank, is_pinv=True, **kwargs)
def forward(self, waveform):
"""Convolve input waveform with the filters from a filterbank.
Args:
waveform (:class:`torch.Tensor`): any tensor with samples along the
last dimension. The waveform representation with and
batch/channel etc.. dimension.
Returns:
:class:`torch.Tensor`: The corresponding TF domain signal.
Shapes
>>> (time, ) -> (freq, conv_time)
>>> (batch, time) -> (batch, freq, conv_time) # Avoid
>>> if as_conv1d:
>>> (batch, 1, time) -> (batch, freq, conv_time)
>>> (batch, chan, time) -> (batch, chan, freq, conv_time)
>>> else:
>>> (batch, chan, time) -> (batch, chan, freq, conv_time)
>>> (batch, any, dim, time) -> (batch, any, dim, freq, conv_time)
"""
filters = self.get_filters()
waveform = self.filterbank.pre_analysis(waveform)
spec = multishape_conv1d(
waveform,
filters=filters,
stride=self.stride,
padding=self.padding,
as_conv1d=self.as_conv1d,
)
return self.filterbank.post_analysis(spec)
def multishape_conv1d(
waveform: torch.Tensor,
filters: torch.Tensor,
stride: int,
padding: int = 0,
as_conv1d: bool = True,
) -> torch.Tensor:
if waveform.ndim == 1:
# Assumes 1D input with shape (time,)
# Output will be (freq, conv_time)
return F.conv1d(
waveform[None, None], filters, stride=stride, padding=padding
).squeeze()
elif waveform.ndim == 2:
# Assume 2D input with shape (batch or channels, time)
# Output will be (batch or channels, freq, conv_time)
warnings.warn(
"Input tensor was 2D. Applying the corresponding "
"Decoder to the current output will result in a 3D "
"tensor. This behaviours was introduced to match "
"Conv1D and ConvTranspose1D, please use 3D inputs "
"to avoid it. For example, this can be done with "
"input_tensor.unsqueeze(1)."
)
return F.conv1d(waveform.unsqueeze(1), filters, stride=stride, padding=padding)
elif waveform.ndim == 3:
batch, channels, time_len = waveform.shape
if channels == 1 and as_conv1d:
# That's the common single channel case (batch, 1, time)
# Output will be (batch, freq, stft_time), behaves as Conv1D
return F.conv1d(waveform, filters, stride=stride, padding=padding)
else:
# Return batched convolution, input is (batch, 3, time), output will be
# (b, 3, f, conv_t). Useful for multichannel transforms. If as_conv1d is
# false, (batch, 1, time) will output (batch, 1, freq, conv_time), useful for
# consistency.
return batch_packed_1d_conv(
waveform, filters, stride=stride, padding=padding
)
else: # waveform.ndim > 3
# This is to compute "multi"multichannel convolution.
# Input can be (*, time), output will be (*, freq, conv_time)
return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding)
def batch_packed_1d_conv(
inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0
):
# Here we perform multichannel / multi-source convolution.
# Output should be (batch, channels, freq, conv_time)
batched_conv = F.conv1d(
inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding
)
output_shape = inp.shape[:-1] + batched_conv.shape[-2:]
return batched_conv.view(output_shape)
class Decoder(_EncDec):
"""Decoder class.
Add decoding methods to Filterbank classes.
Not intended to be subclassed.
Args:
filterbank (:class:`Filterbank`): The filterbank to use as an decoder.
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
padding (int): Zero-padding added to both sides of the input.
output_padding (int): Additional size added to one side of the
output shape.
.. note::
``padding`` and ``output_padding`` arguments are directly passed to
``F.conv_transpose1d``.
"""
def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):
super().__init__(filterbank, is_pinv=is_pinv)
self.padding = padding
self.output_padding = output_padding
@classmethod
def pinv_of(cls, filterbank):
"""Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
if isinstance(filterbank, Filterbank):
return cls(filterbank, is_pinv=True)
elif isinstance(filterbank, Encoder):
return cls(filterbank.filterbank, is_pinv=True)
def forward(self, spec, length: Optional[int] = None) -> torch.Tensor:
"""Applies transposed convolution to a TF representation.
This is equivalent to overlap-add.
Args:
spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
representation. (Output of :func:`Encoder.forward`).
length: desired output length.
Returns:
:class:`torch.Tensor`: The corresponding time domain signal.
"""
filters = self.get_filters()
spec = self.filterbank.pre_synthesis(spec)
wav = multishape_conv_transpose1d(
spec,
filters,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
)
wav = self.filterbank.post_synthesis(wav)
if length is not None:
length = min(length, wav.shape[-1])
return wav[..., :length]
return wav
def multishape_conv_transpose1d(
spec: torch.Tensor,
filters: torch.Tensor,
stride: int = 1,
padding: int = 0,
output_padding: int = 0,
) -> torch.Tensor:
if spec.ndim == 2:
# Input is (freq, conv_time), output is (time)
return F.conv_transpose1d(
spec.unsqueeze(0),
filters,
stride=stride,
padding=padding,
output_padding=output_padding,
).squeeze()
if spec.ndim == 3:
# Input is (batch, freq, conv_time), output is (batch, 1, time)
return F.conv_transpose1d(
spec,
filters,
stride=stride,
padding=padding,
output_padding=output_padding,
)
else:
# Multiply all the left dimensions together and group them in the
# batch. Make the convolution and restore.
view_as = (-1,) + spec.shape[-2:]
out = F.conv_transpose1d(
spec.reshape(view_as),
filters,
stride=stride,
padding=padding,
output_padding=output_padding,
)
return out.view(spec.shape[:-2] + (-1,))
class FreeFB(Filterbank):
"""Free filterbank without any constraints. Equivalent to
:class:`nn.Conv1d`.
Args:
n_filters (int): Number of filters.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the convolution.
If None (default), set to ``kernel_size // 2``.
sample_rate (float): Sample rate of the expected audio.
Defaults to 8000.
Attributes:
n_feats_out (int): Number of output filters.
References
[1] : "Filterbank design for end-to-end speech separation". ICASSP 2020.
Manuel Pariente, Samuele Cornell, Antoine Deleforge, Emmanuel Vincent.
"""
def __init__(
self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kwargs
):
super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate)
self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size))
for p in self.parameters():
nn.init.xavier_normal_(p)
def filters(self):
return self._filters
free = FreeFB
|