File size: 17,568 Bytes
406f22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import warnings
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F


def make_enc_dec(
    fb_name,
    n_filters,
    kernel_size,
    stride=None,
    sample_rate=8000.0,
    who_is_pinv=None,
    padding=0,
    output_padding=0,
    **kwargs,
):
    """Creates congruent encoder and decoder from the same filterbank family.
    Args:
        fb_name (str, className): Filterbank family from which to make encoder
            and decoder. To choose among [``'free'``, ``'analytic_free'``,
            ``'param_sinc'``, ``'stft'``]. Can also be a class defined in a
            submodule in this subpackade (e.g. :class:`~.FreeFB`).
        n_filters (int): Number of filters.
        kernel_size (int): Length of the filters.
        stride (int, optional): Stride of the convolution.
            If None (default), set to ``kernel_size // 2``.
        sample_rate (float): Sample rate of the expected audio.
            Defaults to 8000.0.
        who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will
            be used. If string (among [``'encoder'``, ``'decoder'``]), decides
            which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of
            the other one.
        padding (int): Zero-padding added to both sides of the input.
            Passed to Encoder and Decoder.
        output_padding (int): Additional size added to one side of the output shape.
            Passed to Decoder.
        **kwargs: Arguments which will be passed to the filterbank class
            additionally to the usual `n_filters`, `kernel_size` and `stride`.
            Depends on the filterbank family.
    Returns:
        :class:`.Encoder`, :class:`.Decoder`
    """
    fb_class = get(fb_name)

    if who_is_pinv in ["dec", "decoder"]:
        fb = fb_class(
            n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
        )
        enc = Encoder(fb, padding=padding)
        # Decoder filterbank is pseudo inverse of encoder filterbank.
        dec = Decoder.pinv_of(fb)
    elif who_is_pinv in ["enc", "encoder"]:
        fb = fb_class(
            n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
        )
        dec = Decoder(fb, padding=padding, output_padding=output_padding)
        # Encoder filterbank is pseudo inverse of decoder filterbank.
        enc = Encoder.pinv_of(fb)
    else:
        fb = fb_class(
            n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
        )
        enc = Encoder(fb, padding=padding)
        # Filters between encoder and decoder should not be shared.
        fb = fb_class(
            n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
        )
        dec = Decoder(fb, padding=padding, output_padding=output_padding)
    return enc, dec


def register_filterbank(custom_fb):
    """Register a custom filterbank, gettable with `filterbanks.get`.
    Args:
        custom_fb: Custom filterbank to register.
    """
    if (
        custom_fb.__name__ in globals().keys()
        or custom_fb.__name__.lower() in globals().keys()
    ):
        raise ValueError(
            f"Filterbank {custom_fb.__name__} already exists. Choose another name."
        )
    globals().update({custom_fb.__name__: custom_fb})


def get(identifier):
    """Returns a filterbank class from a string. Returns its input if it
    is callable (already a :class:`.Filterbank` for example).
    Args:
        identifier (str or Callable or None): the filterbank identifier.
    Returns:
        :class:`.Filterbank` or None
    """
    if identifier is None:
        return None
    elif callable(identifier):
        return identifier
    elif isinstance(identifier, str):
        cls = globals().get(identifier)
        if cls is None:
            raise ValueError(
                "Could not interpret filterbank identifier: " + str(identifier)
            )
        return cls
    else:
        raise ValueError(
            "Could not interpret filterbank identifier: " + str(identifier)
        )


class Filterbank(nn.Module):
    """Base Filterbank class.
    Each subclass has to implement a ``filters`` method.
    Args:
        n_filters (int): Number of filters.
        kernel_size (int): Length of the filters.
        stride (int, optional): Stride of the conv or transposed conv. (Hop size).
            If None (default), set to ``kernel_size // 2``.
        sample_rate (float): Sample rate of the expected audio.
            Defaults to 8000.
    Attributes:
        n_feats_out (int): Number of output filters.
    """

    def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0):
        super(Filterbank, self).__init__()
        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.stride = stride if stride else self.kernel_size // 2
        # If not specified otherwise in the filterbank's init, output
        # number of features is equal to number of required filters.
        self.n_feats_out = n_filters
        self.sample_rate = sample_rate

    def filters(self):
        """Abstract method for filters."""
        raise NotImplementedError

    def pre_analysis(self, wav: torch.Tensor):
        """Apply transform before encoder convolution."""
        return wav

    def post_analysis(self, spec: torch.Tensor):
        """Apply transform to encoder convolution."""
        return spec

    def pre_synthesis(self, spec: torch.Tensor):
        """Apply transform before decoder transposed convolution."""
        return spec

    def post_synthesis(self, wav: torch.Tensor):
        """Apply transform after decoder transposed convolution."""
        return wav

    def get_config(self):
        """Returns dictionary of arguments to re-instantiate the class.
        Needs to be subclassed if the filterbanks takes additional arguments
        than ``n_filters`` ``kernel_size`` ``stride`` and ``sample_rate``.
        """
        config = {
            "fb_name": self.__class__.__name__,
            "n_filters": self.n_filters,
            "kernel_size": self.kernel_size,
            "stride": self.stride,
            "sample_rate": self.sample_rate,
        }
        return config

    def forward(self, waveform):
        raise NotImplementedError(
            "Filterbanks must be wrapped with an Encoder or a Decoder."
        )


