FrederikRautenberg commited on
Commit
0f1d9a2
·
1 Parent(s): 273b181

Add vits model and normalizing flow. Jupyter Notebook as example call

Browse files
pvq_manipulation/Example_Notebook.ipynb ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "f0e32cd2-4955-4140-8f48-9751a1a8c588",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import numpy as np \n",
11
+ "from pathlib import Path\n",
12
+ "import padertorch as pt\n",
13
+ "import paderbox as pb\n",
14
+ "import time\n",
15
+ "import torch\n",
16
+ "import torchaudio\n",
17
+ "import ipywidgets as widgets\n",
18
+ "from onnxruntime import InferenceSession\n",
19
+ "from pvq_manipulation.models.vits import Vits_NT\n",
20
+ "from pvq_manipulation.models.ffjord import FFJORD\n",
21
+ "from IPython.display import display, Audio, clear_output\n",
22
+ "from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER\n",
23
+ "from paderbox.transform.module_resample import resample_sox\n",
24
+ "from pvq_manipulation.helper.vad import EnergyVAD\n",
25
+ "from train_tts_nt.helper.utils import rms_norm"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "id": "d4df1db0-8439-4573-9dc2-5d578e8befa1",
31
+ "metadata": {},
32
+ "source": [
33
+ "# load TTS model"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "e6691176-6119-4bf0-9dcf-44d657c76074",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "storage_dir_tts = Path(\"./Saved_models/tts_model/\")\n",
44
+ "tts_model = Vits_NT.load_model(storage_dir_tts, checkpoint=\"checkpoint_390000.pth\")"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "id": "c9c7541c-fab5-4d44-9b89-a26a34343e7c",
50
+ "metadata": {},
51
+ "source": [
52
+ "# load normalizing flow"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "e4a55082-c6c6-4283-96ed-217553f33bcd",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "storage_dir_normalizing_flow = Path(\"./Saved_models/norm_flow\")\n",
63
+ "config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / \"config.yaml\")\n",
64
+ "normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint=\"checkpoints/ckpt_best_loss.pth\")"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "id": "deebed07-b28c-49de-b30f-d80b9e1c6899",
70
+ "metadata": {},
71
+ "source": [
72
+ "# load hubert features model"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "bc4627e1-bac7-4533-8cac-bbc296889855",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "hubert_model = HubertExtractor(\n",
83
+ " layer=SID_LARGE_LAYER,\n",
84
+ " model_name=\"HUBERT_LARGE\",\n",
85
+ " backend=\"torchaudio\",\n",
86
+ " device='cpu', \n",
87
+ " storage_dir='/net/vol/rautenberg/storage/hubert'# target storage dir hubert model\n",
88
+ ")"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "id": "c78fa11b-8617-4175-902c-8af0e4491201",
94
+ "metadata": {},
95
+ "source": [
96
+ "# Example Synthesis"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "id": "4e8afa1b-b02e-4a40-982d-36aa78f37a57",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "speaker_id = 1034\n",
107
+ "example_id = \"1034_121119_000028_000001\"\n",
108
+ "\n",
109
+ "wav_1 = tts_model.synthesize_from_example({\n",
110
+ " 'text' : \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\", \n",
111
+ " 'd_vector_storage_root': f\"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth\"\n",
112
+ "})\n",
113
+ "display(Audio(wav_1, rate=24_000, normalize=True))"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "id": "feeb1d62-69f2-45c1-a172-16fcfbecd0da",
119
+ "metadata": {},
120
+ "source": [
121
+ "# Manipulation Block"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "625368d3-dd35-4da7-a358-7bbac448806c",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "def get_manipulation(\n",
132
+ " example, \n",
133
+ " d_vector,\n",
134
+ " labels,\n",
135
+ " flow, \n",
136
+ " tts_model,\n",
137
+ " manipulation_idx=0,\n",
138
+ " manipulation_fkt=1,\n",
139
+ "):\n",
140
+ " labels_manipulated = labels.clone()\n",
141
+ " labels_manipulated[:,manipulation_idx] += manipulation_fkt\n",
142
+ " \n",
143
+ " output_forward = flow.forward((d_vector.float(), labels))[0]\n",
144
+ " sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0]\n",
145
+ "\n",
146
+ " wav = tts_model.synthesize_from_example({\n",
147
+ " 'text': \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\",\n",
148
+ " 'd_vector': d_vector.detach().numpy(),\n",
149
+ " 'd_vector_man': sampled_class_manipulated.detach().numpy(),\n",
150
+ " }) \n",
151
+ " return wav\n",
152
+ "\n",
153
+ "def extract_speaker_embedding(example):\n",
154
+ " observation, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True)\n",
155
+ " observation = resample_sox(observation, in_rate=sr, out_rate=16_000)\n",
156
+ " \n",
157
+ " vad = EnergyVAD(sample_rate=16_000)\n",
158
+ " if observation.ndim == 1:\n",
159
+ " observation = observation[None, :]\n",
160
+ " \n",
161
+ " observation = vad({'audio_data': observation})['audio_data']\n",
162
+ " \n",
163
+ " with torch.no_grad():\n",
164
+ " example = tts_model.speaker_manager.prepare_example({'audio_data': {'observation': observation}, **example})\n",
165
+ " example = pt.data.utils.collate_fn([example])\n",
166
+ " example['features'] = torch.tensor(np.array(example['features']))\n",
167
+ " d_vector = tts_model.speaker_manager.forward(example)[0]\n",
168
+ " return d_vector"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "id": "b722e503-a8f4-4702-acce-20bcdd828846",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "def load_speaker_labels(example, config_norm_flow, reg_stor_dir=Path('./Saved_models/pvq_extractor/')):\n",
179
+ " audio, _ = torchaudio.load(example['audio_path']['observation'])\n",
180
+ " num_samples = torch.tensor([audio.shape[-1]])\n",
181
+ "\n",
182
+ " if torch.cuda.is_available():\n",
183
+ " audio = audio.cuda()\n",
184
+ " num_samples = num_samples.cuda()\n",
185
+ " providers = [\"CPUExecutionProvider\"]\n",
186
+ "\n",
187
+ " with torch.no_grad():\n",
188
+ " features, seq_len = hubert_model(\n",
189
+ " audio, \n",
190
+ " 24_000, \n",
191
+ " sequence_lengths=num_samples,\n",
192
+ " )\n",
193
+ " features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)\n",
194
+ "\n",
195
+ " pvqd_predictions = {}\n",
196
+ " for pvq in ['Breathiness', 'Loudness', 'Pitch', 'Resonance', 'Roughness', 'Strain', 'Weight']:\n",
197
+ " with open(reg_stor_dir / f\"{pvq}.onnx\", \"rb\") as fid:\n",
198
+ " onnx = fid.read()\n",
199
+ " sess = InferenceSession(onnx, providers=providers)\n",
200
+ " pred = sess.run(None, {\"X\": features[None]})[0].squeeze(1)\n",
201
+ " pvqd_predictions[pvq] = pred.tolist()[0]\n",
202
+ " labels = []\n",
203
+ " for key in config_norm_flow['speaker_conditioning']:\n",
204
+ " labels.append(pvqd_predictions[key]/100)\n",
205
+ " return torch.tensor(labels)"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "markdown",
210
+ "id": "008035ba-6054-4e6e-ab16-1aaaf68f584a",
211
+ "metadata": {},
212
+ "source": [
213
+ "# Get example manipulation"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "id": "e921a3cd-1699-495c-b825-519fb706d89d",
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "example = {\n",
224
+ " 'audio_path': {'observation': \"./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav\"},\n",
225
+ " 'speaker_id': 1034,\n",
226
+ " 'example_id': \"1034_121119_000028_000001\",\n",
227
+ "}\n",
228
+ "\n",
229
+ "d_vector = extract_speaker_embedding(example)\n",
230
+ "labels = load_speaker_labels(example, config_norm_flow)\n",
231
+ "\n",
232
+ "wav_manipulated = get_manipulation(\n",
233
+ " example=example, \n",
234
+ " d_vector=d_vector, \n",
235
+ " labels=labels[None, :], \n",
236
+ " flow=normalizing_flow,\n",
237
+ " tts_model=tts_model,\n",
238
+ " manipulation_idx=0,\n",
239
+ " manipulation_fkt=1,\n",
240
+ ")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": null,
246
+ "id": "09a04e5b-c2ab-43e5-b9df-171028100ab6",
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "example = {\n",
251
+ " 'audio_path': {'observation': \"./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav\"},\n",
252
+ " 'speaker_id': 1034,\n",
253
+ " 'example_id': \"1034_121119_000028_000001\",\n",
254
+ "}\n",
255
+ "\n",
256
+ "label_options = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch']\n",
257
+ "\n",
258
+ "manipulation_idx_widget = widgets.Dropdown(\n",
259
+ " options=[(label, i) for i, label in enumerate(label_options)],\n",
260
+ " value=2, # Standardwert: Breathiness\n",
261
+ " description='Type:',\n",
262
+ " style={'description_width': 'initial'}\n",
263
+ ")\n",
264
+ "\n",
265
+ "manipulation_fkt_widget = widgets.FloatSlider(\n",
266
+ " value=1.0, min=-2.0, max=2.0, step=0.1,\n",
267
+ " description='Strength:',\n",
268
+ " style={'description_width': 'initial'}\n",
269
+ ")\n",
270
+ "\n",
271
+ "run_button = widgets.Button(description=\"Run Manipulation\")\n",
272
+ "\n",
273
+ "audio_output = widgets.Output()\n",
274
+ "\n",
275
+ "def update_manipulation(b):\n",
276
+ " manipulation_idx = manipulation_idx_widget.value\n",
277
+ " manipulation_fkt = manipulation_fkt_widget.value\n",
278
+ " \n",
279
+ " d_vector = extract_speaker_embedding(example)\n",
280
+ " labels = load_speaker_labels(example, config_norm_flow)\n",
281
+ "\n",
282
+ " with audio_output:\n",
283
+ " clear_output(wait=True)\n",
284
+ " display(widgets.Label(\"Processing...\"))\n",
285
+ " \n",
286
+ " time.sleep(1) \n",
287
+ " \n",
288
+ " wav_manipulated = get_manipulation(\n",
289
+ " example=example, \n",
290
+ " d_vector=d_vector, \n",
291
+ " labels=labels[None, :], \n",
292
+ " flow=normalizing_flow,\n",
293
+ " tts_model=tts_model,\n",
294
+ " manipulation_idx=manipulation_idx,\n",
295
+ " manipulation_fkt=manipulation_fkt,\n",
296
+ " )\n",
297
+ " \n",
298
+ " with audio_output:\n",
299
+ " clear_output(wait=True) \n",
300
+ " display(Audio(wav_manipulated, rate=24_000, normalize=True))\n",
301
+ " display(Audio(example['audio_path']['observation'], rate=24_000, normalize=True))\n",
302
+ "\n",
303
+ " print(f\"Manipulated {label_options[manipulation_idx]} with strength {manipulation_fkt}\")\n",
304
+ "\n",
305
+ "run_button.on_click(update_manipulation)\n",
306
+ "display(manipulation_idx_widget, manipulation_fkt_widget, run_button, audio_output)"
307
+ ]
308
+ }
309
+ ],
310
+ "metadata": {
311
+ "kernelspec": {
312
+ "display_name": "voice editing",
313
+ "language": "python",
314
+ "name": "voice_editing"
315
+ },
316
+ "language_info": {
317
+ "codemirror_mode": {
318
+ "name": "ipython",
319
+ "version": 3
320
+ },
321
+ "file_extension": ".py",
322
+ "mimetype": "text/x-python",
323
+ "name": "python",
324
+ "nbconvert_exporter": "python",
325
+ "pygments_lexer": "ipython3",
326
+ "version": "3.11.9"
327
+ }
328
+ },
329
+ "nbformat": 4,
330
+ "nbformat_minor": 5
331
+ }
pvq_manipulation/helper/characters.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Yourtts:
2
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u00af\u00b7\u00df\u00e0\u00e1\u00e2\u00e3\u00e4\u00e6\u00e7\u00e8\u00e9\u00ea\u00eb\u00ec\u00ed\u00ee\u00ef\u00f1\u00f2\u00f3\u00f4\u00f5\u00f6\u00f9\u00fa\u00fb\u00fc\u00ff\u0101\u0105\u0107\u0113\u0119\u011b\u012b\u0131\u0142\u0144\u014d\u0151\u0153\u015b\u016b\u0171\u017a\u017c\u01ce\u01d0\u01d2\u01d4\u0430\u0431\u0432\u0433\u0434\u0435\u0436\u0437\u0438\u0439\u043a\u043b\u043c\u043d\u043e\u043f\u0440\u0441\u0442\u0443\u0444\u0445\u0446\u0447\u0448\u0449\u044a\u044b\u044c\u044d\u044e\u044f\u0451\u0454\u0456\u0457\u0491\u2013!'(),-.:;? "
3
+ German:
4
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;?\u00af\u2013\u00fc\u00f6\u00e4\u00df\u201a\u2018\u2019"
pvq_manipulation/helper/moving_batch_norm.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This Code is adapted from https://github.com/RameenAbdal/StyleFlow/blob/master/module/normalization.py
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class MovingBatchNormNd(nn.Module):
10
+ def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True):
11
+ super(MovingBatchNormNd, self).__init__()
12
+ self.num_features = num_features
13
+ self.affine = affine
14
+ self.eps = eps
15
+ self.decay = decay
16
+ self.bn_lag = bn_lag
17
+ self.register_buffer('step', torch.zeros(1))
18
+ if self.affine:
19
+ self.weight = Parameter(torch.Tensor(num_features))
20
+ self.bias = Parameter(torch.Tensor(num_features))
21
+ else:
22
+ self.register_parameter('weight', None)
23
+ self.register_parameter('bias', None)
24
+ self.register_buffer('running_mean', torch.zeros(num_features))
25
+ self.register_buffer('running_var', torch.ones(num_features))
26
+ self.reset_parameters()
27
+
28
+ @property
29
+ def shape(self):
30
+ raise NotImplementedError
31
+
32
+ def reset_parameters(self):
33
+ self.running_mean.zero_()
34
+ self.running_var.fill_(1)
35
+ if self.affine:
36
+ self.weight.data.zero_()
37
+ self.bias.data.zero_()
38
+
39
+ def forward(self, x, c=None, logpx=None, reverse=False):
40
+ if reverse:
41
+ return self._reverse(x, logpx)
42
+ else:
43
+ return self._forward(x, logpx)
44
+
45
+ def _forward(self, x, logpx=None):
46
+ num_channels = x.size(-1)
47
+ used_mean = self.running_mean.clone().detach()
48
+ used_var = self.running_var.clone().detach()
49
+
50
+ if self.training:
51
+ # compute batch statistics
52
+ x_t = x.transpose(0, -1).reshape(num_channels, -1)
53
+ batch_mean = torch.mean(x_t, dim=1)
54
+
55
+ batch_var = torch.var(x_t, dim=1)
56
+
57
+ # moving average
58
+ if self.bn_lag > 0:
59
+ used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
60
+ used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
61
+ used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
62
+ used_var /= (1. - self.bn_lag**(self.step[0] + 1))
63
+
64
+ # update running estimates
65
+ self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
66
+ self.running_var -= self.decay * (self.running_var - batch_var.data)
67
+ self.step += 1
68
+
69
+ # perform normalization
70
+ used_mean = used_mean.view(*self.shape).expand_as(x)
71
+ used_var = used_var.view(*self.shape).expand_as(x)
72
+
73
+ y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
74
+
75
+ if self.affine:
76
+ weight = self.weight.view(*self.shape).expand_as(x)
77
+ bias = self.bias.view(*self.shape).expand_as(x)
78
+ y = y * torch.exp(weight) + bias
79
+
80
+ if logpx is None:
81
+ return y
82
+ else:
83
+ #import ipdb
84
+ #ipdb.set_trace()
85
+ return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
86
+
87
+ def _reverse(self, y, logpy=None):
88
+ used_mean = self.running_mean
89
+ used_var = self.running_var
90
+
91
+ if self.affine:
92
+ weight = self.weight.view(*self.shape).expand_as(y)
93
+ bias = self.bias.view(*self.shape).expand_as(y)
94
+ y = (y - bias) * torch.exp(-weight)
95
+
96
+ used_mean = used_mean.view(*self.shape).expand_as(y)
97
+ used_var = used_var.view(*self.shape).expand_as(y)
98
+ x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
99
+
100
+ if logpy is None:
101
+ return x
102
+ else:
103
+ return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
104
+
105
+ def _logdetgrad(self, x, used_var):
106
+ logdetgrad = -0.5 * torch.log(used_var + self.eps)
107
+ if self.affine:
108
+ weight = self.weight.view(*self.shape).expand(*x.size())
109
+ logdetgrad += weight
110
+ return logdetgrad
111
+
112
+ def __repr__(self):
113
+ return (
114
+ '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
115
+ ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
116
+ )
117
+
118
+
119
+ def stable_var(x, mean=None, dim=1):
120
+ if mean is None:
121
+ mean = x.mean(dim, keepdim=True)
122
+ mean = mean.view(-1, 1)
123
+ res = torch.pow(x - mean, 2)
124
+ max_sqr = torch.max(res, dim, keepdim=True)[0]
125
+ var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
126
+ var = var.view(-1)
127
+ # change nan to zero
128
+ var[var != var] = 0
129
+ return var
130
+
131
+
132
+ class MovingBatchNorm1d(MovingBatchNormNd):
133
+ @property
134
+ def shape(self):
135
+ return [1, -1]
136
+
137
+ def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
138
+ ret = super(MovingBatchNorm1d, self).forward(
139
+ x, context, logpx=logpx, reverse=reverse)
140
+ return ret
pvq_manipulation/helper/utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import paderbox as pb
2
+ import torch
3
+
4
+ from coqpit import Coqpit
5
+ from dataclasses import dataclass, field
6
+ from paderbox.transform.module_resample import resample_sox
7
+
8
+ from typing import List
9
+
10
+ from TTS.tts.models.vits import VitsAudioConfig, VitsArgs
11
+ from TTS.tts.configs.shared_configs import BaseTTSConfig
12
+
13
+
14
+ def load_audio(file_path, target_sr):
15
+ """Load the audio file normalized in [-1, 1]
16
+
17
+ Return Shapes:
18
+ - x: :math:`[1, T]`
19
+ """
20
+ if type(file_path) is dict:
21
+ if 'observation' in file_path:
22
+ file_path = file_path['observation']
23
+
24
+ x, sr = pb.io.load_audio(file_path, return_sample_rate=True)
25
+ if sr != target_sr:
26
+ x = resample_sox(x, in_rate=sr, out_rate=target_sr)
27
+ x = torch.tensor(x, dtype=torch.float32)[None, :]
28
+ x[x < -1] = -1
29
+ x[x > 1] = 1
30
+ assert (x > 1).sum() + (x < -1).sum() == 0
31
+ return x, target_sr
32
+
33
+
34
+ @dataclass
35
+ class VitsAudioConfig_NT(Coqpit):
36
+ fft_size: int = 1024
37
+ sample_rate: int = 16000
38
+ win_length: int = 1024
39
+ hop_length: int = 256
40
+ num_mels: int = 80
41
+ mel_fmin: int = 0
42
+ mel_fmax: int = None
43
+ fading: str = 'half'
44
+ window: str = 'hann'
45
+ pad: bool = True
46
+
47
+
48
+ @dataclass
49
+ class VitsConfig_NT(BaseTTSConfig):
50
+ """Defines parameters for VITS End2End TTS model.
51
+
52
+ Args:
53
+ model (str):
54
+ Model name. Do not change unless you know what you are doing.
55
+
56
+ model_args (VitsArgs):
57
+ Model architecture arguments. Defaults to `VitsArgs()`.
58
+
59
+ audio (VitsAudioConfig):
60
+ Audio processing configuration. Defaults to `VitsAudioConfig()`.
61
+
62
+ grad_clip (List):
63
+ Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
64
+
65
+ lr_gen (float):
66
+ Initial learning rate for the generator. Defaults to 0.0002.
67
+
68
+ lr_disc (float):
69
+ Initial learning rate for the discriminator. Defaults to 0.0002.
70
+
71
+ lr_scheduler_gen (str):
72
+ Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
73
+ `ExponentialLR`.
74
+
75
+ lr_scheduler_gen_params (dict):
76
+ Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
77
+
78
+ lr_scheduler_disc (str):
79
+ Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
80
+ `ExponentialLR`.
81
+
82
+ lr_scheduler_disc_params (dict):
83
+ Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
84
+
85
+ scheduler_after_epoch (bool):
86
+ If true, step the schedulers after each epoch else after each step. Defaults to `False`.
87
+
88
+ optimizer (str):
89
+ Name of the optimizer to use with both the generator and the discriminator networks. One of the
90
+ `torch.optim.*`. Defaults to `AdamW`.
91
+
92
+ kl_loss_alpha (float):
93
+ Loss weight for KL loss. Defaults to 1.0.
94
+
95
+ disc_loss_alpha (float):
96
+ Loss weight for the discriminator loss. Defaults to 1.0.
97
+
98
+ gen_loss_alpha (float):
99
+ Loss weight for the generator loss. Defaults to 1.0.
100
+
101
+ feat_loss_alpha (float):
102
+ Loss weight for the feature matching loss. Defaults to 1.0.
103
+
104
+ mel_loss_alpha (float):
105
+ Loss weight for the mel loss. Defaults to 45.0.
106
+
107
+ return_wav (bool):
108
+ If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
109
+
110
+ compute_linear_spec (bool):
111
+ If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
112
+
113
+ use_weighted_sampler (bool):
114
+ If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
115
+
116
+ weighted_sampler_attrs (dict):
117
+ Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
118
+ by overweighting `root_path` by 2.0. Defaults to `{}`.
119
+
120
+ weighted_sampler_multipliers (dict):
121
+ Weight each unique value of a key returned by the formatter for weighted sampling.
122
+ For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
123
+ It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
124
+
125
+ r (int):
126
+ Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
127
+
128
+ add_blank (bool):
129
+ If true, a blank token is added in between every character. Defaults to `True`.
130
+
131
+ test_sentences (List[List]):
132
+ List of sentences with speaker and language information to be used for testing.
133
+
134
+ language_ids_file (str):
135
+ Path to the language ids file.
136
+
137
+ use_language_embedding (bool):
138
+ If true, language embedding is used. Defaults to `False`.
139
+
140
+ Note:
141
+ Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
142
+
143
+ Example:
144
+
145
+ >>> from TTS.tts.configs.vits_config import VitsConfig
146
+ >>> config = VitsConfig()
147
+ """
148
+ model: str = "vits"
149
+ # model specific params
150
+ model_args: VitsArgs = field(default_factory=VitsArgs)
151
+ audio: VitsAudioConfig = field(default_factory=VitsAudioConfig)
152
+
153
+ # optimizer
154
+ grad_clip: List[float] = field(default_factory=lambda: [1000, 1000, 1000])
155
+ lr_gen: float = 0.0002
156
+ lr_disc: float = 0.0002
157
+ lr_scheduler_gen: str = "ExponentialLR"
158
+ lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
159
+ lr_scheduler_disc: str = "ExponentialLR"
160
+ lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
161
+ scheduler_after_epoch: bool = True
162
+ optimizer: str = "AdamW"
163
+ optimizer_params: dict = field(
164
+ default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
165
+
166
+ # loss params
167
+ kl_loss_alpha: float = 1.0
168
+ disc_loss_alpha: float = 1.0
169
+ gen_loss_alpha: float = 1.0
170
+ feat_loss_alpha: float = 1.0
171
+ mel_loss_alpha: float = 45.0
172
+ dur_loss_alpha: float = 1.0
173
+ speaker_encoder_loss_alpha: float = 1.0
174
+
175
+ # data loader params
176
+ return_wav: bool = True
177
+ compute_linear_spec: bool = True
178
+
179
+ # sampler params
180
+ use_weighted_sampler: bool = False # TODO: move it to the base config
181
+ weighted_sampler_attrs: dict = field(default_factory=lambda: {})
182
+ weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
183
+
184
+ # overrides
185
+ r: int = 1 # DO NOT CHANGE
186
+ add_blank: bool = True
187
+
188
+ # testing
189
+ test_sentences: List[List] = field(
190
+ default_factory=lambda: [
191
+ ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
192
+ ["Be a voice, not an echo."],
193
+ ["I'm sorry Dave. I'm afraid I can't do that."],
194
+ ["This cake is great. It's so delicious and moist."],
195
+ ["Prior to November 22, 1963."],
196
+ ]
197
+ )
198
+
199
+ # multi-speaker settings
200
+ # use speaker embedding layer
201
+ num_speakers: int = 0
202
+ use_speaker_embedding: bool = False
203
+ speakers_file: str = None
204
+ speaker_embedding_channels: int = 256
205
+ language_ids_file: str = None
206
+ use_language_embedding: bool = False
207
+
208
+ # use d-vectors
209
+ d_vectors_stor_file: bool = False
210
+ d_vector_model_file: str = None
211
+ d_vector_dim: int = None
212
+ d_vector_model: str = None
213
+ dataset_dict: dict = None
214
+ gan_speaker_conditioning: bool = True
215
+
216
+ sample_rate: int = 16_000
217
+ use_vad: bool = True
218
+ use_phone_labels: bool = False
219
+
220
+ CONFIG_SOLVER: str = ''
221
+
222
+ use_speaker_embedding_cond: bool = True
223
+
224
+ def __post_init__(self):
225
+ for key, val in self.model_args.items():
226
+ if hasattr(self, key):
227
+ self[key] = val
228
+
pvq_manipulation/helper/vad.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import paderbox as pb
3
+ import padertorch as pt
4
+ import typing
5
+
6
+ from dataclasses import dataclass
7
+
8
+
9
+ @pb.utils.functional.partial_decorator
10
+ def conv_smoothing(signal, window_length=7, threshold=3):
11
+ """
12
+
13
+ Boundary effects are visible at beginning and end of signal.
14
+
15
+ Examples:
16
+ >>> conv_smoothing(np.array([False, True, True, True, False, False, False, True]), 3, 2)
17
+ array([False, True, True, True, False, False, False, False])
18
+
19
+ Args:
20
+ signal:
21
+ window_length:
22
+ threshold:
23
+
24
+ Returns:
25
+
26
+ """
27
+ left_context = right_context = (window_length - 1) // 2
28
+ if window_length % 2 == 0:
29
+ right_context += 1
30
+ act_conv = np.sum(pb.array.segment_axis(
31
+ np.pad(signal, (left_context, right_context), mode='constant'),
32
+ length=window_length, shift=1, axis=0, end='cut'
33
+ ), axis=-1)
34
+ # act_conv = np.convolve(signal, np.ones(window_length), 's')
35
+ act = act_conv >= threshold
36
+ assert act.shape == signal.shape, (act.shape, signal.shape)
37
+ return act
38
+
39
+
40
+ @dataclass
41
+ class VAD(pt.Configurable):
42
+ smoothing: typing.Optional[typing.Callable] = None
43
+
44
+ def reset(self):
45
+ """Override for a stateful VAD"""
46
+ pass
47
+
48
+ def compute_vad(self, signal, time_resolution=True):
49
+ raise NotImplementedError()
50
+
51
+ def vad_to_time(self, vad, time_length):
52
+ raise NotImplementedError()
53
+
54
+ def __call__(self, signal, time_resolution=True, reset=True):
55
+ if reset:
56
+ self.reset()
57
+
58
+ vad = self.compute_vad(signal)
59
+
60
+ if self.smoothing is not None:
61
+ vad = pb.array.interval.ArrayInterval(self.smoothing(vad))
62
+
63
+ if time_resolution:
64
+ vad = self.vad_to_time(vad, time_length=signal.shape[-1])
65
+
66
+ return vad
67
+
68
+
69
+ class EnergyVAD(VAD):
70
+ def __init__(self, sample_rate, threshold=0.3):
71
+ self.sample_rate = sample_rate
72
+ self.threshold = threshold
73
+
74
+ @staticmethod
75
+ def remove_silence(signal, vad_mask):
76
+ return signal[vad_mask == 1]
77
+
78
+ def __call__(self, example):
79
+ signal = example['audio_data'] # B T
80
+ vad_mask = self.get_vad_mask(signal)
81
+ signal = self.remove_silence(signal, vad_mask)
82
+ example['audio_data'] = signal
83
+ example['vad_mask'] = vad_mask
84
+ return example
85
+
86
+ def get_vad_mask(self, signal):
87
+ window_size = int(0.1 * self.sample_rate + 1)
88
+
89
+ half_context = (window_size - 1) // 2
90
+ std = np.std(signal, axis=-1, keepdims=True)
91
+ signal = signal - np.mean(signal, axis=-1, keepdims=True)
92
+ signal = np.abs(signal)
93
+ zeros = np.zeros(
94
+ [
95
+ signal.shape[0],
96
+ half_context,
97
+ ]
98
+ )
99
+ signal = np.concatenate([zeros, signal, zeros], axis=1)
100
+ sliding_max = np.max(pb.array.segment_axis(
101
+ signal,
102
+ length=window_size, shift=1, axis=1, end='cut'
103
+ ), axis=-1)
104
+ return sliding_max > self.threshold * std
105
+
106
+
107
+ @dataclass
108
+ class ThresholdVAD(VAD):
109
+ """
110
+ Energy-based VAD for almost clean files. Tested on WSJ clean data by Lukas
111
+ Drude.
112
+
113
+ Attributes:
114
+ threshold: Fraction of total signal standard deviation. Use 0.3 for
115
+ (almost) clean files (SNR >= 20dB, think LibriTTS) and 0.7 for less
116
+ clean files (think LibriSpeech).
117
+ window_size: Size of sliding max window.
118
+ sample_rate: Sampling rate of audio data.
119
+ smoothing: Optional callable that uses a sliding window over the raw
120
+ decision to return a smoothed VAD.
121
+ """
122
+ threshold: float = 0.3
123
+ window_size: typing.Optional[int] = None
124
+ sample_rate: int = 16_000
125
+ smoothing: typing.Optional[typing.Callable] = None
126
+
127
+ @classmethod
128
+ def finalize_dogmatic_config(cls, config):
129
+ rate = config['sample_rate']
130
+ config['smoothing'] = {
131
+ 'partial': conv_smoothing,
132
+ 'window_length': int(0.3 * rate),
133
+ 'threshold': int(0.1 * rate),
134
+ }
135
+
136
+ def __post_init__(self):
137
+ if self.window_size is None:
138
+ self.window_size = int(0.1 * self.sample_rate + 1)
139
+
140
+ assert self.window_size % 2 == 1, self.window_size
141
+
142
+ def __call__(self, example):
143
+ if isinstance(example, dict):
144
+ signal = example['audio_data']
145
+ if signal.ndim == 2 and signal.shape[0] == 1:
146
+ signal = signal[0]
147
+ elif signal.ndim == 2 and signal.shape[0] != 1:
148
+ raise ValueError(
149
+ 'Only mono signals are supported but audio_data has shape '
150
+ f'{signal.shape}'
151
+ )
152
+ vad = super().__call__(signal)
153
+ intervals = np.asarray(vad.intervals)
154
+ start, stop = zip(*intervals)
155
+ example['vad'] = vad
156
+ example['vad_start_samples'] = start
157
+ example['vad_stop_samples'] = stop
158
+ else:
159
+ example = super().__call__(example)
160
+ return example
161
+
162
+ def _detect_voice_activity(self, signal):
163
+ assert signal.ndim == 1, signal.shape
164
+
165
+ half_context = (self.window_size - 1) // 2
166
+ std = np.std(signal)
167
+ signal = signal - np.mean(signal)
168
+ assert np.min(signal) < 0
169
+ assert np.max(signal) > 0
170
+ signal = np.abs(signal)
171
+
172
+ sliding_max = np.max(pb.array.segment_axis(
173
+ np.pad(signal, (half_context, half_context), mode='constant'),
174
+ length=self.window_size, shift=1, axis=0, end='cut'
175
+ ), axis=-1)
176
+
177
+ assert sliding_max.shape == signal.shape, (
178
+ sliding_max.shape, signal.shape
179
+ )
180
+
181
+ unconstrained = sliding_max > self.threshold * std
182
+
183
+ return unconstrained
184
+
185
+ def compute_vad(self, signal, time_resolution=True):
186
+ assert time_resolution
187
+ return pb.array.interval.ArrayInterval(
188
+ self._detect_voice_activity(signal)
189
+ )
190
+
191
+ def vad_to_time(self, vad, time_length):
192
+ assert time_length == vad.shape[-1], (time_length, vad.shape[-1])
193
+ return vad
pvq_manipulation/models/ffjord.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import paderbox as pb
3
+
4
+ from padertorch.base import Model
5
+ from torchdiffeq import odeint_adjoint as odeint
6
+ from pvq_manipulation.helper.moving_batch_norm import MovingBatchNorm1d
7
+
8
+
9
+ class ODEBlock(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ ode_function,
13
+ train_flag=True,
14
+ reverse=False,
15
+ ):
16
+ super().__init__()
17
+ self.time_deriv_func = ode_function
18
+ self.noise = None
19
+ self.reverse = reverse
20
+ self.train_flag = train_flag
21
+
22
+ def forward(
23
+ self,
24
+ time: torch.Tensor,
25
+ states: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
26
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
27
+ """
28
+ Helper function to use a neural network for dy(t)/dt = f_theta(t, y(t))
29
+
30
+ Hutchinson’s trace estimator, as proposed in the FFJORD Paper, was adapted from:
31
+ https://github.com/RameenAbdal/StyleFlow/blob/master/module/odefunc.py
32
+
33
+ Args:
34
+ time (torch.Tensor): Scalar tensor representing time
35
+ states (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
36
+ - z (torch.Tensor): (batch_size, feature_dim) representing the input data.
37
+ - d_log_dz_dt (torch.Tensor): (batch_size, 1) representing the log derivative.
38
+ - labels (torch.Tensor): (batch_size, num_labeled_classes)
39
+
40
+ Returns:
41
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
42
+ - dz_dt (torch.Tensor): (batch_size, feature_dim) The derivative of z w.r.t. time
43
+ - d_log_dz_dt (torch.Tensor): (batch_size, 1) The negative log derivative
44
+ - labels (torch.Tensor): (batch_size, num_labeled_classes)
45
+ """
46
+
47
+ z, d_log_dz_dt, labels = states
48
+
49
+ if self.noise is None:
50
+ self.noise = self.sample_rademacher_like(z)
51
+
52
+ with torch.enable_grad():
53
+ z.requires_grad_(True)
54
+
55
+ dz_dt = self.time_deriv_func.forward(time, z, labels)
56
+ if self.train_flag:
57
+ d_log_dz_dt = self.divergence_approx(dz_dt, z, self.noise)
58
+ else:
59
+ d_log_dz_dt = torch.zeros_like(z[:, 0]).requires_grad_(True)
60
+
61
+ labels = torch.zeros_like(labels).requires_grad_(True)
62
+ return dz_dt, -d_log_dz_dt.view(z.shape[0], 1), labels
63
+
64
+ def divergence_approx(self, f, y, e=None):
65
+ e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
66
+ e_dzdx_e = e_dzdx.mul(e)
67
+
68
+ cnt = 0
69
+ while not e_dzdx_e.requires_grad and cnt < 10:
70
+ e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
71
+ e_dzdx_e = e_dzdx * e
72
+ cnt += 1
73
+
74
+ approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
75
+ assert approx_tr_dzdx.requires_grad, \
76
+ "(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
77
+ % (
78
+ f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad,
79
+ e_dzdx_e.requires_grad, cnt)
80
+ return approx_tr_dzdx
81
+
82
+ def before_odeint(self, e=None):
83
+ self.noise = e
84
+
85
+ def sample_rademacher_like(self, z):
86
+ if not self.training:
87
+ torch.manual_seed(0)
88
+ return torch.randint(low=0, high=2, size=z.shape).to(z) * 2 - 1
89
+
90
+
91
+ class FFJORD(Model):
92
+ """
93
+ This class is an implementation of the FFJORD model as proposed in
94
+ https://arxiv.org/pdf/1810.01367
95
+ """
96
+ def __init__(self, ode_function, normalize=True):
97
+ super().__init__()
98
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
99
+ self.input_dim = ode_function.input_dim
100
+ self.time_deriv_func = ODEBlock(ode_function=ode_function)
101
+ self.latent_dist = torch.distributions.MultivariateNormal(
102
+ torch.zeros(self.input_dim, device=device),
103
+ torch.eye(self.input_dim, device=device),
104
+ )
105
+ self.normalize = normalize
106
+ if self.normalize:
107
+ self.input_norm = MovingBatchNorm1d(self.input_dim, bn_lag=0)
108
+ self.output_norm = MovingBatchNorm1d(self.input_dim, bn_lag=0)
109
+
110
+ @staticmethod
111
+ def load_model(model_path, checkpoint):
112
+ model_dict = pb.io.load_yaml(model_path / "config.yaml")
113
+ model = Model.from_config(model_dict['model'])
114
+ cp = torch.load(
115
+ model_path / checkpoint,
116
+ map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
117
+ )
118
+ model_weights = cp.copy()
119
+ model.load_state_dict(model_weights['model'])
120
+ model.eval()
121
+ return model
122
+
123
+ def forward(
124
+ self,
125
+ state: tuple[torch.Tensor, torch.Tensor],
126
+ integration_times: torch.Tensor = torch.tensor([0.0, 1.0]
127
+ )
128
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
129
+ """
130
+ Integration from t_1 (data distribution) to t_0 (base distribution).
131
+ (training step)
132
+
133
+ Args:
134
+ state (Tuple[torch.Tensor, torch.Tensor]):
135
+ - z (torch.Tensor): (batch_size, feature_dim) representing the input data.
136
+ - labels (torch.Tensor): (batch_size, num_labeled_classes)
137
+
138
+ integration_times (torch.Tensor, optional): A tensor of shape (2,)
139
+ specifying the start and end times for integration. Defaults to torch.tensor([0.0, 1.0]).
140
+
141
+ Returns:
142
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
143
+ - dz_dt (torch.Tensor): A tensor of shape (batch_size, feature_dim) representing the derivative of z w.r.t. time.
144
+ - -d_log_dz_dt (torch.Tensor): (batch_size, 1) representing the negative log derivative.
145
+ - labels (torch.Tensor): (batch_size, num_labeled_classes)
146
+ """
147
+ z_1, labels = state
148
+
149
+ if z_1.dim() == 3:
150
+ z_1 = z_1.squeeze(1)
151
+
152
+ delta_logpz = torch.zeros(z_1.shape[0], 1).to(z_1.device)
153
+
154
+ if self.normalize:
155
+ z_1, delta_logpz = self.input_norm(z_1, context=labels, logpx=delta_logpz)
156
+
157
+ self.time_deriv_func.before_odeint()
158
+ state = odeint(
159
+ self.time_deriv_func, # Calculates time derivatives.
160
+ (z_1, delta_logpz, labels), # Values to update. init states
161
+ integration_times.to(z_1.device), # When to evaluate.
162
+ method='dopri5', # Runge-Kutta
163
+ atol=1e-5, # Error tolerance
164
+ rtol=1e-5, # Error tolerance
165
+ )
166
+ if self.normalize:
167
+ dz_dt, d_delta_log_dz_t = self.output_norm(state[0], context=state[2], logpx=state[1])
168
+ else:
169
+ dz_dt, d_delta_log_dz_t = state[0], state[1]
170
+
171
+ state = (dz_dt, d_delta_log_dz_t, labels)
172
+
173
+ if len(integration_times) == 2:
174
+ state = tuple(s[1] if s.shape[0] > 1 else s[0] for s in state)
175
+ return state
176
+
177
+ def sample(
178
+ self,
179
+ state: tuple[torch.Tensor, torch.Tensor],
180
+ integration_times: torch.Tensor = torch.tensor([1.0, 0.0])
181
+ ) -> tuple[torch.Tensor, torch.Tensor]:
182
+ """
183
+ This is the sampling step. Integration from t_0 (base distribution) to t_1 (data distribution).
184
+
185
+ Args:
186
+ state (Tuple[torch.Tensor, torch.Tensor]):
187
+ - z_0 (torch.Tensor): (batch_size, feature_dim) representing the initial state from the base distribution
188
+ - labels (torch.Tensor): (batch_size, num_labeled_classes)
189
+
190
+ integration_times (torch.Tensor, optional): A tensor of shape (2,) specifying the start (t_0) and end (t_1) times for integration.
191
+ Defaults to torch.tensor([1.0, 0.0])
192
+
193
+ Returns:
194
+ Tuple[torch.Tensor, torch.Tensor]:
195
+ - z_t1 (torch.Tensor): (batch_size, feature_dim) representing the sampled data at time t_1 (data distribution).
196
+ - labels (torch.Tensor): (batch_size, num_labeled_classes)
197
+ """
198
+ z_0, labels = state
199
+ delta_logpz = torch.zeros(z_0.shape[0], 1).to(z_0.device)
200
+ if self.normalize:
201
+ z_0, delta_logpz = self.output_norm(
202
+ z_0,
203
+ context=labels,
204
+ logpx=delta_logpz,
205
+ reverse=True
206
+ )
207
+
208
+ state = odeint(
209
+ self.time_deriv_func, # Calculates time derivatives.
210
+ (z_0, delta_logpz, labels), # Values to update. init states
211
+ integration_times.to(z_0.device), # When to evaluate.
212
+ method='dopri5', # Runge-Kutta
213
+ atol=1e-5, # Error tolerance
214
+ rtol=1e-5, # Error tolerance
215
+ )
216
+ if self.normalize:
217
+ dz_dt, d_delta_log_dz_t = self.input_norm(
218
+ state[0],
219
+ context=state[2],
220
+ logpx=state[1],
221
+ reverse=True
222
+ )
223
+ else:
224
+ dz_dt, d_delta_log_dz_t = state[0], state[1]
225
+ state = (dz_dt, d_delta_log_dz_t, labels)
226
+
227
+ if len(integration_times) == 2:
228
+ state = tuple(s[1] if s.shape[0] > 1 else s[0] for s in state)
229
+ return state
230
+
231
+ def example_to_device(self, examples, device):
232
+ observations = [example['observation'] for example in examples]
233
+ labels = [example['speaker_conditioning'].tolist() for example in examples if 'speaker_conditioning' in example]
234
+ observations_tensor = torch.tensor(observations, device=device, dtype=torch.float)
235
+ labels_tensor = torch.tensor(labels, device=device, dtype=torch.float) if labels else None
236
+ return observations_tensor, labels_tensor
237
+
238
+ def review(self, example, outputs):
239
+ z_t0, delta_logpz, _ = outputs
240
+ logpz_t1 = self.latent_dist.log_prob(z_t0) - delta_logpz
241
+ loss = -torch.mean(logpz_t1)
242
+ return dict(loss=loss, scalars=dict(loss=loss))
243
+
244
+ def modify_summary(self, summary):
245
+ summary = super().modify_summary(summary)
246
+ return summary
247
+
pvq_manipulation/models/hubert.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from contextlib import nullcontext
4
+ import typing as tp
5
+ from typing import List, Tuple, Optional
6
+
7
+ import einops
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchaudio
12
+
13
+ import padertorch as pt
14
+ from padertorch.contrib.je.modules.conv_utils import (
15
+ compute_conv_output_sequence_lengths
16
+ )
17
+ from padertorch.utils import to_numpy
18
+ from transformers.models.hubert.modeling_hubert import HubertModel
19
+
20
+ # See https://ieeexplore.ieee.org/abstract/document/9814838, Fig. 2
21
+ PR_BASE_LAYER = 11
22
+ PR_LARGE_LAYER = 22
23
+ SID_BASE_LAYER = 4
24
+ SID_LARGE_LAYER = 6
25
+
26
+
27
+ def tuple_to_int(sequence) -> list:
28
+ return list(map(lambda t: t[0], sequence))
29
+
30
+
31
+ class HubertExtractor(pt.Module):
32
+ """Extract HuBERT features from raw waveform.
33
+
34
+ Args:
35
+ model_name (str): Name of the pretrained HuBERT model on huggingface.co.
36
+ Defaults to "facebook/hubert-large-ll60k".
37
+ layer (int): Index of the layer to extract features from. Defaults to
38
+ 22.
39
+ freeze (bool): If True, freeze the weights of the encoder
40
+ (i.e., no finetuning of Transformer layers). Defaults to True.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ model_name: str = "facebook/hubert-large-ll60k",
46
+ layer: tp.Union[int, str] = PR_LARGE_LAYER,
47
+ freeze: bool = True,
48
+ detach: bool = False,
49
+ device: str = "cpu",
50
+ backend: str = "torchaudio",
51
+ storage_dir: str = None,
52
+ ):
53
+ super().__init__()
54
+
55
+ if not freeze and detach:
56
+ raise ValueError(
57
+ 'detach=True only supported if freeze=True\n'
58
+ f'Got: freeze={freeze}, detach={detach}'
59
+ )
60
+ if backend == "torchaudio":
61
+ bundle = getattr(torchaudio.pipelines, model_name)
62
+ self.model = bundle.get_model(dl_kwargs={'model_dir': storage_dir}).eval().to(device)
63
+ self.sampling_rate = bundle.sample_rate
64
+ else:
65
+ raise ValueError(f'Unknown backend: {backend}')
66
+ self.backend = backend
67
+
68
+ if freeze:
69
+ for param in self.model.parameters():
70
+ param.requires_grad = False
71
+ else:
72
+ # Always freeze feature extractor and feature projection layers
73
+ for param in self.model.feature_extractor.parameters():
74
+ param.requires_grad = False
75
+ for param in self.model.feature_projection.parameters():
76
+ param.requires_grad = False
77
+
78
+ self.layer = layer
79
+ self.freeze = freeze
80
+ self.detach = detach
81
+
82
+ if self.layer == 'all':
83
+ self.weights = torch.nn.Parameter(
84
+ torch.ones(24), requires_grad=True
85
+ )
86
+
87
+ @property
88
+ def cache_dir(self):
89
+ return Path(os.environ['STORAGE_ROOT']) / 'huggingface' / 'hub'
90
+
91
+ @property
92
+ def context(self):
93
+ if self.detach:
94
+ return torch.no_grad()
95
+ else:
96
+ return nullcontext()
97
+
98
+ def compute_output_lengths(
99
+ self, input_lengths: Optional[List[int]]
100
+ ) -> Optional[List[int]]:
101
+ """Compute the number of time frames for each batch entry.
102
+
103
+ Args:
104
+ input_lengths: List with number of samples per batch entry.
105
+
106
+ Returns:
107
+ List with number of time frames per batch entry.
108
+ """
109
+ if input_lengths is None:
110
+ return input_lengths
111
+ output_lengths = np.asarray(input_lengths) + self.window_size - 1
112
+ for kernel_size, dilation, stride in zip(
113
+ self.kernel_sizes, self.dilations, self.strides,
114
+ ):
115
+ output_lengths = compute_conv_output_sequence_lengths(
116
+ output_lengths,
117
+ kernel_size=kernel_size,
118
+ dilation=dilation,
119
+ pad_type=None,
120
+ stride=stride,
121
+ )
122
+ return output_lengths.tolist()
123
+
124
+ def forward(
125
+ self,
126
+ time_signal: torch.Tensor,
127
+ sampling_rate: int,
128
+ sequence_lengths: Optional[List[int]] = None,
129
+ extract_features: bool = False,
130
+ other_inputs: Optional[dict] = None,
131
+ ) -> Tuple[torch.Tensor, Optional[List[int]]]:
132
+ """Extract HuBERT features from raw waveform.
133
+
134
+ Args:
135
+ time_signal: Time signal of shape (batch, 1, time) or (batch, time)
136
+ sampled at 16 kHz.
137
+ sequence_lengths: List with number of samples per batch entry.
138
+
139
+ Returns:
140
+ x (torch.Tensor): HuBERT features of shape
141
+ (batch, D, time frames).
142
+ sequence_lengths (List[int]): List with number of time frames per
143
+ batch entry.
144
+ """
145
+ del other_inputs
146
+
147
+ if time_signal.ndim == 3:
148
+ time_signal = einops.rearrange(time_signal, 'b c t -> (b c) t')
149
+
150
+ time_signal = torchaudio.functional.resample(
151
+ time_signal, sampling_rate, self.sampling_rate
152
+ )
153
+ if sequence_lengths is not None:
154
+ if isinstance(sequence_lengths, (list, tuple)):
155
+ sequence_lengths = torch.tensor(sequence_lengths).long() \
156
+ .to(time_signal.device)
157
+ sequence_lengths = (
158
+ sequence_lengths / sampling_rate * self.sampling_rate
159
+ ).long()
160
+
161
+ if self.freeze or self.detach:
162
+ self.model.eval()
163
+ with self.context:
164
+ if self.backend == "torchaudio":
165
+ self.model: torchaudio.models.Wav2Vec2Model
166
+ x, sequence_lengths = self.model.extract_features(
167
+ time_signal, lengths=sequence_lengths,
168
+ num_layers=self.layer,
169
+ )
170
+ if isinstance(self.layer, int):
171
+ x = x[-1].transpose(1, 2)
172
+ else:
173
+ raise NotImplementedError(self.layer)
174
+ return x, sequence_lengths
175
+
176
+ self.model: HubertModel
177
+ n_pad = self.window_size - 1
178
+ time_signal = F.pad(time_signal, (0, n_pad), value=0)
179
+ if extract_features:
180
+ features = self.model.feature_extractor(time_signal.float()) \
181
+ .transpose(1, 2)
182
+ x = self.model.feature_projection(features).transpose(1, 2)
183
+ else:
184
+ outputs = self.model(
185
+ time_signal.float(), output_hidden_states=True
186
+ )
187
+ if isinstance(self.layer, int):
188
+ x = outputs.hidden_states[self.layer].transpose(1, 2)
189
+ if self.detach:
190
+ x = x.detach()
191
+ elif self.layer == 'all':
192
+ hidden_states = []
193
+ for _, hidden_state in enumerate(outputs.hidden_states):
194
+ x = hidden_state.transpose(1, 2)
195
+ if self.detach:
196
+ x = x.detach()
197
+ hidden_states.append(x)
198
+ hidden_states = torch.stack(hidden_states) # (L, B, D, T)
199
+ x = (hidden_states * self.weights[:, None, None, None]) \
200
+ .sum(dim=0)
201
+ else:
202
+ raise ValueError(f'Unknown layer: {self.layer}')
203
+
204
+ sequence_lengths = to_numpy(sequence_lengths)
205
+ sequence_lengths = self.compute_output_lengths(sequence_lengths)
206
+
207
+ return x.unsqueeze(1), sequence_lengths
pvq_manipulation/models/ode_functions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Δz = f(t, z, labels)
3
+ f() is a neural network with the architecture defined in StyleFlow
4
+ StyleFlow: Attribute-conditioned Exploration of StyleGAN-Generated Images using Conditional Continuous Normalizing Flows
5
+ """
6
+ import torch
7
+
8
+
9
+ class CNFNN(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ input_dim,
13
+ condition_dim,
14
+ hidden_channels,
15
+ ):
16
+ super().__init__()
17
+ self.layers = torch.nn.ModuleList()
18
+ hidden_dims = hidden_channels + [input_dim]
19
+ self.input_dim = input_dim
20
+
21
+ for idx, hidden_dim in enumerate(hidden_dims):
22
+ self.layers.append(CNFBlock(
23
+ input_dim=input_dim,
24
+ condition_dim=condition_dim,
25
+ output_dim=hidden_dim,
26
+ output_layer=False if idx < len(hidden_dims) - 1 else True,
27
+ ))
28
+ input_dim = hidden_dim
29
+
30
+ def forward(self, t, z, labels):
31
+ """
32
+ This function computes: Δz = f(t, z, labels)
33
+
34
+ Args:
35
+ t (torch.Tensor): () Time step of the ODE
36
+ z (torch.Tensor): (Batch_size, Input_dim) Intermediate value
37
+ labels (torch.Tensor): (Batch_size, condition_dim) Speaker attributes
38
+
39
+ Returns:
40
+ Δz (torch.Tensor): : (Batch_size, Input_dim) Computed delta
41
+ """
42
+ for layer in self.layers:
43
+ z = layer(t, z, labels)
44
+ return z
45
+
46
+
47
+ class CNFBlock(torch.nn.Module):
48
+ def __init__(
49
+ self,
50
+ input_dim,
51
+ output_dim,
52
+ condition_dim,
53
+ output_layer,
54
+ ):
55
+ super().__init__()
56
+ self._layer = torch.nn.Linear(input_dim, output_dim)
57
+ self._hyper_bias = torch.nn.Linear(
58
+ 1 + condition_dim,
59
+ output_dim,
60
+ bias=False
61
+ )
62
+ self._hyper_gate = torch.nn.Linear(
63
+ 1 + condition_dim,
64
+ output_dim
65
+ )
66
+ self.output_layer = output_layer
67
+
68
+ def forward(self, t, z, labels):
69
+ """
70
+ Args:
71
+ t (torch.Tensor): () Time step of the ODE
72
+ z (torch.Tensor): (Batch_size, Input_dim) Intermediate value
73
+ labels (torch.Tensor): (Batch_size, condition_dim) Speaker attributes
74
+
75
+ Returns:
76
+ z (torch.Tensor): : (Batch_size, Output_dim) Intermediate value
77
+ """
78
+ if labels.dim() == 1:
79
+ labels = labels[:, None]
80
+ elif labels.dim() == 3:
81
+ labels = labels.squeeze(1)
82
+
83
+ tz_cat = torch.cat((t.expand(z.shape[0], 1), labels), dim=1)
84
+
85
+ gate = torch.sigmoid(self._hyper_gate(tz_cat))
86
+ bias = self._hyper_bias(tz_cat)
87
+
88
+ if z.dim() == 3:
89
+ gate = gate.unsqueeze(1)
90
+ bias = bias.unsqueeze(1)
91
+
92
+ z = self._layer(z) * gate + bias
93
+
94
+ if not self.output_layer:
95
+ z = torch.tanh(z)
96
+ return z
pvq_manipulation/models/vits.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a wrapper for the TTS VITS model.
3
+ TTS.tts.models.vits
4
+ https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/models/vits.py
5
+ """
6
+ import os
7
+ import numpy as np
8
+ import paderbox as pb
9
+ import padertorch as pt
10
+ import torch
11
+
12
+ from coqpit import Coqpit
13
+ from padertorch.ops._stft import STFT
14
+ from pathlib import Path
15
+ from pvq_manipulation.helper.utils import VitsAudioConfig_NT, VitsConfig_NT, load_audio
16
+
17
+ from torch.utils.data import DataLoader
18
+ from torch.cuda.amp.autocast_mode import autocast
19
+ from TTS.tts.configs.shared_configs import CharactersConfig
20
+ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
21
+ from TTS.tts.models.vits import Vits, VitsArgs, VitsDataset, spec_to_mel, wav_to_spec
22
+ from TTS.tts.utils.languages import LanguageManager
23
+ from TTS.tts.utils.speakers import SpeakerManager
24
+ from TTS.tts.utils.synthesis import embedding_to_torch, numpy_to_torch
25
+ from TTS.tts.utils.text.tokenizer import TTSTokenizer
26
+ from TTS.tts.utils.helpers import generate_path, rand_segments, segment, sequence_mask
27
+ from TTS.utils.audio import AudioProcessor
28
+ from TTS.vocoder.models.hifigan_generator import HifiganGenerator
29
+ from trainer.trainer import to_cuda
30
+ from typing import Dict, List, Union
31
+
32
+
33
+ STORAGE_ROOT = Path(os.getenv('STORAGE_ROOT')).expanduser()
34
+
35
+
36
+ class Vits_NT(Vits):
37
+ def __init__(
38
+ self,
39
+ config: Coqpit,
40
+ ap: "AudioProcessor" = None,
41
+ tokenizer: "TTSTokenizer" = None,
42
+ speaker_manager: SpeakerManager = None,
43
+ language_manager: LanguageManager = None,
44
+ sample_rate: int = None,
45
+ ):
46
+ super().__init__(
47
+ config,
48
+ ap,
49
+ tokenizer,
50
+ speaker_manager,
51
+ language_manager
52
+ )
53
+ self.sample_rate = sample_rate
54
+ self.embedded_speaker_dim = self.args.d_vector_dim
55
+ self.posterior_encoder = PosteriorEncoder(
56
+ self.args.out_channels,
57
+ self.args.hidden_channels,
58
+ self.args.hidden_channels,
59
+ kernel_size=self.args.kernel_size_posterior_encoder,
60
+ dilation_rate=self.args.dilation_rate_posterior_encoder,
61
+ num_layers=self.args.num_layers_posterior_encoder,
62
+ cond_channels=self.embedded_speaker_dim,
63
+ )
64
+
65
+ self.flow = ResidualCouplingBlocks(
66
+ self.args.hidden_channels,
67
+ self.args.hidden_channels,
68
+ kernel_size=self.args.kernel_size_flow,
69
+ dilation_rate=self.args.dilation_rate_flow,
70
+ num_layers=self.args.num_layers_flow,
71
+ cond_channels=self.embedded_speaker_dim,
72
+ )
73
+
74
+ self.text_encoder = TextEncoder(
75
+ self.args.num_chars,
76
+ self.args.hidden_channels,
77
+ self.args.hidden_channels,
78
+ self.args.hidden_channels_ffn_text_encoder,
79
+ self.args.num_heads_text_encoder,
80
+ self.args.num_layers_text_encoder,
81
+ self.args.kernel_size_text_encoder,
82
+ self.args.dropout_p_text_encoder,
83
+ language_emb_dim=self.embedded_language_dim,
84
+ )
85
+ self.waveform_decoder = HifiganGenerator(
86
+ self.args.hidden_channels,
87
+ 1,
88
+ self.args.resblock_type_decoder,
89
+ self.args.resblock_dilation_sizes_decoder,
90
+ self.args.resblock_kernel_sizes_decoder,
91
+ self.args.upsample_kernel_sizes_decoder,
92
+ self.args.upsample_initial_channel_decoder,
93
+ self.args.upsample_rates_decoder,
94
+ inference_padding=0,
95
+ cond_channels=self.embedded_speaker_dim if self.config.gan_speaker_conditioning else 0,
96
+ conv_pre_weight_norm=False,
97
+ conv_post_weight_norm=False,
98
+ conv_post_bias=False,
99
+ )
100
+ self.speaker_manager = self.speaker_manager
101
+ self.speaker_encoder = self.speaker_manager
102
+
103
+ self.speaker_manager.eval()
104
+
105
+ self.epoch = 0
106
+ self.num_epochs = config['epochs']
107
+ self.lr_lambda = 0
108
+ self.config_solver = config['CONFIG_SOLVER']
109
+ self.config = config
110
+
111
+ self.stft = STFT(
112
+ size=self.config.audio.win_length,
113
+ shift=self.config.audio.hop_length,
114
+ window_length=self.config.audio.win_length,
115
+ fading=self.config.audio.fading,
116
+ window=self.config.audio.window,
117
+ pad=self.config.audio.pad
118
+ )
119
+
120
+ def get_spectogram_nt(self, wav):
121
+ """
122
+ Extracts spectrogram from audio
123
+ Args:
124
+ wav (torch.Tensor): (Batch_size, Num_samples)
125
+ Returns:
126
+ spectrogram (torch.Tensor): (Batch_size, Frequency_bins, Time) spectrogram
127
+ """
128
+ wav = wav.squeeze(1)
129
+ stft_signal = self.stft(wav)
130
+ stft_signal = torch.einsum('btf-> bft', stft_signal)
131
+ spectrogram = stft_signal.real ** 2 + stft_signal.imag ** 2
132
+ spectrogram = torch.sqrt(spectrogram + 1e-6)
133
+ return spectrogram
134
+
135
+ def get_aux_input_from_test_sentences(self, sentence_info):
136
+ """
137
+ Get aux input for the inference step from test sentences
138
+ Args:
139
+ sentence_info (dict): Expected keys:
140
+ - "d_vector_storage_root" (str)
141
+ - "d_vector" (torch.Tensor)
142
+ - "d_vector_man" (torch.Tensor) (optional)
143
+ Returns:
144
+ aux_input (dict): aux input for the inference step
145
+ """
146
+ if 'd_vector' not in sentence_info.keys():
147
+ d_vector_file = sentence_info['d_vector_storage_root']
148
+ d_vector = torch.load(d_vector_file)
149
+ return {"d_vector": d_vector, **sentence_info}
150
+ else:
151
+ return sentence_info
152
+
153
+ @staticmethod
154
+ def init_from_config(
155
+ config: "VitsConfig",
156
+ samples= None,
157
+ verbose=True
158
+ ):
159
+ """
160
+ Initiate model from config
161
+ Args:
162
+ config (VitsConfig): Model config.
163
+ samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
164
+ Defaults to None.
165
+ Returns:
166
+ model (Vits): Initialized model.
167
+ """
168
+ upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
169
+ assert (upsample_rate == config.audio.hop_length), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
170
+ ap = AudioProcessor.init_from_config(config, verbose=verbose)
171
+ tokenizer, new_config = TTSTokenizer.init_from_config(config)
172
+ language_manager = LanguageManager.init_from_config(config)
173
+ speaker_manager = pt.Module.from_storage_dir(
174
+ config['d_vector_model_file'],
175
+ checkpoint_name='ckpt_latest.pth',
176
+ consider_mpi=False,
177
+ config_name='config.json',
178
+ )
179
+ speaker_manager.num_speakers = config['num_speakers']
180
+ for param in speaker_manager.parameters():
181
+ param.requires_grad = False
182
+
183
+ return Vits_NT(
184
+ new_config,
185
+ ap,
186
+ tokenizer,
187
+ speaker_manager=speaker_manager,
188
+ language_manager=language_manager,
189
+ sample_rate=config['sample_rate'],
190
+ )
191
+
192
+ @torch.no_grad()
193
+ def inference(self, x, aux_input=None):
194
+ """
195
+ Note:
196
+ To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
197
+
198
+ Args:
199
+ x (torch.Tensor): (batch_size, T_seq) or (T_seq) Input character sequence IDs
200
+ aux_input (dict): Expected keys:
201
+ - d_vector (torch.Tensor): (batch_size, Feature_dim) speaker_embedding
202
+ - x_lengths: (torch.Tensor): (batch_size) length of each text token
203
+
204
+ Returns:
205
+ - model_outputs (torch.Tensor): (batch_size, T_wav) Synthesized waveform
206
+ """
207
+ speaker_embedding = aux_input['d_vector'].detach()[:, :, None]
208
+ if aux_input['d_vector_man'] is not None:
209
+ speaker_embedding_man = aux_input['d_vector_man'].detach()[:, :, None]
210
+ else:
211
+ speaker_embedding_man = speaker_embedding
212
+ aux_input['tokens'] = x.clone()
213
+ x_lengths = self._set_x_lengths(x, aux_input)
214
+ x, m_p, logs_p, x_mask = self.text_encoder(
215
+ x,
216
+ x_lengths,
217
+ lang_emb=None
218
+ )
219
+ logw = self.duration_predictor(
220
+ x,
221
+ x_mask,
222
+ g=speaker_embedding,
223
+ lang_emb=None,
224
+ )
225
+
226
+ w = torch.exp(logw) * x_mask * self.length_scale
227
+ w_ceil = torch.ceil(w)
228
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
229
+ y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
230
+
231
+ attn_mask = x_mask * y_mask.transpose(1, 2)
232
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
233
+ m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
234
+ logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
235
+
236
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
237
+
238
+ z = self.flow(z_p, y_mask, g=speaker_embedding_man, reverse=True)
239
+ z, _, _, y_mask = self.upsampling_z(
240
+ z,
241
+ y_lengths=y_lengths,
242
+ y_mask=y_mask
243
+ )
244
+ o = self.waveform_decoder(
245
+ (z * y_mask)[:, :, : self.max_inference_len],
246
+ g=speaker_embedding_man if self.config.gan_speaker_conditioning else None
247
+ )
248
+ return o
249
+
250
+ def forward(self, x, x_lengths, y, y_lengths, aux_input, inference=False):
251
+ """
252
+ Forward pass of the model.
253
+
254
+ Args:
255
+ x (torch.tensor): (Batch, T_seq) Input character sequence IDs
256
+ x_lengths (torch.tensor): (Batch) Input character sequence lengths.
257
+ y (torch.tensor): (Batch_size, Frequency_bins, Time) Input spectrograms.
258
+ y_lengths (torch.tensor): (Batch) Input spectrogram lengths.
259
+ aux_input (dict, optional): Expected keys:
260
+ - d_vector (torch.Tensor): (batch_size, Feature_dim) speaker_embedding
261
+ - waveform: (torch.Tensor): (Batch_size, Num_samples) Target waveform
262
+ Returns:
263
+ Dict: model outputs keyed by the output name.
264
+ """
265
+ outputs = {}
266
+ speaker_embedding = aux_input['d_vector'].detach()[:, :, None]
267
+ x, m_p, logs_p, x_mask = self.text_encoder(
268
+ x,
269
+ x_lengths,
270
+ lang_emb=None
271
+ )
272
+ z, m_q, logs_q, y_mask = self.posterior_encoder(
273
+ y,
274
+ y_lengths,
275
+ g=speaker_embedding,
276
+ )
277
+ z_p = self.flow(z, y_mask, g=speaker_embedding)
278
+ outputs, attn = self.forward_mas(
279
+ outputs,
280
+ z_p,
281
+ m_p,
282
+ logs_p,
283
+ x,
284
+ x_mask,
285
+ y_mask,
286
+ g=speaker_embedding,
287
+ lang_emb=None,
288
+ )
289
+ m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
290
+ logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
291
+
292
+ z_slice, slice_ids = rand_segments(
293
+ z,
294
+ y_lengths,
295
+ self.spec_segment_size,
296
+ let_short_samples=True,
297
+ pad_short=True
298
+ )
299
+ z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(
300
+ z_slice,
301
+ slice_ids=slice_ids,
302
+ )
303
+
304
+ wav_seg = segment(
305
+ aux_input['waveform'],
306
+ slice_ids * self.config.audio.hop_length,
307
+ spec_segment_size * self.config.audio.hop_length,
308
+ pad_short=True,
309
+ )
310
+ o = self.waveform_decoder(
311
+ z_slice,
312
+ g=speaker_embedding if self.config.gan_speaker_conditioning else None
313
+ )
314
+
315
+ if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
316
+ wavs_batch = torch.cat((wav_seg, o), dim=0)
317
+ if self.audio_transform is not None:
318
+ wavs_batch = self.audio_transform(wavs_batch)
319
+ with torch.no_grad():
320
+ pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
321
+ gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
322
+ else:
323
+ gt_spk_emb, syn_spk_emb = None, None
324
+
325
+ outputs.update(
326
+ {
327
+ "model_outputs": o,
328
+ "alignments": attn.squeeze(1),
329
+ "m_p": m_p,
330
+ "logs_p": logs_p,
331
+ "z": z,
332
+ "z_p": z_p,
333
+ "m_q": m_q,
334
+ "logs_q": logs_q,
335
+ "waveform_seg": wav_seg,
336
+ "gt_spk_emb": gt_spk_emb,
337
+ "syn_spk_emb": syn_spk_emb,
338
+ "slice_ids": slice_ids,
339
+ "z_slice": z_slice,
340
+ "speaker_embedding": speaker_embedding,
341
+ }
342
+ )
343
+ return outputs
344
+
345
+ @staticmethod
346
+ def load_model(model_path, checkpoint):
347
+ """
348
+ Load model from checkpoint
349
+
350
+ Args:
351
+ model_path (str): model path
352
+ checkpoint (str): checkpoint name
353
+
354
+ Returns:
355
+ model (pvq_manipulation.models.vits.Vits_NT): model
356
+ """
357
+ config = pb.io.load_json(model_path / "config.json")
358
+ model_args = VitsArgs(**config['model_args'])
359
+ audio_config = VitsAudioConfig_NT(**config['audio'])
360
+ characters_config = CharactersConfig(**config['characters'])
361
+ del config['audio']
362
+ del config['characters']
363
+ del config['model_args']
364
+
365
+ config = VitsConfig_NT(
366
+ model_args=model_args,
367
+ audio=audio_config,
368
+ characters=characters_config,
369
+ **config,
370
+ )
371
+ model = Vits_NT.init_from_config(config)
372
+ cp = torch.load(
373
+ model_path / checkpoint,
374
+ map_location=torch.device('cpu')
375
+ )
376
+ model_weights = cp['model'].copy()
377
+ model.load_state_dict(model_weights, strict=False)
378
+ model.eval()
379
+ return model
380
+
381
+ def synthesize_from_example(self, s_info):
382
+ """
383
+ Synthesize voice from example
384
+
385
+ Args:
386
+ s_info (dict): Expected keys:
387
+ - "speaker_id" (str),
388
+ - "example_id" (str),
389
+ - "audio_path" (str),
390
+ - "d_vector_storage_root" (str),
391
+ - "text" (str) specifying the text to synthesize
392
+ Returns:
393
+ wav (torch.Tensor): synthesized waveform
394
+ """
395
+ aux_inputs = self.get_aux_input_from_test_sentences(s_info)
396
+ use_cuda = "cuda" in str(next(self.parameters()).device)
397
+
398
+ device = next(self.parameters()).device
399
+ if use_cuda:
400
+ device = "cuda"
401
+
402
+ text_inputs = np.asarray(
403
+ self.tokenizer.text_to_ids(aux_inputs["text"], language=None),
404
+ dtype=np.int32,
405
+ )
406
+ d_vector = embedding_to_torch(aux_inputs["d_vector"], device=device)
407
+
408
+ if "d_vector_man" in aux_inputs.keys():
409
+ d_vector_man = embedding_to_torch(aux_inputs["d_vector_man"], device=device)
410
+
411
+ text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
412
+ text_inputs = text_inputs.unsqueeze(0)
413
+
414
+ wav = self.inference(
415
+ text_inputs,
416
+ aux_input={
417
+ "x_lengths": torch.tensor(
418
+ text_inputs.shape[1:2]
419
+ ).to(text_inputs.device),
420
+ "d_vector": d_vector,
421
+ "d_vector_man": d_vector_man if "d_vector_man" in aux_inputs.keys() else None
422
+ }
423
+ )[0].data.cpu().numpy().squeeze()
424
+ return wav
425
+
426
+ def format_batch_on_device(self, batch):
427
+ """Format batch on device"""
428
+ ac = self.config.audio
429
+
430
+ batch['waveform'] = to_cuda(batch['waveform'])
431
+ wav = batch["waveform"]
432
+
433
+ batch['spec'] = self.get_spectogram_nt(wav)
434
+
435
+ if self.args.encoder_sample_rate:
436
+ spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
437
+ if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor):
438
+ spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
439
+ else:
440
+ batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)]
441
+ else:
442
+ spec_mel = batch["spec"]
443
+
444
+ batch["mel"] = spec_to_mel(
445
+ spec=spec_mel,
446
+ n_fft=ac.fft_size,
447
+ num_mels=ac.num_mels,
448
+ sample_rate=ac.sample_rate,
449
+ fmin=ac.mel_fmin,
450
+ fmax=ac.mel_fmax,
451
+ )
452
+
453
+ if self.args.encoder_sample_rate:
454
+ assert batch["spec"].shape[2] == int(
455
+ batch["mel"].shape[2] / self.interpolate_factor
456
+ ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
457
+ else:
458
+ assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
459
+
460
+ batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
461
+ batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
462
+
463
+ if self.args.encoder_sample_rate:
464
+ assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0
465
+ else:
466
+ assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
467
+
468
+ batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1)
469
+ batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
470
+ return batch
471
+
472
+ def train_step(
473
+ self,
474
+ batch: dict,
475
+ criterion: torch.nn.Module,
476
+ optimizer_idx: int,
477
+ ):
478
+ """
479
+ Perform a single training step. Run the model forward pass and compute losses.
480
+
481
+ Args:
482
+ batch (Dict): Input tensors.
483
+ criterion (nn.Module): Loss layer designed for the model.
484
+ optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
485
+
486
+ Returns:
487
+ Tuple[Dict, Dict]: Model ouputs and computed losses.
488
+ """
489
+ if optimizer_idx == 0:
490
+ # generator pass
491
+ outputs = self.forward(
492
+ batch["tokens"],
493
+ batch["token_lens"],
494
+ batch["spec"],
495
+ batch["spec_lens"],
496
+ aux_input={
497
+ **batch,
498
+ },
499
+ )
500
+ # cache tensors for the generator pass
501
+ self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
502
+ scores_disc_fake, _, scores_disc_real, _ = self.disc(
503
+ outputs["model_outputs"].detach(),
504
+ outputs["waveform_seg"]
505
+ )
506
+ # compute loss
507
+ with autocast(enabled=False): # use float32 for the criterion
508
+ loss_dict = criterion[optimizer_idx](
509
+ scores_disc_real,
510
+ scores_disc_fake,
511
+ )
512
+ return outputs, loss_dict
513
+
514
+ if optimizer_idx == 1:
515
+ # compute melspec segment
516
+ with autocast(enabled=False):
517
+ if self.args.encoder_sample_rate:
518
+ spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
519
+ else:
520
+ spec_segment_size = self.spec_segment_size
521
+ mel_slice = segment(
522
+ batch["mel"].float(),
523
+ self.model_outputs_cache["slice_ids"],
524
+ spec_segment_size,
525
+ pad_short=True
526
+ )
527
+
528
+ spec = self.get_spectogram_nt(
529
+ self.model_outputs_cache["model_outputs"].float(),
530
+ )
531
+ mel_slice_hat = spec_to_mel(
532
+ spec=spec,
533
+ n_fft=self.config.audio.fft_size,
534
+ num_mels=self.config.audio.num_mels,
535
+ sample_rate=self.config.audio.sample_rate,
536
+ fmin=self.config.audio.mel_fmin,
537
+ fmax=self.config.audio.mel_fmax,
538
+ )
539
+
540
+ # compute discriminator scores and features
541
+ scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
542
+ self.model_outputs_cache["model_outputs"],
543
+ self.model_outputs_cache["waveform_seg"],
544
+ )
545
+
546
+ # compute losses
547
+ with autocast(enabled=False): # use float32 for the criterion
548
+ loss_dict = criterion[optimizer_idx](
549
+ mel_slice_hat=mel_slice.float(),
550
+ mel_slice=mel_slice_hat.float(),
551
+ z_p=self.model_outputs_cache["z_p"].float(),
552
+ logs_q=self.model_outputs_cache["logs_q"].float(),
553
+ m_p=self.model_outputs_cache["m_p"].float(),
554
+ logs_p=self.model_outputs_cache["logs_p"].float(),
555
+ z_len=batch["spec_lens"],
556
+ scores_disc_fake=scores_disc_fake,
557
+ feats_disc_fake=feats_disc_fake,
558
+ feats_disc_real=feats_disc_real,
559
+ loss_duration=self.model_outputs_cache["loss_duration"],
560
+ use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
561
+ gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
562
+ syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
563
+ )
564
+ return self.model_outputs_cache, loss_dict
565
+ raise ValueError(" [!] Unexpected `optimizer_idx`.")
566
+
567
+ @torch.no_grad()
568
+ def test_run(self, assets):
569
+ """Generic test run for `tts` models used by `Trainer`.
570
+
571
+ You can override this for a different behaviour.
572
+
573
+ Returns:
574
+ Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
575
+ """
576
+ print(" | > Synthesizing test sentences.")
577
+ test_audios = {}
578
+ test_figures = {}
579
+ test_sentences = self.config.test_sentences
580
+ for idx, s_info in enumerate(test_sentences):
581
+ wav = self.synthesize_from_example(s_info)
582
+ test_audios["{}-audio".format(idx)] = wav
583
+ return {"figures": test_figures, "audios": test_audios}
584
+
585
+ def get_data_loader(
586
+ self,
587
+ config: Coqpit,
588
+ assets: Dict,
589
+ is_eval: bool,
590
+ samples: Union[List[Dict], List[List]],
591
+ verbose: bool,
592
+ num_gpus: int,
593
+ rank: int = None,
594
+ ) -> "DataLoader":
595
+ dataset = VitsDataset_NT(
596
+ model_args=self.args,
597
+ speaker_manager=self.speaker_manager,
598
+ config=self.config,
599
+ use_phone_labels=config.use_phone_labels,
600
+ sample_rate=self.sample_rate,
601
+ samples=samples,
602
+ batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
603
+ min_text_len=config.min_text_len,
604
+ max_text_len=config.max_text_len,
605
+ min_audio_len=config.min_audio_len,
606
+ max_audio_len=config.max_audio_len,
607
+ phoneme_cache_path=config.phoneme_cache_path,
608
+ precompute_num_workers=config.precompute_num_workers,
609
+ verbose=verbose,
610
+ tokenizer=self.tokenizer,
611
+ start_by_longest=config.start_by_longest,
612
+ )
613
+
614
+ # sort input sequences from short to long
615
+ dataset.preprocess_samples()
616
+
617
+ # get samplers
618
+ sampler = self.get_sampler(config, dataset, num_gpus)
619
+ loader = DataLoader(
620
+ dataset,
621
+ batch_sampler=sampler,
622
+ collate_fn=dataset.collate_fn,
623
+ num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
624
+ pin_memory=False,
625
+ )
626
+ return loader
627
+
628
+
629
+ class VitsDataset_NT(VitsDataset):
630
+ def __init__(
631
+ self,
632
+ model_args,
633
+ speaker_manager,
634
+ sample_rate,
635
+ config,
636
+ use_phone_labels,
637
+ *args,
638
+ **kwargs
639
+ ):
640
+ super().__init__(model_args, *args, **kwargs)
641
+ self.speaker_manager = speaker_manager
642
+ self.sample_rate = sample_rate
643
+ self.config = config
644
+ self.use_phone_labels = use_phone_labels
645
+
646
+ def __getitem__(self, idx):
647
+ example = self.samples[idx]
648
+ token_ids = self.get_token_ids(idx, example["text"])
649
+
650
+ wav, _ = load_audio(example["audio_file"], target_sr=self.sample_rate)
651
+
652
+ speaker_id = example['speaker_name']
653
+ example_id = example['example_id']
654
+ d_vector = None
655
+ for dataset_dict_sub in self.config.dataset_dict['datasets'].values():
656
+ d_vector_file = dataset_dict_sub['d_vector_storage_root']
657
+ if (Path(d_vector_file) / f'{speaker_id}/{example_id}.pth').is_file():
658
+ d_vector = torch.load(Path(d_vector_file) / f'{speaker_id}/{example_id}.pth')
659
+ break
660
+ if d_vector is None:
661
+ raise ValueError(f'Could not find d_vector for example {example_id}')
662
+
663
+ if d_vector.dim() == 1:
664
+ d_vector = d_vector[None, :]
665
+ return {
666
+ "raw_text": example['text'],
667
+ "token_ids": token_ids,
668
+ "token_len": len(token_ids),
669
+ "wav": wav,
670
+ "d_vector": d_vector,
671
+ "speaker_name": example["speaker_name"]
672
+ }
673
+
674
+ def collate_fn(self, batch):
675
+ """
676
+ Collate a list of samples from a Dataset into a batch for VITS.
677
+
678
+ Args:
679
+ batch (dict): Expeted keys:
680
+ - wav (list): list of tensors
681
+ - token_ids (list):
682
+ - token_len (list):
683
+ - speaker_name (list):
684
+ - language_name (list):
685
+ - audiofile_path (list):
686
+ - raw_text (list):
687
+ - wav_d_vector (list):
688
+ Returns:
689
+ - tokens (torch.Tensor): (B, T)
690
+ - token_lens (torch.Tensor): (B)
691
+ - token_rel_lens (torch.Tensor): (B)
692
+ - wav (torch.Tensor): (B, 1, T)
693
+ - wav_lens (torch.Tensor): (B)
694
+ - wav_rel_lens (torch.Tensor): (B)
695
+ - speaker_names (torch.Tensor): (B)
696
+ - language_names (torch.Tensor): (B)
697
+ - audiofile_paths (torch.Tensor): (B)
698
+ - raw_texts (torch.Tensor): (B)
699
+ - audio_unique_names (torch.Tensor): (B)
700
+ """
701
+ B = len(batch)
702
+ batch = {k: [dic[k] for dic in batch] for k in batch[0]}
703
+
704
+ _, ids_sorted_decreasing = torch.sort(
705
+ torch.LongTensor(
706
+ [
707
+ x.size(1) for x in batch["wav"]]
708
+ ),
709
+ dim=0,
710
+ descending=True
711
+ )
712
+
713
+ max_text_len = max([len(x) for x in batch["token_ids"]])
714
+ token_lens = torch.LongTensor(batch["token_len"])
715
+ token_rel_lens = token_lens / token_lens.max()
716
+
717
+ wav_lens = [w.shape[1] for w in batch["wav"]]
718
+ wav_lens = torch.LongTensor(wav_lens)
719
+ wav_lens_max = torch.max(wav_lens)
720
+ wav_rel_lens = wav_lens / wav_lens_max
721
+
722
+ token_padded = torch.LongTensor(B, max_text_len)
723
+ wav_padded = torch.FloatTensor(B, 1, wav_lens_max)
724
+ token_padded = token_padded.zero_() + self.pad_id
725
+ wav_padded = wav_padded.zero_() + self.pad_id
726
+ for i in range(len(ids_sorted_decreasing)):
727
+ token_ids = batch["token_ids"][i]
728
+ token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids)
729
+ wav = batch["wav"][i]
730
+ wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
731
+
732
+ return {
733
+ "tokens": token_padded,
734
+ "token_lens": token_lens,
735
+ "token_rel_lens": token_rel_lens,
736
+ "waveform": wav_padded,
737
+ "waveform_lens": wav_lens,
738
+ "waveform_rel_lens": wav_rel_lens,
739
+ "speaker_names": batch["speaker_name"],
740
+ "raw_text": batch["raw_text"],
741
+ "d_vector": torch.concatenate(batch["d_vector"]) if 'd_vector' in batch.keys() else None,
742
+ }
setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.core import setup
2
+
3
+ setup(
4
+ name='pvq_manipulation',
5
+ version='0.0.0',
6
+ author='Department of Communications Engineering, Paderborn University',
7
+ author_email='sek@nt.upb.de',
8
+ license='MIT',
9
+ keywords='audio speech',
10
+ install_requires=[
11
+ 'torchdiffeq',
12
+ ],
13
+ )