File size: 1,238 Bytes
9b0d6c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn.functional
import torchaudio


class DataAugmentation(object):
    def __init__(self, data_mean=-4.2677393, data_std=4.5689974, num_mel_bins=128, sample_rate=16000):
        self.data_mean = data_mean
        self.data_std = data_std
        self.num_mel_bins = num_mel_bins
        self.sample_rate = sample_rate

    def _wav2fbank(self, waveform):
        waveform = (waveform - waveform.mean())
        fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=self.sample_rate,
                                                  use_energy=False,
                                                  window_type='hanning', num_mel_bins=self.num_mel_bins, dither=0.0,
                                                  frame_shift=10)
        return fbank

    def convert_waveform(self, waveform):
        w = self._wav2fbank(waveform)
        fbank = (w - self.data_mean) / (self.data_std * 2)
        fbank = fbank.unsqueeze(0)
        return fbank

    def __call__(self, batch):
        # apply convert_waveform to each sample of the batch and return the result
        return torch.stack([self.convert_waveform(sample.reshape(1, -1)) for sample in batch]).permute(0, 1, 3, 2)