File size: 5,184 Bytes
9b0d6c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn as nn
import torchaudio

sz_float = 4  # size of a float
epsilon = 10e-8  # fudge factor for normalization


class AugmentMelSTFT(nn.Module):
    def __init__(
            self,
            n_mels=128,
            sr=32000,
            win_length=None,
            hopsize=320,
            n_fft=1024,
            freqm=0,
            timem=0,
            htk=False,
            fmin=0.0,
            fmax=None,
            norm=1,
            fmin_aug_range=1,
            fmax_aug_range=1,
            fast_norm=False,
            preamp=True,
            padding="center",
            periodic_window=True,
    ):
        torch.nn.Module.__init__(self)
        # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
        # Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast

        if win_length is None:
            win_length = n_fft

        if isinstance(win_length, list) or isinstance(win_length, tuple):
            assert isinstance(n_fft, list) or isinstance(n_fft, tuple)
            assert len(win_length) == len(n_fft)
        else:
            win_length = [win_length]
            n_fft = [n_fft]

        self.win_length = win_length
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.sr = sr
        self.htk = htk
        self.fmin = fmin
        if fmax is None:
            fmax = sr // 2 - fmax_aug_range // 2
        self.fmax = fmax
        self.norm = norm
        self.hopsize = hopsize
        self.preamp = preamp
        for win_l in self.win_length:
            self.register_buffer(
                f"window_{win_l}",
                torch.hann_window(win_l, periodic=periodic_window),
                persistent=False,
            )
        assert (
                fmin_aug_range >= 1
        ), f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
        assert (
                fmin_aug_range >= 1
        ), f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
        self.fmin_aug_range = fmin_aug_range
        self.fmax_aug_range = fmax_aug_range

        self.register_buffer(
            "preemphasis_coefficient", torch.as_tensor([[[-0.97, 1]]]), persistent=False
        )
        if freqm == 0:
            self.freqm = torch.nn.Identity()
        else:
            self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=False)
        if timem == 0:
            self.timem = torch.nn.Identity()
        else:
            self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=False)
        self.fast_norm = fast_norm
        self.padding = padding
        if padding not in ["center", "same"]:
            raise ValueError("Padding must be 'center' or 'same'.")
        self.iden = nn.Identity()

    def forward(self, x):
        if self.preamp:
            x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient)
        x = x.squeeze(1)

        fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
        fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()

        # don't augment eval data
        if not self.training:
            fmin = self.fmin
            fmax = self.fmax

        mels = []
        for n_fft, win_length in zip(self.n_fft, self.win_length):
            x_temp = x
            if self.padding == "same":
                pad = win_length - self.hopsize
                self.iden(x_temp)  # printing
                x_temp = torch.nn.functional.pad(x_temp, (pad // 2, pad // 2), mode="reflect")
                self.iden(x_temp)  # printing

            x_temp = torch.stft(
                x_temp,
                n_fft,
                hop_length=self.hopsize,
                win_length=win_length,
                center=self.padding == "center",
                normalized=False,
                window=getattr(self, f"window_{win_length}"),
                return_complex=True
            )
            x_temp = torch.view_as_real(x_temp)
            x_temp = (x_temp ** 2).sum(dim=-1)  # power mag

            mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, n_fft, self.sr,
                                                                     fmin, fmax, vtln_low=100.0, vtln_high=-500.,
                                                                     vtln_warp_factor=1.0)
            mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
                                        device=x.device)

            with torch.cuda.amp.autocast(enabled=False):
                x_temp = torch.matmul(mel_basis, x_temp)

            x_temp = torch.log(torch.clip(x_temp, min=1e-7))

            mels.append(x_temp)

        mels = torch.stack(mels, dim=1)

        if self.training:
            mels = self.freqm(mels)
            mels = self.timem(mels)
        if self.fast_norm:
            mels = (mels + 4.5) / 5.0  # fast normalization

        return mels

    def extra_repr(self):
        return "winsize={}, hopsize={}".format(self.win_length, self.hopsize)