sohamc10's picture
gradio app
9b0d6c2
import torch
import torch.nn as nn
import torchaudio
sz_float = 4 # size of a float
epsilon = 10e-8 # fudge factor for normalization
class AugmentMelSTFT(nn.Module):
def __init__(
self,
n_mels=128,
sr=32000,
win_length=None,
hopsize=320,
n_fft=1024,
freqm=0,
timem=0,
htk=False,
fmin=0.0,
fmax=None,
norm=1,
fmin_aug_range=1,
fmax_aug_range=1,
fast_norm=False,
preamp=True,
padding="center",
periodic_window=True,
):
torch.nn.Module.__init__(self)
# adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
# Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast
if win_length is None:
win_length = n_fft
if isinstance(win_length, list) or isinstance(win_length, tuple):
assert isinstance(n_fft, list) or isinstance(n_fft, tuple)
assert len(win_length) == len(n_fft)
else:
win_length = [win_length]
n_fft = [n_fft]
self.win_length = win_length
self.n_mels = n_mels
self.n_fft = n_fft
self.sr = sr
self.htk = htk
self.fmin = fmin
if fmax is None:
fmax = sr // 2 - fmax_aug_range // 2
self.fmax = fmax
self.norm = norm
self.hopsize = hopsize
self.preamp = preamp
for win_l in self.win_length:
self.register_buffer(
f"window_{win_l}",
torch.hann_window(win_l, periodic=periodic_window),
persistent=False,
)
assert (
fmin_aug_range >= 1
), f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
assert (
fmin_aug_range >= 1
), f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
self.fmin_aug_range = fmin_aug_range
self.fmax_aug_range = fmax_aug_range
self.register_buffer(
"preemphasis_coefficient", torch.as_tensor([[[-0.97, 1]]]), persistent=False
)
if freqm == 0:
self.freqm = torch.nn.Identity()
else:
self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=False)
if timem == 0:
self.timem = torch.nn.Identity()
else:
self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=False)
self.fast_norm = fast_norm
self.padding = padding
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.iden = nn.Identity()
def forward(self, x):
if self.preamp:
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient)
x = x.squeeze(1)
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
# don't augment eval data
if not self.training:
fmin = self.fmin
fmax = self.fmax
mels = []
for n_fft, win_length in zip(self.n_fft, self.win_length):
x_temp = x
if self.padding == "same":
pad = win_length - self.hopsize
self.iden(x_temp) # printing
x_temp = torch.nn.functional.pad(x_temp, (pad // 2, pad // 2), mode="reflect")
self.iden(x_temp) # printing
x_temp = torch.stft(
x_temp,
n_fft,
hop_length=self.hopsize,
win_length=win_length,
center=self.padding == "center",
normalized=False,
window=getattr(self, f"window_{win_length}"),
return_complex=True
)
x_temp = torch.view_as_real(x_temp)
x_temp = (x_temp ** 2).sum(dim=-1) # power mag
mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, n_fft, self.sr,
fmin, fmax, vtln_low=100.0, vtln_high=-500.,
vtln_warp_factor=1.0)
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
device=x.device)
with torch.cuda.amp.autocast(enabled=False):
x_temp = torch.matmul(mel_basis, x_temp)
x_temp = torch.log(torch.clip(x_temp, min=1e-7))
mels.append(x_temp)
mels = torch.stack(mels, dim=1)
if self.training:
mels = self.freqm(mels)
mels = self.timem(mels)
if self.fast_norm:
mels = (mels + 4.5) / 5.0 # fast normalization
return mels
def extra_repr(self):
return "winsize={}, hopsize={}".format(self.win_length, self.hopsize)