Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchaudio | |
sz_float = 4 # size of a float | |
epsilon = 10e-8 # fudge factor for normalization | |
class AugmentMelSTFT(nn.Module): | |
def __init__( | |
self, | |
n_mels=128, | |
sr=32000, | |
win_length=None, | |
hopsize=320, | |
n_fft=1024, | |
freqm=0, | |
timem=0, | |
htk=False, | |
fmin=0.0, | |
fmax=None, | |
norm=1, | |
fmin_aug_range=1, | |
fmax_aug_range=1, | |
fast_norm=False, | |
preamp=True, | |
padding="center", | |
periodic_window=True, | |
): | |
torch.nn.Module.__init__(self) | |
# adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e | |
# Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast | |
if win_length is None: | |
win_length = n_fft | |
if isinstance(win_length, list) or isinstance(win_length, tuple): | |
assert isinstance(n_fft, list) or isinstance(n_fft, tuple) | |
assert len(win_length) == len(n_fft) | |
else: | |
win_length = [win_length] | |
n_fft = [n_fft] | |
self.win_length = win_length | |
self.n_mels = n_mels | |
self.n_fft = n_fft | |
self.sr = sr | |
self.htk = htk | |
self.fmin = fmin | |
if fmax is None: | |
fmax = sr // 2 - fmax_aug_range // 2 | |
self.fmax = fmax | |
self.norm = norm | |
self.hopsize = hopsize | |
self.preamp = preamp | |
for win_l in self.win_length: | |
self.register_buffer( | |
f"window_{win_l}", | |
torch.hann_window(win_l, periodic=periodic_window), | |
persistent=False, | |
) | |
assert ( | |
fmin_aug_range >= 1 | |
), f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation" | |
assert ( | |
fmin_aug_range >= 1 | |
), f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation" | |
self.fmin_aug_range = fmin_aug_range | |
self.fmax_aug_range = fmax_aug_range | |
self.register_buffer( | |
"preemphasis_coefficient", torch.as_tensor([[[-0.97, 1]]]), persistent=False | |
) | |
if freqm == 0: | |
self.freqm = torch.nn.Identity() | |
else: | |
self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=False) | |
if timem == 0: | |
self.timem = torch.nn.Identity() | |
else: | |
self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=False) | |
self.fast_norm = fast_norm | |
self.padding = padding | |
if padding not in ["center", "same"]: | |
raise ValueError("Padding must be 'center' or 'same'.") | |
self.iden = nn.Identity() | |
def forward(self, x): | |
if self.preamp: | |
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient) | |
x = x.squeeze(1) | |
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item() | |
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item() | |
# don't augment eval data | |
if not self.training: | |
fmin = self.fmin | |
fmax = self.fmax | |
mels = [] | |
for n_fft, win_length in zip(self.n_fft, self.win_length): | |
x_temp = x | |
if self.padding == "same": | |
pad = win_length - self.hopsize | |
self.iden(x_temp) # printing | |
x_temp = torch.nn.functional.pad(x_temp, (pad // 2, pad // 2), mode="reflect") | |
self.iden(x_temp) # printing | |
x_temp = torch.stft( | |
x_temp, | |
n_fft, | |
hop_length=self.hopsize, | |
win_length=win_length, | |
center=self.padding == "center", | |
normalized=False, | |
window=getattr(self, f"window_{win_length}"), | |
return_complex=True | |
) | |
x_temp = torch.view_as_real(x_temp) | |
x_temp = (x_temp ** 2).sum(dim=-1) # power mag | |
mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, n_fft, self.sr, | |
fmin, fmax, vtln_low=100.0, vtln_high=-500., | |
vtln_warp_factor=1.0) | |
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0), | |
device=x.device) | |
with torch.cuda.amp.autocast(enabled=False): | |
x_temp = torch.matmul(mel_basis, x_temp) | |
x_temp = torch.log(torch.clip(x_temp, min=1e-7)) | |
mels.append(x_temp) | |
mels = torch.stack(mels, dim=1) | |
if self.training: | |
mels = self.freqm(mels) | |
mels = self.timem(mels) | |
if self.fast_norm: | |
mels = (mels + 4.5) / 5.0 # fast normalization | |
return mels | |
def extra_repr(self): | |
return "winsize={}, hopsize={}".format(self.win_length, self.hopsize) | |