NSynth-5K-pretrained-sed / models /asit /data_transformations.py
sohamc10's picture
gradio app
9b0d6c2
import torch
import torch.nn.functional
import torchaudio
class DataAugmentation(object):
def __init__(self, data_mean=-4.2677393, data_std=4.5689974, num_mel_bins=128, sample_rate=16000):
self.data_mean = data_mean
self.data_std = data_std
self.num_mel_bins = num_mel_bins
self.sample_rate = sample_rate
def _wav2fbank(self, waveform):
waveform = (waveform - waveform.mean())
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=self.sample_rate,
use_energy=False,
window_type='hanning', num_mel_bins=self.num_mel_bins, dither=0.0,
frame_shift=10)
return fbank
def convert_waveform(self, waveform):
w = self._wav2fbank(waveform)
fbank = (w - self.data_mean) / (self.data_std * 2)
fbank = fbank.unsqueeze(0)
return fbank
def __call__(self, batch):
# apply convert_waveform to each sample of the batch and return the result
return torch.stack([self.convert_waveform(sample.reshape(1, -1)) for sample in batch]).permute(0, 1, 3, 2)