class _EncDec(nn.Module):
    """Base private class for Encoder and Decoder.
    Common parameters and methods.
    Args:
        filterbank (:class:`Filterbank`): Filterbank instance. The filterbank
            to use as an encoder or a decoder.
        is_pinv (bool): Whether to be the pseudo inverse of filterbank.
    Attributes:
        filterbank (:class:`Filterbank`)
        stride (int)
        is_pinv (bool)
    """

    def __init__(self, filterbank, is_pinv=False):
        super(_EncDec, self).__init__()
        self.filterbank = filterbank
        self.sample_rate = getattr(filterbank, "sample_rate", None)
        self.stride = self.filterbank.stride
        self.is_pinv = is_pinv

    def filters(self):
        return self.filterbank.filters()

    def compute_filter_pinv(self, filters):
        """Computes pseudo inverse filterbank of given filters."""
        scale = self.filterbank.stride / self.filterbank.kernel_size
        shape = filters.shape
        ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
        # Compensate for the overlap-add.
        return ifilt * scale

    def get_filters(self):
        """Returns filters or pinv filters depending on `is_pinv` attribute"""
        if self.is_pinv:
            return self.compute_filter_pinv(self.filters())
        else:
            return self.filters()

    def get_config(self):
        """Returns dictionary of arguments to re-instantiate the class."""
        config = {"is_pinv": self.is_pinv}
        base_config = self.filterbank.get_config()
        return dict(list(base_config.items()) + list(config.items()))


class Encoder(_EncDec):
    r"""Encoder class.
    Add encoding methods to Filterbank classes.
    Not intended to be subclassed.
    Args:
        filterbank (:class:`Filterbank`): The filterbank to use
            as an encoder.
        is_pinv (bool): Whether to be the pseudo inverse of filterbank.
        as_conv1d (bool): Whether to behave like nn.Conv1d.
            If True (default), forwarding input with shape :math:`(batch, 1, time)`
            will output a tensor of shape :math:`(batch, freq, conv\_time)`.
            If False, will output a tensor of shape :math:`(batch, 1, freq, conv\_time)`.
        padding (int): Zero-padding added to both sides of the input.
    """

    def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0):
        super(Encoder, self).__init__(filterbank, is_pinv=is_pinv)
        self.as_conv1d = as_conv1d
        self.n_feats_out = self.filterbank.n_feats_out
        self.kernel_size = self.filterbank.kernel_size
        self.padding = padding

    @classmethod
    def pinv_of(cls, filterbank, **kwargs):
        """Returns an :class:`~.Encoder`, pseudo inverse of a
        :class:`~.Filterbank` or :class:`~.Decoder`."""
        if isinstance(filterbank, Filterbank):
            return cls(filterbank, is_pinv=True, **kwargs)
        elif isinstance(filterbank, Decoder):
            return cls(filterbank.filterbank, is_pinv=True, **kwargs)

    def forward(self, waveform):
        """Convolve input waveform with the filters from a filterbank.
        Args:
            waveform (:class:`torch.Tensor`): any tensor with samples along the
                last dimension. The waveform representation with and
                batch/channel etc.. dimension.
        Returns:
            :class:`torch.Tensor`: The corresponding TF domain signal.
        Shapes
            >>> (time, ) -> (freq, conv_time)
            >>> (batch, time) -> (batch, freq, conv_time)  # Avoid
            >>> if as_conv1d:
            >>>     (batch, 1, time) -> (batch, freq, conv_time)
            >>>     (batch, chan, time) -> (batch, chan, freq, conv_time)
            >>> else:
            >>>     (batch, chan, time) -> (batch, chan, freq, conv_time)
            >>> (batch, any, dim, time) -> (batch, any, dim, freq, conv_time)
        """
        filters = self.get_filters()
        waveform = self.filterbank.pre_analysis(waveform)
        spec = multishape_conv1d(
            waveform,
            filters=filters,
            stride=self.stride,
            padding=self.padding,
            as_conv1d=self.as_conv1d,
        )
        return self.filterbank.post_analysis(spec)


