File size: 8,082 Bytes
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import os
import h5py
import numpy as np
from typing import Any, Tuple
import torch
import random
from pytorch_lightning import LightningDataModule
import torchaudio
from torchaudio.functional import apply_codec
from torch.utils.data import DataLoader, Dataset
from typing import Any, Dict, Optional, Tuple

def compute_mch_rms_dB(mch_wav, fs=16000, energy_thresh=-50):
    """Return the wav RMS calculated only in the active portions"""
    mean_square = max(1e-20, torch.mean(mch_wav ** 2))
    return 10 * np.log10(mean_square)

def match2(x, d):
    assert x.dim()==2, x.shape
    assert d.dim()==2, d.shape
    minlen = min(x.shape[-1], d.shape[-1])
    x, d = x[:,0:minlen], d[:,0:minlen]
    Fx = torch.fft.rfft(x, dim=-1)
    Fd = torch.fft.rfft(d, dim=-1)
    Phi = Fd*Fx.conj()
    Phi = Phi / (Phi.abs() + 1e-3)
    Phi[:,0] = 0
    tmp = torch.fft.irfft(Phi, dim=-1)
    tau = torch.argmax(tmp.abs(),dim=-1).tolist()
    return tau

def codec_simu(wav, sr=16000, options={'bitrate':'random','compression':'random', 'complexity':'random', 'vbr':'random'}):

    if options['bitrate'] == 'random':
        options['bitrate'] = random.choice([24000, 32000, 48000, 64000, 96000, 128000])
    compression = int(options['bitrate']//1000)
    param = {'format': "mp3", "compression": compression}
    wav_encdec = apply_codec(wav, sr, **param)
    if wav_encdec.shape[-1] >= wav.shape[-1]:
        wav_encdec = wav_encdec[...,:wav.shape[-1]]
    else:
        wav_encdec = torch.cat([wav_encdec, wav[..., wav_encdec.shape[-1]:]], -1)
    tau = match2(wav, wav_encdec) 
    wav_encdec = torch.roll(wav_encdec, -tau[0], -1)

    return wav_encdec

def get_wav_files(root_dir):
    wav_files = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            if filename.endswith('.wav'):
                if "musdb18hq" in dirpath and "mixture" not in filename:
                    wav_files.append(os.path.join(dirpath, filename))
                elif "moisesdb" in dirpath:
                    wav_files.append(os.path.join(dirpath, filename))
    return wav_files

class MusdbMoisesdbDataset(Dataset):
    def __init__(
        self, 
        data_dir: str,
        codec_type: str,
        codec_options: dict,
        sr: int = 16000,
        segments: int = 10,
        num_stems: int = 4,
        snr_range: Tuple[int, int] = (-10, 10),
        num_samples: int = 1000,
    ) -> None:
        
        self.data_dir = data_dir
        self.codec_type = codec_type
        self.codec_options = codec_options
        self.segments = int(segments * sr)
        self.sr = sr
        self.num_stems = num_stems
        self.snr_range = snr_range
        self.num_samples = num_samples
        
        self.instruments = [
            "bass", 
            "bowed_strings", 
            "drums", 
            "guitar",
            "other", 
            "other_keys", 
            "other_plucked", 
            "percussion", 
            "piano", 
            "vocals", 
            "wind"
        ]

    def __len__(self) -> int:
        return self.num_samples
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if random.random() > 0.5:
            select_stems = random.randint(1, self.num_stems)
            select_stems = random.choices(self.instruments, k=select_stems)
            ori_wav = []
            for stem in select_stems:
                h5path = random.choice(os.listdir(os.path.join(self.data_dir, stem)))
                datas = h5py.File(os.path.join(self.data_dir, stem, h5path), 'r')['data']
                random_index = random.randint(0, datas.shape[0]-1)
                music_wav = torch.FloatTensor(datas[random_index])
                start = random.randint(0, music_wav.shape[-1] - self.segments)
                music_wav = music_wav[:, start:start+self.segments]
                
                rescale_snr = random.randint(self.snr_range[0], self.snr_range[1])
                music_wav = music_wav * np.sqrt(10**(rescale_snr/10))
                ori_wav.append(music_wav)
            ori_wav = torch.stack(ori_wav).sum(0)
        else:
            h5path = random.choice(os.listdir(os.path.join(self.data_dir, "mixture")))
            datas = h5py.File(os.path.join(self.data_dir, "mixture", h5path), 'r')['data']
            random_index = random.randint(0, datas.shape[0]-1)
            music_wav = torch.FloatTensor(datas[random_index])
            start = random.randint(0, music_wav.shape[-1] - self.segments)
            ori_wav = music_wav[:, start:start+self.segments]
        
        codec_wav = codec_simu(ori_wav, sr=self.sr, options=self.codec_options)
        
        max_scale = max(ori_wav.abs().max(), codec_wav.abs().max())
        
        if max_scale > 0:
            ori_wav = ori_wav / max_scale
            codec_wav = codec_wav / max_scale
            
        return ori_wav, codec_wav
    

class MusdbMoisesdbEval(Dataset):
    def __init__(
        self,
        data_dir: str
    ) -> None:
        self.data_path = os.listdir(data_dir)
        self.data_path = [os.path.join(data_dir, i) for i in self.data_path]
        
    def __len__(self) -> int:
        return len(self.data_path)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        ori_wav = torchaudio.load(self.data_path[idx]+"/ori_wav.wav")[0]
        codec_wav = torchaudio.load(self.data_path[idx]+"/codec_wav.wav")[0]
        
        return ori_wav, codec_wav, self.data_path[idx]
    
class MusdbMoisesdbDataModule(LightningDataModule):
    def __init__(
        self,
        train_dir: str,
        eval_dir: str,
        codec_type: str,
        codec_options: dict,
        sr: int = 16000,
        segments: int = 10,
        num_stems: int = 4,
        snr_range: Tuple[int, int] = (-10, 10),
        num_samples: int = 1000,
        batch_size: int = 32,
        num_workers: int = 4,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(logger=False)
        
        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        
    def setup(self, stage: Optional[str] = None) -> None:
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
        `self.setup()` once the data is prepared and available for use.

        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
        """
        # load and split datasets only if not loaded already
        if not self.data_train and not self.data_val:
            self.data_train = MusdbMoisesdbDataset(
                data_dir=self.hparams.train_dir,
                codec_type=self.hparams.codec_type,
                codec_options=self.hparams.codec_options,
                sr=self.hparams.sr,
                segments=self.hparams.segments,
                num_stems=self.hparams.num_stems,
                snr_range=self.hparams.snr_range,
                num_samples=self.hparams.num_samples,
            )
            
            self.data_val = MusdbMoisesdbEval(
                data_dir=self.hparams.eval_dir
            )
    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            shuffle=True,
            pin_memory=True,
        )
        
    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            shuffle=False,
            pin_memory=True,
        )