def multishape_conv1d(
    waveform: torch.Tensor,
    filters: torch.Tensor,
    stride: int,
    padding: int = 0,
    as_conv1d: bool = True,
) -> torch.Tensor:
    if waveform.ndim == 1:
        # Assumes 1D input with shape (time,)
        # Output will be (freq, conv_time)
        return F.conv1d(
            waveform[None, None], filters, stride=stride, padding=padding
        ).squeeze()
    elif waveform.ndim == 2:
        # Assume 2D input with shape (batch or channels, time)
        # Output will be (batch or channels, freq, conv_time)
        warnings.warn(
            "Input tensor was 2D. Applying the corresponding "
            "Decoder to the current output will result in a 3D "
            "tensor. This behaviours was introduced to match "
            "Conv1D and ConvTranspose1D, please use 3D inputs "
            "to avoid it. For example, this can be done with "
            "input_tensor.unsqueeze(1)."
        )
        return F.conv1d(waveform.unsqueeze(1), filters, stride=stride, padding=padding)
    elif waveform.ndim == 3:
        batch, channels, time_len = waveform.shape
        if channels == 1 and as_conv1d:
            # That's the common single channel case (batch, 1, time)
            # Output will be (batch, freq, stft_time), behaves as Conv1D
            return F.conv1d(waveform, filters, stride=stride, padding=padding)
        else:
            # Return batched convolution, input is (batch, 3, time), output will be
            # (b, 3, f, conv_t). Useful for multichannel transforms. If as_conv1d is
            # false, (batch, 1, time) will output (batch, 1, freq, conv_time), useful for
            # consistency.
            return batch_packed_1d_conv(
                waveform, filters, stride=stride, padding=padding
            )
    else:  # waveform.ndim > 3
        # This is to compute "multi"multichannel convolution.
        # Input can be (*, time), output will be (*, freq, conv_time)
        return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding)


def batch_packed_1d_conv(
    inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0
):
    # Here we perform multichannel / multi-source convolution.
    # Output should be (batch, channels, freq, conv_time)
    batched_conv = F.conv1d(
        inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding
    )
    output_shape = inp.shape[:-1] + batched_conv.shape[-2:]
    return batched_conv.view(output_shape)


class Decoder(_EncDec):
    """Decoder class.
    Add decoding methods to Filterbank classes.
    Not intended to be subclassed.
    Args:
        filterbank (:class:`Filterbank`): The filterbank to use as an decoder.
        is_pinv (bool): Whether to be the pseudo inverse of filterbank.
        padding (int): Zero-padding added to both sides of the input.
        output_padding (int): Additional size added to one side of the
            output shape.
    .. note::
        ``padding`` and ``output_padding`` arguments are directly passed to
        ``F.conv_transpose1d``.
    """

    def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):
        super().__init__(filterbank, is_pinv=is_pinv)
        self.padding = padding
        self.output_padding = output_padding

    @classmethod
    def pinv_of(cls, filterbank):
        """Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
        if isinstance(filterbank, Filterbank):
            return cls(filterbank, is_pinv=True)
        elif isinstance(filterbank, Encoder):
            return cls(filterbank.filterbank, is_pinv=True)

    def forward(self, spec, length: Optional[int] = None) -> torch.Tensor:
        """Applies transposed convolution to a TF representation.
        This is equivalent to overlap-add.
        Args:
            spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
                representation. (Output of :func:`Encoder.forward`).
            length: desired output length.
        Returns:
            :class:`torch.Tensor`: The corresponding time domain signal.
        """
        filters = self.get_filters()
        spec = self.filterbank.pre_synthesis(spec)
        wav = multishape_conv_transpose1d(
            spec,
            filters,
            stride=self.stride,
            padding=self.padding,
            output_padding=self.output_padding,
        )
        wav = self.filterbank.post_synthesis(wav)
        if length is not None:
            length = min(length, wav.shape[-1])
            return wav[..., :length]
        return wav


def multishape_conv_transpose1d(
    spec: torch.Tensor,
    filters: torch.Tensor,
    stride: int = 1,
    padding: int = 0,
    output_padding: int = 0,
) -> torch.Tensor:
    if spec.ndim == 2:
        # Input is (freq, conv_time), output is (time)
        return F.conv_transpose1d(
            spec.unsqueeze(0),
            filters,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        ).squeeze()
    if spec.ndim == 3:
        # Input is (batch, freq, conv_time), output is (batch, 1, time)
        return F.conv_transpose1d(
            spec,
            filters,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
    else:
        # Multiply all the left dimensions together and group them in the
        # batch. Make the convolution and restore.
        view_as = (-1,) + spec.shape[-2:]
        out = F.conv_transpose1d(
            spec.reshape(view_as),
            filters,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        return out.view(spec.shape[:-2] + (-1,))


class FreeFB(Filterbank):
    """Free filterbank without any constraints. Equivalent to
    :class:`nn.Conv1d`.
    Args:
        n_filters (int): Number of filters.
        kernel_size (int): Length of the filters.
        stride (int, optional): Stride of the convolution.
            If None (default), set to ``kernel_size // 2``.
        sample_rate (float): Sample rate of the expected audio.
            Defaults to 8000.
    Attributes:
        n_feats_out (int): Number of output filters.
    References
        [1] : "Filterbank design for end-to-end speech separation". ICASSP 2020.
        Manuel Pariente, Samuele Cornell, Antoine Deleforge, Emmanuel Vincent.
    """

    def __init__(
        self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kwargs
    ):
        super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate)
        self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size))
        for p in self.parameters():
            nn.init.xavier_normal_(p)

    def filters(self):
        return self._filters


free = FreeFB