Spaces:
Running
Running
Commit
·
0f1d9a2
1
Parent(s):
273b181
Add vits model and normalizing flow. Jupyter Notebook as example call
Browse files- pvq_manipulation/Example_Notebook.ipynb +331 -0
- pvq_manipulation/helper/characters.yaml +4 -0
- pvq_manipulation/helper/moving_batch_norm.py +140 -0
- pvq_manipulation/helper/utils.py +228 -0
- pvq_manipulation/helper/vad.py +193 -0
- pvq_manipulation/models/ffjord.py +247 -0
- pvq_manipulation/models/hubert.py +207 -0
- pvq_manipulation/models/ode_functions.py +96 -0
- pvq_manipulation/models/vits.py +742 -0
- setup.py +13 -0
pvq_manipulation/Example_Notebook.ipynb
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "f0e32cd2-4955-4140-8f48-9751a1a8c588",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import numpy as np \n",
|
11 |
+
"from pathlib import Path\n",
|
12 |
+
"import padertorch as pt\n",
|
13 |
+
"import paderbox as pb\n",
|
14 |
+
"import time\n",
|
15 |
+
"import torch\n",
|
16 |
+
"import torchaudio\n",
|
17 |
+
"import ipywidgets as widgets\n",
|
18 |
+
"from onnxruntime import InferenceSession\n",
|
19 |
+
"from pvq_manipulation.models.vits import Vits_NT\n",
|
20 |
+
"from pvq_manipulation.models.ffjord import FFJORD\n",
|
21 |
+
"from IPython.display import display, Audio, clear_output\n",
|
22 |
+
"from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER\n",
|
23 |
+
"from paderbox.transform.module_resample import resample_sox\n",
|
24 |
+
"from pvq_manipulation.helper.vad import EnergyVAD\n",
|
25 |
+
"from train_tts_nt.helper.utils import rms_norm"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "markdown",
|
30 |
+
"id": "d4df1db0-8439-4573-9dc2-5d578e8befa1",
|
31 |
+
"metadata": {},
|
32 |
+
"source": [
|
33 |
+
"# load TTS model"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": null,
|
39 |
+
"id": "e6691176-6119-4bf0-9dcf-44d657c76074",
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"storage_dir_tts = Path(\"./Saved_models/tts_model/\")\n",
|
44 |
+
"tts_model = Vits_NT.load_model(storage_dir_tts, checkpoint=\"checkpoint_390000.pth\")"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "markdown",
|
49 |
+
"id": "c9c7541c-fab5-4d44-9b89-a26a34343e7c",
|
50 |
+
"metadata": {},
|
51 |
+
"source": [
|
52 |
+
"# load normalizing flow"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
+
"id": "e4a55082-c6c6-4283-96ed-217553f33bcd",
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"storage_dir_normalizing_flow = Path(\"./Saved_models/norm_flow\")\n",
|
63 |
+
"config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / \"config.yaml\")\n",
|
64 |
+
"normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint=\"checkpoints/ckpt_best_loss.pth\")"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "markdown",
|
69 |
+
"id": "deebed07-b28c-49de-b30f-d80b9e1c6899",
|
70 |
+
"metadata": {},
|
71 |
+
"source": [
|
72 |
+
"# load hubert features model"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"id": "bc4627e1-bac7-4533-8cac-bbc296889855",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"hubert_model = HubertExtractor(\n",
|
83 |
+
" layer=SID_LARGE_LAYER,\n",
|
84 |
+
" model_name=\"HUBERT_LARGE\",\n",
|
85 |
+
" backend=\"torchaudio\",\n",
|
86 |
+
" device='cpu', \n",
|
87 |
+
" storage_dir='/net/vol/rautenberg/storage/hubert'# target storage dir hubert model\n",
|
88 |
+
")"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "markdown",
|
93 |
+
"id": "c78fa11b-8617-4175-902c-8af0e4491201",
|
94 |
+
"metadata": {},
|
95 |
+
"source": [
|
96 |
+
"# Example Synthesis"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"id": "4e8afa1b-b02e-4a40-982d-36aa78f37a57",
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"speaker_id = 1034\n",
|
107 |
+
"example_id = \"1034_121119_000028_000001\"\n",
|
108 |
+
"\n",
|
109 |
+
"wav_1 = tts_model.synthesize_from_example({\n",
|
110 |
+
" 'text' : \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\", \n",
|
111 |
+
" 'd_vector_storage_root': f\"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth\"\n",
|
112 |
+
"})\n",
|
113 |
+
"display(Audio(wav_1, rate=24_000, normalize=True))"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "markdown",
|
118 |
+
"id": "feeb1d62-69f2-45c1-a172-16fcfbecd0da",
|
119 |
+
"metadata": {},
|
120 |
+
"source": [
|
121 |
+
"# Manipulation Block"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": null,
|
127 |
+
"id": "625368d3-dd35-4da7-a358-7bbac448806c",
|
128 |
+
"metadata": {},
|
129 |
+
"outputs": [],
|
130 |
+
"source": [
|
131 |
+
"def get_manipulation(\n",
|
132 |
+
" example, \n",
|
133 |
+
" d_vector,\n",
|
134 |
+
" labels,\n",
|
135 |
+
" flow, \n",
|
136 |
+
" tts_model,\n",
|
137 |
+
" manipulation_idx=0,\n",
|
138 |
+
" manipulation_fkt=1,\n",
|
139 |
+
"):\n",
|
140 |
+
" labels_manipulated = labels.clone()\n",
|
141 |
+
" labels_manipulated[:,manipulation_idx] += manipulation_fkt\n",
|
142 |
+
" \n",
|
143 |
+
" output_forward = flow.forward((d_vector.float(), labels))[0]\n",
|
144 |
+
" sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0]\n",
|
145 |
+
"\n",
|
146 |
+
" wav = tts_model.synthesize_from_example({\n",
|
147 |
+
" 'text': \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\",\n",
|
148 |
+
" 'd_vector': d_vector.detach().numpy(),\n",
|
149 |
+
" 'd_vector_man': sampled_class_manipulated.detach().numpy(),\n",
|
150 |
+
" }) \n",
|
151 |
+
" return wav\n",
|
152 |
+
"\n",
|
153 |
+
"def extract_speaker_embedding(example):\n",
|
154 |
+
" observation, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True)\n",
|
155 |
+
" observation = resample_sox(observation, in_rate=sr, out_rate=16_000)\n",
|
156 |
+
" \n",
|
157 |
+
" vad = EnergyVAD(sample_rate=16_000)\n",
|
158 |
+
" if observation.ndim == 1:\n",
|
159 |
+
" observation = observation[None, :]\n",
|
160 |
+
" \n",
|
161 |
+
" observation = vad({'audio_data': observation})['audio_data']\n",
|
162 |
+
" \n",
|
163 |
+
" with torch.no_grad():\n",
|
164 |
+
" example = tts_model.speaker_manager.prepare_example({'audio_data': {'observation': observation}, **example})\n",
|
165 |
+
" example = pt.data.utils.collate_fn([example])\n",
|
166 |
+
" example['features'] = torch.tensor(np.array(example['features']))\n",
|
167 |
+
" d_vector = tts_model.speaker_manager.forward(example)[0]\n",
|
168 |
+
" return d_vector"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"id": "b722e503-a8f4-4702-acce-20bcdd828846",
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"def load_speaker_labels(example, config_norm_flow, reg_stor_dir=Path('./Saved_models/pvq_extractor/')):\n",
|
179 |
+
" audio, _ = torchaudio.load(example['audio_path']['observation'])\n",
|
180 |
+
" num_samples = torch.tensor([audio.shape[-1]])\n",
|
181 |
+
"\n",
|
182 |
+
" if torch.cuda.is_available():\n",
|
183 |
+
" audio = audio.cuda()\n",
|
184 |
+
" num_samples = num_samples.cuda()\n",
|
185 |
+
" providers = [\"CPUExecutionProvider\"]\n",
|
186 |
+
"\n",
|
187 |
+
" with torch.no_grad():\n",
|
188 |
+
" features, seq_len = hubert_model(\n",
|
189 |
+
" audio, \n",
|
190 |
+
" 24_000, \n",
|
191 |
+
" sequence_lengths=num_samples,\n",
|
192 |
+
" )\n",
|
193 |
+
" features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)\n",
|
194 |
+
"\n",
|
195 |
+
" pvqd_predictions = {}\n",
|
196 |
+
" for pvq in ['Breathiness', 'Loudness', 'Pitch', 'Resonance', 'Roughness', 'Strain', 'Weight']:\n",
|
197 |
+
" with open(reg_stor_dir / f\"{pvq}.onnx\", \"rb\") as fid:\n",
|
198 |
+
" onnx = fid.read()\n",
|
199 |
+
" sess = InferenceSession(onnx, providers=providers)\n",
|
200 |
+
" pred = sess.run(None, {\"X\": features[None]})[0].squeeze(1)\n",
|
201 |
+
" pvqd_predictions[pvq] = pred.tolist()[0]\n",
|
202 |
+
" labels = []\n",
|
203 |
+
" for key in config_norm_flow['speaker_conditioning']:\n",
|
204 |
+
" labels.append(pvqd_predictions[key]/100)\n",
|
205 |
+
" return torch.tensor(labels)"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "markdown",
|
210 |
+
"id": "008035ba-6054-4e6e-ab16-1aaaf68f584a",
|
211 |
+
"metadata": {},
|
212 |
+
"source": [
|
213 |
+
"# Get example manipulation"
|
214 |
+
]
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "code",
|
218 |
+
"execution_count": null,
|
219 |
+
"id": "e921a3cd-1699-495c-b825-519fb706d89d",
|
220 |
+
"metadata": {},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"example = {\n",
|
224 |
+
" 'audio_path': {'observation': \"./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav\"},\n",
|
225 |
+
" 'speaker_id': 1034,\n",
|
226 |
+
" 'example_id': \"1034_121119_000028_000001\",\n",
|
227 |
+
"}\n",
|
228 |
+
"\n",
|
229 |
+
"d_vector = extract_speaker_embedding(example)\n",
|
230 |
+
"labels = load_speaker_labels(example, config_norm_flow)\n",
|
231 |
+
"\n",
|
232 |
+
"wav_manipulated = get_manipulation(\n",
|
233 |
+
" example=example, \n",
|
234 |
+
" d_vector=d_vector, \n",
|
235 |
+
" labels=labels[None, :], \n",
|
236 |
+
" flow=normalizing_flow,\n",
|
237 |
+
" tts_model=tts_model,\n",
|
238 |
+
" manipulation_idx=0,\n",
|
239 |
+
" manipulation_fkt=1,\n",
|
240 |
+
")"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": null,
|
246 |
+
"id": "09a04e5b-c2ab-43e5-b9df-171028100ab6",
|
247 |
+
"metadata": {},
|
248 |
+
"outputs": [],
|
249 |
+
"source": [
|
250 |
+
"example = {\n",
|
251 |
+
" 'audio_path': {'observation': \"./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav\"},\n",
|
252 |
+
" 'speaker_id': 1034,\n",
|
253 |
+
" 'example_id': \"1034_121119_000028_000001\",\n",
|
254 |
+
"}\n",
|
255 |
+
"\n",
|
256 |
+
"label_options = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch']\n",
|
257 |
+
"\n",
|
258 |
+
"manipulation_idx_widget = widgets.Dropdown(\n",
|
259 |
+
" options=[(label, i) for i, label in enumerate(label_options)],\n",
|
260 |
+
" value=2, # Standardwert: Breathiness\n",
|
261 |
+
" description='Type:',\n",
|
262 |
+
" style={'description_width': 'initial'}\n",
|
263 |
+
")\n",
|
264 |
+
"\n",
|
265 |
+
"manipulation_fkt_widget = widgets.FloatSlider(\n",
|
266 |
+
" value=1.0, min=-2.0, max=2.0, step=0.1,\n",
|
267 |
+
" description='Strength:',\n",
|
268 |
+
" style={'description_width': 'initial'}\n",
|
269 |
+
")\n",
|
270 |
+
"\n",
|
271 |
+
"run_button = widgets.Button(description=\"Run Manipulation\")\n",
|
272 |
+
"\n",
|
273 |
+
"audio_output = widgets.Output()\n",
|
274 |
+
"\n",
|
275 |
+
"def update_manipulation(b):\n",
|
276 |
+
" manipulation_idx = manipulation_idx_widget.value\n",
|
277 |
+
" manipulation_fkt = manipulation_fkt_widget.value\n",
|
278 |
+
" \n",
|
279 |
+
" d_vector = extract_speaker_embedding(example)\n",
|
280 |
+
" labels = load_speaker_labels(example, config_norm_flow)\n",
|
281 |
+
"\n",
|
282 |
+
" with audio_output:\n",
|
283 |
+
" clear_output(wait=True)\n",
|
284 |
+
" display(widgets.Label(\"Processing...\"))\n",
|
285 |
+
" \n",
|
286 |
+
" time.sleep(1) \n",
|
287 |
+
" \n",
|
288 |
+
" wav_manipulated = get_manipulation(\n",
|
289 |
+
" example=example, \n",
|
290 |
+
" d_vector=d_vector, \n",
|
291 |
+
" labels=labels[None, :], \n",
|
292 |
+
" flow=normalizing_flow,\n",
|
293 |
+
" tts_model=tts_model,\n",
|
294 |
+
" manipulation_idx=manipulation_idx,\n",
|
295 |
+
" manipulation_fkt=manipulation_fkt,\n",
|
296 |
+
" )\n",
|
297 |
+
" \n",
|
298 |
+
" with audio_output:\n",
|
299 |
+
" clear_output(wait=True) \n",
|
300 |
+
" display(Audio(wav_manipulated, rate=24_000, normalize=True))\n",
|
301 |
+
" display(Audio(example['audio_path']['observation'], rate=24_000, normalize=True))\n",
|
302 |
+
"\n",
|
303 |
+
" print(f\"Manipulated {label_options[manipulation_idx]} with strength {manipulation_fkt}\")\n",
|
304 |
+
"\n",
|
305 |
+
"run_button.on_click(update_manipulation)\n",
|
306 |
+
"display(manipulation_idx_widget, manipulation_fkt_widget, run_button, audio_output)"
|
307 |
+
]
|
308 |
+
}
|
309 |
+
],
|
310 |
+
"metadata": {
|
311 |
+
"kernelspec": {
|
312 |
+
"display_name": "voice editing",
|
313 |
+
"language": "python",
|
314 |
+
"name": "voice_editing"
|
315 |
+
},
|
316 |
+
"language_info": {
|
317 |
+
"codemirror_mode": {
|
318 |
+
"name": "ipython",
|
319 |
+
"version": 3
|
320 |
+
},
|
321 |
+
"file_extension": ".py",
|
322 |
+
"mimetype": "text/x-python",
|
323 |
+
"name": "python",
|
324 |
+
"nbconvert_exporter": "python",
|
325 |
+
"pygments_lexer": "ipython3",
|
326 |
+
"version": "3.11.9"
|
327 |
+
}
|
328 |
+
},
|
329 |
+
"nbformat": 4,
|
330 |
+
"nbformat_minor": 5
|
331 |
+
}
|
pvq_manipulation/helper/characters.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Yourtts:
|
2 |
+
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u00af\u00b7\u00df\u00e0\u00e1\u00e2\u00e3\u00e4\u00e6\u00e7\u00e8\u00e9\u00ea\u00eb\u00ec\u00ed\u00ee\u00ef\u00f1\u00f2\u00f3\u00f4\u00f5\u00f6\u00f9\u00fa\u00fb\u00fc\u00ff\u0101\u0105\u0107\u0113\u0119\u011b\u012b\u0131\u0142\u0144\u014d\u0151\u0153\u015b\u016b\u0171\u017a\u017c\u01ce\u01d0\u01d2\u01d4\u0430\u0431\u0432\u0433\u0434\u0435\u0436\u0437\u0438\u0439\u043a\u043b\u043c\u043d\u043e\u043f\u0440\u0441\u0442\u0443\u0444\u0445\u0446\u0447\u0448\u0449\u044a\u044b\u044c\u044d\u044e\u044f\u0451\u0454\u0456\u0457\u0491\u2013!'(),-.:;? "
|
3 |
+
German:
|
4 |
+
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;?\u00af\u2013\u00fc\u00f6\u00e4\u00df\u201a\u2018\u2019"
|
pvq_manipulation/helper/moving_batch_norm.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This Code is adapted from https://github.com/RameenAbdal/StyleFlow/blob/master/module/normalization.py
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class MovingBatchNormNd(nn.Module):
|
10 |
+
def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True):
|
11 |
+
super(MovingBatchNormNd, self).__init__()
|
12 |
+
self.num_features = num_features
|
13 |
+
self.affine = affine
|
14 |
+
self.eps = eps
|
15 |
+
self.decay = decay
|
16 |
+
self.bn_lag = bn_lag
|
17 |
+
self.register_buffer('step', torch.zeros(1))
|
18 |
+
if self.affine:
|
19 |
+
self.weight = Parameter(torch.Tensor(num_features))
|
20 |
+
self.bias = Parameter(torch.Tensor(num_features))
|
21 |
+
else:
|
22 |
+
self.register_parameter('weight', None)
|
23 |
+
self.register_parameter('bias', None)
|
24 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
25 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
26 |
+
self.reset_parameters()
|
27 |
+
|
28 |
+
@property
|
29 |
+
def shape(self):
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
def reset_parameters(self):
|
33 |
+
self.running_mean.zero_()
|
34 |
+
self.running_var.fill_(1)
|
35 |
+
if self.affine:
|
36 |
+
self.weight.data.zero_()
|
37 |
+
self.bias.data.zero_()
|
38 |
+
|
39 |
+
def forward(self, x, c=None, logpx=None, reverse=False):
|
40 |
+
if reverse:
|
41 |
+
return self._reverse(x, logpx)
|
42 |
+
else:
|
43 |
+
return self._forward(x, logpx)
|
44 |
+
|
45 |
+
def _forward(self, x, logpx=None):
|
46 |
+
num_channels = x.size(-1)
|
47 |
+
used_mean = self.running_mean.clone().detach()
|
48 |
+
used_var = self.running_var.clone().detach()
|
49 |
+
|
50 |
+
if self.training:
|
51 |
+
# compute batch statistics
|
52 |
+
x_t = x.transpose(0, -1).reshape(num_channels, -1)
|
53 |
+
batch_mean = torch.mean(x_t, dim=1)
|
54 |
+
|
55 |
+
batch_var = torch.var(x_t, dim=1)
|
56 |
+
|
57 |
+
# moving average
|
58 |
+
if self.bn_lag > 0:
|
59 |
+
used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
|
60 |
+
used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
|
61 |
+
used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
|
62 |
+
used_var /= (1. - self.bn_lag**(self.step[0] + 1))
|
63 |
+
|
64 |
+
# update running estimates
|
65 |
+
self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
|
66 |
+
self.running_var -= self.decay * (self.running_var - batch_var.data)
|
67 |
+
self.step += 1
|
68 |
+
|
69 |
+
# perform normalization
|
70 |
+
used_mean = used_mean.view(*self.shape).expand_as(x)
|
71 |
+
used_var = used_var.view(*self.shape).expand_as(x)
|
72 |
+
|
73 |
+
y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
|
74 |
+
|
75 |
+
if self.affine:
|
76 |
+
weight = self.weight.view(*self.shape).expand_as(x)
|
77 |
+
bias = self.bias.view(*self.shape).expand_as(x)
|
78 |
+
y = y * torch.exp(weight) + bias
|
79 |
+
|
80 |
+
if logpx is None:
|
81 |
+
return y
|
82 |
+
else:
|
83 |
+
#import ipdb
|
84 |
+
#ipdb.set_trace()
|
85 |
+
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
86 |
+
|
87 |
+
def _reverse(self, y, logpy=None):
|
88 |
+
used_mean = self.running_mean
|
89 |
+
used_var = self.running_var
|
90 |
+
|
91 |
+
if self.affine:
|
92 |
+
weight = self.weight.view(*self.shape).expand_as(y)
|
93 |
+
bias = self.bias.view(*self.shape).expand_as(y)
|
94 |
+
y = (y - bias) * torch.exp(-weight)
|
95 |
+
|
96 |
+
used_mean = used_mean.view(*self.shape).expand_as(y)
|
97 |
+
used_var = used_var.view(*self.shape).expand_as(y)
|
98 |
+
x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
|
99 |
+
|
100 |
+
if logpy is None:
|
101 |
+
return x
|
102 |
+
else:
|
103 |
+
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
104 |
+
|
105 |
+
def _logdetgrad(self, x, used_var):
|
106 |
+
logdetgrad = -0.5 * torch.log(used_var + self.eps)
|
107 |
+
if self.affine:
|
108 |
+
weight = self.weight.view(*self.shape).expand(*x.size())
|
109 |
+
logdetgrad += weight
|
110 |
+
return logdetgrad
|
111 |
+
|
112 |
+
def __repr__(self):
|
113 |
+
return (
|
114 |
+
'{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
|
115 |
+
' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def stable_var(x, mean=None, dim=1):
|
120 |
+
if mean is None:
|
121 |
+
mean = x.mean(dim, keepdim=True)
|
122 |
+
mean = mean.view(-1, 1)
|
123 |
+
res = torch.pow(x - mean, 2)
|
124 |
+
max_sqr = torch.max(res, dim, keepdim=True)[0]
|
125 |
+
var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
|
126 |
+
var = var.view(-1)
|
127 |
+
# change nan to zero
|
128 |
+
var[var != var] = 0
|
129 |
+
return var
|
130 |
+
|
131 |
+
|
132 |
+
class MovingBatchNorm1d(MovingBatchNormNd):
|
133 |
+
@property
|
134 |
+
def shape(self):
|
135 |
+
return [1, -1]
|
136 |
+
|
137 |
+
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
|
138 |
+
ret = super(MovingBatchNorm1d, self).forward(
|
139 |
+
x, context, logpx=logpx, reverse=reverse)
|
140 |
+
return ret
|
pvq_manipulation/helper/utils.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import paderbox as pb
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from coqpit import Coqpit
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from paderbox.transform.module_resample import resample_sox
|
7 |
+
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
from TTS.tts.models.vits import VitsAudioConfig, VitsArgs
|
11 |
+
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
12 |
+
|
13 |
+
|
14 |
+
def load_audio(file_path, target_sr):
|
15 |
+
"""Load the audio file normalized in [-1, 1]
|
16 |
+
|
17 |
+
Return Shapes:
|
18 |
+
- x: :math:`[1, T]`
|
19 |
+
"""
|
20 |
+
if type(file_path) is dict:
|
21 |
+
if 'observation' in file_path:
|
22 |
+
file_path = file_path['observation']
|
23 |
+
|
24 |
+
x, sr = pb.io.load_audio(file_path, return_sample_rate=True)
|
25 |
+
if sr != target_sr:
|
26 |
+
x = resample_sox(x, in_rate=sr, out_rate=target_sr)
|
27 |
+
x = torch.tensor(x, dtype=torch.float32)[None, :]
|
28 |
+
x[x < -1] = -1
|
29 |
+
x[x > 1] = 1
|
30 |
+
assert (x > 1).sum() + (x < -1).sum() == 0
|
31 |
+
return x, target_sr
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class VitsAudioConfig_NT(Coqpit):
|
36 |
+
fft_size: int = 1024
|
37 |
+
sample_rate: int = 16000
|
38 |
+
win_length: int = 1024
|
39 |
+
hop_length: int = 256
|
40 |
+
num_mels: int = 80
|
41 |
+
mel_fmin: int = 0
|
42 |
+
mel_fmax: int = None
|
43 |
+
fading: str = 'half'
|
44 |
+
window: str = 'hann'
|
45 |
+
pad: bool = True
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class VitsConfig_NT(BaseTTSConfig):
|
50 |
+
"""Defines parameters for VITS End2End TTS model.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
model (str):
|
54 |
+
Model name. Do not change unless you know what you are doing.
|
55 |
+
|
56 |
+
model_args (VitsArgs):
|
57 |
+
Model architecture arguments. Defaults to `VitsArgs()`.
|
58 |
+
|
59 |
+
audio (VitsAudioConfig):
|
60 |
+
Audio processing configuration. Defaults to `VitsAudioConfig()`.
|
61 |
+
|
62 |
+
grad_clip (List):
|
63 |
+
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
64 |
+
|
65 |
+
lr_gen (float):
|
66 |
+
Initial learning rate for the generator. Defaults to 0.0002.
|
67 |
+
|
68 |
+
lr_disc (float):
|
69 |
+
Initial learning rate for the discriminator. Defaults to 0.0002.
|
70 |
+
|
71 |
+
lr_scheduler_gen (str):
|
72 |
+
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
73 |
+
`ExponentialLR`.
|
74 |
+
|
75 |
+
lr_scheduler_gen_params (dict):
|
76 |
+
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
77 |
+
|
78 |
+
lr_scheduler_disc (str):
|
79 |
+
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
80 |
+
`ExponentialLR`.
|
81 |
+
|
82 |
+
lr_scheduler_disc_params (dict):
|
83 |
+
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
84 |
+
|
85 |
+
scheduler_after_epoch (bool):
|
86 |
+
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
87 |
+
|
88 |
+
optimizer (str):
|
89 |
+
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
90 |
+
`torch.optim.*`. Defaults to `AdamW`.
|
91 |
+
|
92 |
+
kl_loss_alpha (float):
|
93 |
+
Loss weight for KL loss. Defaults to 1.0.
|
94 |
+
|
95 |
+
disc_loss_alpha (float):
|
96 |
+
Loss weight for the discriminator loss. Defaults to 1.0.
|
97 |
+
|
98 |
+
gen_loss_alpha (float):
|
99 |
+
Loss weight for the generator loss. Defaults to 1.0.
|
100 |
+
|
101 |
+
feat_loss_alpha (float):
|
102 |
+
Loss weight for the feature matching loss. Defaults to 1.0.
|
103 |
+
|
104 |
+
mel_loss_alpha (float):
|
105 |
+
Loss weight for the mel loss. Defaults to 45.0.
|
106 |
+
|
107 |
+
return_wav (bool):
|
108 |
+
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
109 |
+
|
110 |
+
compute_linear_spec (bool):
|
111 |
+
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
112 |
+
|
113 |
+
use_weighted_sampler (bool):
|
114 |
+
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
115 |
+
|
116 |
+
weighted_sampler_attrs (dict):
|
117 |
+
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
118 |
+
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
119 |
+
|
120 |
+
weighted_sampler_multipliers (dict):
|
121 |
+
Weight each unique value of a key returned by the formatter for weighted sampling.
|
122 |
+
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
123 |
+
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
124 |
+
|
125 |
+
r (int):
|
126 |
+
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
127 |
+
|
128 |
+
add_blank (bool):
|
129 |
+
If true, a blank token is added in between every character. Defaults to `True`.
|
130 |
+
|
131 |
+
test_sentences (List[List]):
|
132 |
+
List of sentences with speaker and language information to be used for testing.
|
133 |
+
|
134 |
+
language_ids_file (str):
|
135 |
+
Path to the language ids file.
|
136 |
+
|
137 |
+
use_language_embedding (bool):
|
138 |
+
If true, language embedding is used. Defaults to `False`.
|
139 |
+
|
140 |
+
Note:
|
141 |
+
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
142 |
+
|
143 |
+
Example:
|
144 |
+
|
145 |
+
>>> from TTS.tts.configs.vits_config import VitsConfig
|
146 |
+
>>> config = VitsConfig()
|
147 |
+
"""
|
148 |
+
model: str = "vits"
|
149 |
+
# model specific params
|
150 |
+
model_args: VitsArgs = field(default_factory=VitsArgs)
|
151 |
+
audio: VitsAudioConfig = field(default_factory=VitsAudioConfig)
|
152 |
+
|
153 |
+
# optimizer
|
154 |
+
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000, 1000])
|
155 |
+
lr_gen: float = 0.0002
|
156 |
+
lr_disc: float = 0.0002
|
157 |
+
lr_scheduler_gen: str = "ExponentialLR"
|
158 |
+
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
159 |
+
lr_scheduler_disc: str = "ExponentialLR"
|
160 |
+
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
161 |
+
scheduler_after_epoch: bool = True
|
162 |
+
optimizer: str = "AdamW"
|
163 |
+
optimizer_params: dict = field(
|
164 |
+
default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
|
165 |
+
|
166 |
+
# loss params
|
167 |
+
kl_loss_alpha: float = 1.0
|
168 |
+
disc_loss_alpha: float = 1.0
|
169 |
+
gen_loss_alpha: float = 1.0
|
170 |
+
feat_loss_alpha: float = 1.0
|
171 |
+
mel_loss_alpha: float = 45.0
|
172 |
+
dur_loss_alpha: float = 1.0
|
173 |
+
speaker_encoder_loss_alpha: float = 1.0
|
174 |
+
|
175 |
+
# data loader params
|
176 |
+
return_wav: bool = True
|
177 |
+
compute_linear_spec: bool = True
|
178 |
+
|
179 |
+
# sampler params
|
180 |
+
use_weighted_sampler: bool = False # TODO: move it to the base config
|
181 |
+
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
182 |
+
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
183 |
+
|
184 |
+
# overrides
|
185 |
+
r: int = 1 # DO NOT CHANGE
|
186 |
+
add_blank: bool = True
|
187 |
+
|
188 |
+
# testing
|
189 |
+
test_sentences: List[List] = field(
|
190 |
+
default_factory=lambda: [
|
191 |
+
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
|
192 |
+
["Be a voice, not an echo."],
|
193 |
+
["I'm sorry Dave. I'm afraid I can't do that."],
|
194 |
+
["This cake is great. It's so delicious and moist."],
|
195 |
+
["Prior to November 22, 1963."],
|
196 |
+
]
|
197 |
+
)
|
198 |
+
|
199 |
+
# multi-speaker settings
|
200 |
+
# use speaker embedding layer
|
201 |
+
num_speakers: int = 0
|
202 |
+
use_speaker_embedding: bool = False
|
203 |
+
speakers_file: str = None
|
204 |
+
speaker_embedding_channels: int = 256
|
205 |
+
language_ids_file: str = None
|
206 |
+
use_language_embedding: bool = False
|
207 |
+
|
208 |
+
# use d-vectors
|
209 |
+
d_vectors_stor_file: bool = False
|
210 |
+
d_vector_model_file: str = None
|
211 |
+
d_vector_dim: int = None
|
212 |
+
d_vector_model: str = None
|
213 |
+
dataset_dict: dict = None
|
214 |
+
gan_speaker_conditioning: bool = True
|
215 |
+
|
216 |
+
sample_rate: int = 16_000
|
217 |
+
use_vad: bool = True
|
218 |
+
use_phone_labels: bool = False
|
219 |
+
|
220 |
+
CONFIG_SOLVER: str = ''
|
221 |
+
|
222 |
+
use_speaker_embedding_cond: bool = True
|
223 |
+
|
224 |
+
def __post_init__(self):
|
225 |
+
for key, val in self.model_args.items():
|
226 |
+
if hasattr(self, key):
|
227 |
+
self[key] = val
|
228 |
+
|
pvq_manipulation/helper/vad.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import paderbox as pb
|
3 |
+
import padertorch as pt
|
4 |
+
import typing
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
|
9 |
+
@pb.utils.functional.partial_decorator
|
10 |
+
def conv_smoothing(signal, window_length=7, threshold=3):
|
11 |
+
"""
|
12 |
+
|
13 |
+
Boundary effects are visible at beginning and end of signal.
|
14 |
+
|
15 |
+
Examples:
|
16 |
+
>>> conv_smoothing(np.array([False, True, True, True, False, False, False, True]), 3, 2)
|
17 |
+
array([False, True, True, True, False, False, False, False])
|
18 |
+
|
19 |
+
Args:
|
20 |
+
signal:
|
21 |
+
window_length:
|
22 |
+
threshold:
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
|
26 |
+
"""
|
27 |
+
left_context = right_context = (window_length - 1) // 2
|
28 |
+
if window_length % 2 == 0:
|
29 |
+
right_context += 1
|
30 |
+
act_conv = np.sum(pb.array.segment_axis(
|
31 |
+
np.pad(signal, (left_context, right_context), mode='constant'),
|
32 |
+
length=window_length, shift=1, axis=0, end='cut'
|
33 |
+
), axis=-1)
|
34 |
+
# act_conv = np.convolve(signal, np.ones(window_length), 's')
|
35 |
+
act = act_conv >= threshold
|
36 |
+
assert act.shape == signal.shape, (act.shape, signal.shape)
|
37 |
+
return act
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class VAD(pt.Configurable):
|
42 |
+
smoothing: typing.Optional[typing.Callable] = None
|
43 |
+
|
44 |
+
def reset(self):
|
45 |
+
"""Override for a stateful VAD"""
|
46 |
+
pass
|
47 |
+
|
48 |
+
def compute_vad(self, signal, time_resolution=True):
|
49 |
+
raise NotImplementedError()
|
50 |
+
|
51 |
+
def vad_to_time(self, vad, time_length):
|
52 |
+
raise NotImplementedError()
|
53 |
+
|
54 |
+
def __call__(self, signal, time_resolution=True, reset=True):
|
55 |
+
if reset:
|
56 |
+
self.reset()
|
57 |
+
|
58 |
+
vad = self.compute_vad(signal)
|
59 |
+
|
60 |
+
if self.smoothing is not None:
|
61 |
+
vad = pb.array.interval.ArrayInterval(self.smoothing(vad))
|
62 |
+
|
63 |
+
if time_resolution:
|
64 |
+
vad = self.vad_to_time(vad, time_length=signal.shape[-1])
|
65 |
+
|
66 |
+
return vad
|
67 |
+
|
68 |
+
|
69 |
+
class EnergyVAD(VAD):
|
70 |
+
def __init__(self, sample_rate, threshold=0.3):
|
71 |
+
self.sample_rate = sample_rate
|
72 |
+
self.threshold = threshold
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def remove_silence(signal, vad_mask):
|
76 |
+
return signal[vad_mask == 1]
|
77 |
+
|
78 |
+
def __call__(self, example):
|
79 |
+
signal = example['audio_data'] # B T
|
80 |
+
vad_mask = self.get_vad_mask(signal)
|
81 |
+
signal = self.remove_silence(signal, vad_mask)
|
82 |
+
example['audio_data'] = signal
|
83 |
+
example['vad_mask'] = vad_mask
|
84 |
+
return example
|
85 |
+
|
86 |
+
def get_vad_mask(self, signal):
|
87 |
+
window_size = int(0.1 * self.sample_rate + 1)
|
88 |
+
|
89 |
+
half_context = (window_size - 1) // 2
|
90 |
+
std = np.std(signal, axis=-1, keepdims=True)
|
91 |
+
signal = signal - np.mean(signal, axis=-1, keepdims=True)
|
92 |
+
signal = np.abs(signal)
|
93 |
+
zeros = np.zeros(
|
94 |
+
[
|
95 |
+
signal.shape[0],
|
96 |
+
half_context,
|
97 |
+
]
|
98 |
+
)
|
99 |
+
signal = np.concatenate([zeros, signal, zeros], axis=1)
|
100 |
+
sliding_max = np.max(pb.array.segment_axis(
|
101 |
+
signal,
|
102 |
+
length=window_size, shift=1, axis=1, end='cut'
|
103 |
+
), axis=-1)
|
104 |
+
return sliding_max > self.threshold * std
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class ThresholdVAD(VAD):
|
109 |
+
"""
|
110 |
+
Energy-based VAD for almost clean files. Tested on WSJ clean data by Lukas
|
111 |
+
Drude.
|
112 |
+
|
113 |
+
Attributes:
|
114 |
+
threshold: Fraction of total signal standard deviation. Use 0.3 for
|
115 |
+
(almost) clean files (SNR >= 20dB, think LibriTTS) and 0.7 for less
|
116 |
+
clean files (think LibriSpeech).
|
117 |
+
window_size: Size of sliding max window.
|
118 |
+
sample_rate: Sampling rate of audio data.
|
119 |
+
smoothing: Optional callable that uses a sliding window over the raw
|
120 |
+
decision to return a smoothed VAD.
|
121 |
+
"""
|
122 |
+
threshold: float = 0.3
|
123 |
+
window_size: typing.Optional[int] = None
|
124 |
+
sample_rate: int = 16_000
|
125 |
+
smoothing: typing.Optional[typing.Callable] = None
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def finalize_dogmatic_config(cls, config):
|
129 |
+
rate = config['sample_rate']
|
130 |
+
config['smoothing'] = {
|
131 |
+
'partial': conv_smoothing,
|
132 |
+
'window_length': int(0.3 * rate),
|
133 |
+
'threshold': int(0.1 * rate),
|
134 |
+
}
|
135 |
+
|
136 |
+
def __post_init__(self):
|
137 |
+
if self.window_size is None:
|
138 |
+
self.window_size = int(0.1 * self.sample_rate + 1)
|
139 |
+
|
140 |
+
assert self.window_size % 2 == 1, self.window_size
|
141 |
+
|
142 |
+
def __call__(self, example):
|
143 |
+
if isinstance(example, dict):
|
144 |
+
signal = example['audio_data']
|
145 |
+
if signal.ndim == 2 and signal.shape[0] == 1:
|
146 |
+
signal = signal[0]
|
147 |
+
elif signal.ndim == 2 and signal.shape[0] != 1:
|
148 |
+
raise ValueError(
|
149 |
+
'Only mono signals are supported but audio_data has shape '
|
150 |
+
f'{signal.shape}'
|
151 |
+
)
|
152 |
+
vad = super().__call__(signal)
|
153 |
+
intervals = np.asarray(vad.intervals)
|
154 |
+
start, stop = zip(*intervals)
|
155 |
+
example['vad'] = vad
|
156 |
+
example['vad_start_samples'] = start
|
157 |
+
example['vad_stop_samples'] = stop
|
158 |
+
else:
|
159 |
+
example = super().__call__(example)
|
160 |
+
return example
|
161 |
+
|
162 |
+
def _detect_voice_activity(self, signal):
|
163 |
+
assert signal.ndim == 1, signal.shape
|
164 |
+
|
165 |
+
half_context = (self.window_size - 1) // 2
|
166 |
+
std = np.std(signal)
|
167 |
+
signal = signal - np.mean(signal)
|
168 |
+
assert np.min(signal) < 0
|
169 |
+
assert np.max(signal) > 0
|
170 |
+
signal = np.abs(signal)
|
171 |
+
|
172 |
+
sliding_max = np.max(pb.array.segment_axis(
|
173 |
+
np.pad(signal, (half_context, half_context), mode='constant'),
|
174 |
+
length=self.window_size, shift=1, axis=0, end='cut'
|
175 |
+
), axis=-1)
|
176 |
+
|
177 |
+
assert sliding_max.shape == signal.shape, (
|
178 |
+
sliding_max.shape, signal.shape
|
179 |
+
)
|
180 |
+
|
181 |
+
unconstrained = sliding_max > self.threshold * std
|
182 |
+
|
183 |
+
return unconstrained
|
184 |
+
|
185 |
+
def compute_vad(self, signal, time_resolution=True):
|
186 |
+
assert time_resolution
|
187 |
+
return pb.array.interval.ArrayInterval(
|
188 |
+
self._detect_voice_activity(signal)
|
189 |
+
)
|
190 |
+
|
191 |
+
def vad_to_time(self, vad, time_length):
|
192 |
+
assert time_length == vad.shape[-1], (time_length, vad.shape[-1])
|
193 |
+
return vad
|
pvq_manipulation/models/ffjord.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import paderbox as pb
|
3 |
+
|
4 |
+
from padertorch.base import Model
|
5 |
+
from torchdiffeq import odeint_adjoint as odeint
|
6 |
+
from pvq_manipulation.helper.moving_batch_norm import MovingBatchNorm1d
|
7 |
+
|
8 |
+
|
9 |
+
class ODEBlock(torch.nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
ode_function,
|
13 |
+
train_flag=True,
|
14 |
+
reverse=False,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.time_deriv_func = ode_function
|
18 |
+
self.noise = None
|
19 |
+
self.reverse = reverse
|
20 |
+
self.train_flag = train_flag
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
time: torch.Tensor,
|
25 |
+
states: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
26 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
27 |
+
"""
|
28 |
+
Helper function to use a neural network for dy(t)/dt = f_theta(t, y(t))
|
29 |
+
|
30 |
+
Hutchinson’s trace estimator, as proposed in the FFJORD Paper, was adapted from:
|
31 |
+
https://github.com/RameenAbdal/StyleFlow/blob/master/module/odefunc.py
|
32 |
+
|
33 |
+
Args:
|
34 |
+
time (torch.Tensor): Scalar tensor representing time
|
35 |
+
states (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
36 |
+
- z (torch.Tensor): (batch_size, feature_dim) representing the input data.
|
37 |
+
- d_log_dz_dt (torch.Tensor): (batch_size, 1) representing the log derivative.
|
38 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
42 |
+
- dz_dt (torch.Tensor): (batch_size, feature_dim) The derivative of z w.r.t. time
|
43 |
+
- d_log_dz_dt (torch.Tensor): (batch_size, 1) The negative log derivative
|
44 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
45 |
+
"""
|
46 |
+
|
47 |
+
z, d_log_dz_dt, labels = states
|
48 |
+
|
49 |
+
if self.noise is None:
|
50 |
+
self.noise = self.sample_rademacher_like(z)
|
51 |
+
|
52 |
+
with torch.enable_grad():
|
53 |
+
z.requires_grad_(True)
|
54 |
+
|
55 |
+
dz_dt = self.time_deriv_func.forward(time, z, labels)
|
56 |
+
if self.train_flag:
|
57 |
+
d_log_dz_dt = self.divergence_approx(dz_dt, z, self.noise)
|
58 |
+
else:
|
59 |
+
d_log_dz_dt = torch.zeros_like(z[:, 0]).requires_grad_(True)
|
60 |
+
|
61 |
+
labels = torch.zeros_like(labels).requires_grad_(True)
|
62 |
+
return dz_dt, -d_log_dz_dt.view(z.shape[0], 1), labels
|
63 |
+
|
64 |
+
def divergence_approx(self, f, y, e=None):
|
65 |
+
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
66 |
+
e_dzdx_e = e_dzdx.mul(e)
|
67 |
+
|
68 |
+
cnt = 0
|
69 |
+
while not e_dzdx_e.requires_grad and cnt < 10:
|
70 |
+
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
71 |
+
e_dzdx_e = e_dzdx * e
|
72 |
+
cnt += 1
|
73 |
+
|
74 |
+
approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
|
75 |
+
assert approx_tr_dzdx.requires_grad, \
|
76 |
+
"(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
|
77 |
+
% (
|
78 |
+
f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad,
|
79 |
+
e_dzdx_e.requires_grad, cnt)
|
80 |
+
return approx_tr_dzdx
|
81 |
+
|
82 |
+
def before_odeint(self, e=None):
|
83 |
+
self.noise = e
|
84 |
+
|
85 |
+
def sample_rademacher_like(self, z):
|
86 |
+
if not self.training:
|
87 |
+
torch.manual_seed(0)
|
88 |
+
return torch.randint(low=0, high=2, size=z.shape).to(z) * 2 - 1
|
89 |
+
|
90 |
+
|
91 |
+
class FFJORD(Model):
|
92 |
+
"""
|
93 |
+
This class is an implementation of the FFJORD model as proposed in
|
94 |
+
https://arxiv.org/pdf/1810.01367
|
95 |
+
"""
|
96 |
+
def __init__(self, ode_function, normalize=True):
|
97 |
+
super().__init__()
|
98 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
99 |
+
self.input_dim = ode_function.input_dim
|
100 |
+
self.time_deriv_func = ODEBlock(ode_function=ode_function)
|
101 |
+
self.latent_dist = torch.distributions.MultivariateNormal(
|
102 |
+
torch.zeros(self.input_dim, device=device),
|
103 |
+
torch.eye(self.input_dim, device=device),
|
104 |
+
)
|
105 |
+
self.normalize = normalize
|
106 |
+
if self.normalize:
|
107 |
+
self.input_norm = MovingBatchNorm1d(self.input_dim, bn_lag=0)
|
108 |
+
self.output_norm = MovingBatchNorm1d(self.input_dim, bn_lag=0)
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def load_model(model_path, checkpoint):
|
112 |
+
model_dict = pb.io.load_yaml(model_path / "config.yaml")
|
113 |
+
model = Model.from_config(model_dict['model'])
|
114 |
+
cp = torch.load(
|
115 |
+
model_path / checkpoint,
|
116 |
+
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
117 |
+
)
|
118 |
+
model_weights = cp.copy()
|
119 |
+
model.load_state_dict(model_weights['model'])
|
120 |
+
model.eval()
|
121 |
+
return model
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self,
|
125 |
+
state: tuple[torch.Tensor, torch.Tensor],
|
126 |
+
integration_times: torch.Tensor = torch.tensor([0.0, 1.0]
|
127 |
+
)
|
128 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
129 |
+
"""
|
130 |
+
Integration from t_1 (data distribution) to t_0 (base distribution).
|
131 |
+
(training step)
|
132 |
+
|
133 |
+
Args:
|
134 |
+
state (Tuple[torch.Tensor, torch.Tensor]):
|
135 |
+
- z (torch.Tensor): (batch_size, feature_dim) representing the input data.
|
136 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
137 |
+
|
138 |
+
integration_times (torch.Tensor, optional): A tensor of shape (2,)
|
139 |
+
specifying the start and end times for integration. Defaults to torch.tensor([0.0, 1.0]).
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
143 |
+
- dz_dt (torch.Tensor): A tensor of shape (batch_size, feature_dim) representing the derivative of z w.r.t. time.
|
144 |
+
- -d_log_dz_dt (torch.Tensor): (batch_size, 1) representing the negative log derivative.
|
145 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
146 |
+
"""
|
147 |
+
z_1, labels = state
|
148 |
+
|
149 |
+
if z_1.dim() == 3:
|
150 |
+
z_1 = z_1.squeeze(1)
|
151 |
+
|
152 |
+
delta_logpz = torch.zeros(z_1.shape[0], 1).to(z_1.device)
|
153 |
+
|
154 |
+
if self.normalize:
|
155 |
+
z_1, delta_logpz = self.input_norm(z_1, context=labels, logpx=delta_logpz)
|
156 |
+
|
157 |
+
self.time_deriv_func.before_odeint()
|
158 |
+
state = odeint(
|
159 |
+
self.time_deriv_func, # Calculates time derivatives.
|
160 |
+
(z_1, delta_logpz, labels), # Values to update. init states
|
161 |
+
integration_times.to(z_1.device), # When to evaluate.
|
162 |
+
method='dopri5', # Runge-Kutta
|
163 |
+
atol=1e-5, # Error tolerance
|
164 |
+
rtol=1e-5, # Error tolerance
|
165 |
+
)
|
166 |
+
if self.normalize:
|
167 |
+
dz_dt, d_delta_log_dz_t = self.output_norm(state[0], context=state[2], logpx=state[1])
|
168 |
+
else:
|
169 |
+
dz_dt, d_delta_log_dz_t = state[0], state[1]
|
170 |
+
|
171 |
+
state = (dz_dt, d_delta_log_dz_t, labels)
|
172 |
+
|
173 |
+
if len(integration_times) == 2:
|
174 |
+
state = tuple(s[1] if s.shape[0] > 1 else s[0] for s in state)
|
175 |
+
return state
|
176 |
+
|
177 |
+
def sample(
|
178 |
+
self,
|
179 |
+
state: tuple[torch.Tensor, torch.Tensor],
|
180 |
+
integration_times: torch.Tensor = torch.tensor([1.0, 0.0])
|
181 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
182 |
+
"""
|
183 |
+
This is the sampling step. Integration from t_0 (base distribution) to t_1 (data distribution).
|
184 |
+
|
185 |
+
Args:
|
186 |
+
state (Tuple[torch.Tensor, torch.Tensor]):
|
187 |
+
- z_0 (torch.Tensor): (batch_size, feature_dim) representing the initial state from the base distribution
|
188 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
189 |
+
|
190 |
+
integration_times (torch.Tensor, optional): A tensor of shape (2,) specifying the start (t_0) and end (t_1) times for integration.
|
191 |
+
Defaults to torch.tensor([1.0, 0.0])
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
195 |
+
- z_t1 (torch.Tensor): (batch_size, feature_dim) representing the sampled data at time t_1 (data distribution).
|
196 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
197 |
+
"""
|
198 |
+
z_0, labels = state
|
199 |
+
delta_logpz = torch.zeros(z_0.shape[0], 1).to(z_0.device)
|
200 |
+
if self.normalize:
|
201 |
+
z_0, delta_logpz = self.output_norm(
|
202 |
+
z_0,
|
203 |
+
context=labels,
|
204 |
+
logpx=delta_logpz,
|
205 |
+
reverse=True
|
206 |
+
)
|
207 |
+
|
208 |
+
state = odeint(
|
209 |
+
self.time_deriv_func, # Calculates time derivatives.
|
210 |
+
(z_0, delta_logpz, labels), # Values to update. init states
|
211 |
+
integration_times.to(z_0.device), # When to evaluate.
|
212 |
+
method='dopri5', # Runge-Kutta
|
213 |
+
atol=1e-5, # Error tolerance
|
214 |
+
rtol=1e-5, # Error tolerance
|
215 |
+
)
|
216 |
+
if self.normalize:
|
217 |
+
dz_dt, d_delta_log_dz_t = self.input_norm(
|
218 |
+
state[0],
|
219 |
+
context=state[2],
|
220 |
+
logpx=state[1],
|
221 |
+
reverse=True
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
dz_dt, d_delta_log_dz_t = state[0], state[1]
|
225 |
+
state = (dz_dt, d_delta_log_dz_t, labels)
|
226 |
+
|
227 |
+
if len(integration_times) == 2:
|
228 |
+
state = tuple(s[1] if s.shape[0] > 1 else s[0] for s in state)
|
229 |
+
return state
|
230 |
+
|
231 |
+
def example_to_device(self, examples, device):
|
232 |
+
observations = [example['observation'] for example in examples]
|
233 |
+
labels = [example['speaker_conditioning'].tolist() for example in examples if 'speaker_conditioning' in example]
|
234 |
+
observations_tensor = torch.tensor(observations, device=device, dtype=torch.float)
|
235 |
+
labels_tensor = torch.tensor(labels, device=device, dtype=torch.float) if labels else None
|
236 |
+
return observations_tensor, labels_tensor
|
237 |
+
|
238 |
+
def review(self, example, outputs):
|
239 |
+
z_t0, delta_logpz, _ = outputs
|
240 |
+
logpz_t1 = self.latent_dist.log_prob(z_t0) - delta_logpz
|
241 |
+
loss = -torch.mean(logpz_t1)
|
242 |
+
return dict(loss=loss, scalars=dict(loss=loss))
|
243 |
+
|
244 |
+
def modify_summary(self, summary):
|
245 |
+
summary = super().modify_summary(summary)
|
246 |
+
return summary
|
247 |
+
|
pvq_manipulation/models/hubert.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from contextlib import nullcontext
|
4 |
+
import typing as tp
|
5 |
+
from typing import List, Tuple, Optional
|
6 |
+
|
7 |
+
import einops
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchaudio
|
12 |
+
|
13 |
+
import padertorch as pt
|
14 |
+
from padertorch.contrib.je.modules.conv_utils import (
|
15 |
+
compute_conv_output_sequence_lengths
|
16 |
+
)
|
17 |
+
from padertorch.utils import to_numpy
|
18 |
+
from transformers.models.hubert.modeling_hubert import HubertModel
|
19 |
+
|
20 |
+
# See https://ieeexplore.ieee.org/abstract/document/9814838, Fig. 2
|
21 |
+
PR_BASE_LAYER = 11
|
22 |
+
PR_LARGE_LAYER = 22
|
23 |
+
SID_BASE_LAYER = 4
|
24 |
+
SID_LARGE_LAYER = 6
|
25 |
+
|
26 |
+
|
27 |
+
def tuple_to_int(sequence) -> list:
|
28 |
+
return list(map(lambda t: t[0], sequence))
|
29 |
+
|
30 |
+
|
31 |
+
class HubertExtractor(pt.Module):
|
32 |
+
"""Extract HuBERT features from raw waveform.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
model_name (str): Name of the pretrained HuBERT model on huggingface.co.
|
36 |
+
Defaults to "facebook/hubert-large-ll60k".
|
37 |
+
layer (int): Index of the layer to extract features from. Defaults to
|
38 |
+
22.
|
39 |
+
freeze (bool): If True, freeze the weights of the encoder
|
40 |
+
(i.e., no finetuning of Transformer layers). Defaults to True.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
model_name: str = "facebook/hubert-large-ll60k",
|
46 |
+
layer: tp.Union[int, str] = PR_LARGE_LAYER,
|
47 |
+
freeze: bool = True,
|
48 |
+
detach: bool = False,
|
49 |
+
device: str = "cpu",
|
50 |
+
backend: str = "torchaudio",
|
51 |
+
storage_dir: str = None,
|
52 |
+
):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
if not freeze and detach:
|
56 |
+
raise ValueError(
|
57 |
+
'detach=True only supported if freeze=True\n'
|
58 |
+
f'Got: freeze={freeze}, detach={detach}'
|
59 |
+
)
|
60 |
+
if backend == "torchaudio":
|
61 |
+
bundle = getattr(torchaudio.pipelines, model_name)
|
62 |
+
self.model = bundle.get_model(dl_kwargs={'model_dir': storage_dir}).eval().to(device)
|
63 |
+
self.sampling_rate = bundle.sample_rate
|
64 |
+
else:
|
65 |
+
raise ValueError(f'Unknown backend: {backend}')
|
66 |
+
self.backend = backend
|
67 |
+
|
68 |
+
if freeze:
|
69 |
+
for param in self.model.parameters():
|
70 |
+
param.requires_grad = False
|
71 |
+
else:
|
72 |
+
# Always freeze feature extractor and feature projection layers
|
73 |
+
for param in self.model.feature_extractor.parameters():
|
74 |
+
param.requires_grad = False
|
75 |
+
for param in self.model.feature_projection.parameters():
|
76 |
+
param.requires_grad = False
|
77 |
+
|
78 |
+
self.layer = layer
|
79 |
+
self.freeze = freeze
|
80 |
+
self.detach = detach
|
81 |
+
|
82 |
+
if self.layer == 'all':
|
83 |
+
self.weights = torch.nn.Parameter(
|
84 |
+
torch.ones(24), requires_grad=True
|
85 |
+
)
|
86 |
+
|
87 |
+
@property
|
88 |
+
def cache_dir(self):
|
89 |
+
return Path(os.environ['STORAGE_ROOT']) / 'huggingface' / 'hub'
|
90 |
+
|
91 |
+
@property
|
92 |
+
def context(self):
|
93 |
+
if self.detach:
|
94 |
+
return torch.no_grad()
|
95 |
+
else:
|
96 |
+
return nullcontext()
|
97 |
+
|
98 |
+
def compute_output_lengths(
|
99 |
+
self, input_lengths: Optional[List[int]]
|
100 |
+
) -> Optional[List[int]]:
|
101 |
+
"""Compute the number of time frames for each batch entry.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
input_lengths: List with number of samples per batch entry.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
List with number of time frames per batch entry.
|
108 |
+
"""
|
109 |
+
if input_lengths is None:
|
110 |
+
return input_lengths
|
111 |
+
output_lengths = np.asarray(input_lengths) + self.window_size - 1
|
112 |
+
for kernel_size, dilation, stride in zip(
|
113 |
+
self.kernel_sizes, self.dilations, self.strides,
|
114 |
+
):
|
115 |
+
output_lengths = compute_conv_output_sequence_lengths(
|
116 |
+
output_lengths,
|
117 |
+
kernel_size=kernel_size,
|
118 |
+
dilation=dilation,
|
119 |
+
pad_type=None,
|
120 |
+
stride=stride,
|
121 |
+
)
|
122 |
+
return output_lengths.tolist()
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
time_signal: torch.Tensor,
|
127 |
+
sampling_rate: int,
|
128 |
+
sequence_lengths: Optional[List[int]] = None,
|
129 |
+
extract_features: bool = False,
|
130 |
+
other_inputs: Optional[dict] = None,
|
131 |
+
) -> Tuple[torch.Tensor, Optional[List[int]]]:
|
132 |
+
"""Extract HuBERT features from raw waveform.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
time_signal: Time signal of shape (batch, 1, time) or (batch, time)
|
136 |
+
sampled at 16 kHz.
|
137 |
+
sequence_lengths: List with number of samples per batch entry.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
x (torch.Tensor): HuBERT features of shape
|
141 |
+
(batch, D, time frames).
|
142 |
+
sequence_lengths (List[int]): List with number of time frames per
|
143 |
+
batch entry.
|
144 |
+
"""
|
145 |
+
del other_inputs
|
146 |
+
|
147 |
+
if time_signal.ndim == 3:
|
148 |
+
time_signal = einops.rearrange(time_signal, 'b c t -> (b c) t')
|
149 |
+
|
150 |
+
time_signal = torchaudio.functional.resample(
|
151 |
+
time_signal, sampling_rate, self.sampling_rate
|
152 |
+
)
|
153 |
+
if sequence_lengths is not None:
|
154 |
+
if isinstance(sequence_lengths, (list, tuple)):
|
155 |
+
sequence_lengths = torch.tensor(sequence_lengths).long() \
|
156 |
+
.to(time_signal.device)
|
157 |
+
sequence_lengths = (
|
158 |
+
sequence_lengths / sampling_rate * self.sampling_rate
|
159 |
+
).long()
|
160 |
+
|
161 |
+
if self.freeze or self.detach:
|
162 |
+
self.model.eval()
|
163 |
+
with self.context:
|
164 |
+
if self.backend == "torchaudio":
|
165 |
+
self.model: torchaudio.models.Wav2Vec2Model
|
166 |
+
x, sequence_lengths = self.model.extract_features(
|
167 |
+
time_signal, lengths=sequence_lengths,
|
168 |
+
num_layers=self.layer,
|
169 |
+
)
|
170 |
+
if isinstance(self.layer, int):
|
171 |
+
x = x[-1].transpose(1, 2)
|
172 |
+
else:
|
173 |
+
raise NotImplementedError(self.layer)
|
174 |
+
return x, sequence_lengths
|
175 |
+
|
176 |
+
self.model: HubertModel
|
177 |
+
n_pad = self.window_size - 1
|
178 |
+
time_signal = F.pad(time_signal, (0, n_pad), value=0)
|
179 |
+
if extract_features:
|
180 |
+
features = self.model.feature_extractor(time_signal.float()) \
|
181 |
+
.transpose(1, 2)
|
182 |
+
x = self.model.feature_projection(features).transpose(1, 2)
|
183 |
+
else:
|
184 |
+
outputs = self.model(
|
185 |
+
time_signal.float(), output_hidden_states=True
|
186 |
+
)
|
187 |
+
if isinstance(self.layer, int):
|
188 |
+
x = outputs.hidden_states[self.layer].transpose(1, 2)
|
189 |
+
if self.detach:
|
190 |
+
x = x.detach()
|
191 |
+
elif self.layer == 'all':
|
192 |
+
hidden_states = []
|
193 |
+
for _, hidden_state in enumerate(outputs.hidden_states):
|
194 |
+
x = hidden_state.transpose(1, 2)
|
195 |
+
if self.detach:
|
196 |
+
x = x.detach()
|
197 |
+
hidden_states.append(x)
|
198 |
+
hidden_states = torch.stack(hidden_states) # (L, B, D, T)
|
199 |
+
x = (hidden_states * self.weights[:, None, None, None]) \
|
200 |
+
.sum(dim=0)
|
201 |
+
else:
|
202 |
+
raise ValueError(f'Unknown layer: {self.layer}')
|
203 |
+
|
204 |
+
sequence_lengths = to_numpy(sequence_lengths)
|
205 |
+
sequence_lengths = self.compute_output_lengths(sequence_lengths)
|
206 |
+
|
207 |
+
return x.unsqueeze(1), sequence_lengths
|
pvq_manipulation/models/ode_functions.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of Δz = f(t, z, labels)
|
3 |
+
f() is a neural network with the architecture defined in StyleFlow
|
4 |
+
StyleFlow: Attribute-conditioned Exploration of StyleGAN-Generated Images using Conditional Continuous Normalizing Flows
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class CNFNN(torch.nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
input_dim,
|
13 |
+
condition_dim,
|
14 |
+
hidden_channels,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.layers = torch.nn.ModuleList()
|
18 |
+
hidden_dims = hidden_channels + [input_dim]
|
19 |
+
self.input_dim = input_dim
|
20 |
+
|
21 |
+
for idx, hidden_dim in enumerate(hidden_dims):
|
22 |
+
self.layers.append(CNFBlock(
|
23 |
+
input_dim=input_dim,
|
24 |
+
condition_dim=condition_dim,
|
25 |
+
output_dim=hidden_dim,
|
26 |
+
output_layer=False if idx < len(hidden_dims) - 1 else True,
|
27 |
+
))
|
28 |
+
input_dim = hidden_dim
|
29 |
+
|
30 |
+
def forward(self, t, z, labels):
|
31 |
+
"""
|
32 |
+
This function computes: Δz = f(t, z, labels)
|
33 |
+
|
34 |
+
Args:
|
35 |
+
t (torch.Tensor): () Time step of the ODE
|
36 |
+
z (torch.Tensor): (Batch_size, Input_dim) Intermediate value
|
37 |
+
labels (torch.Tensor): (Batch_size, condition_dim) Speaker attributes
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Δz (torch.Tensor): : (Batch_size, Input_dim) Computed delta
|
41 |
+
"""
|
42 |
+
for layer in self.layers:
|
43 |
+
z = layer(t, z, labels)
|
44 |
+
return z
|
45 |
+
|
46 |
+
|
47 |
+
class CNFBlock(torch.nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
input_dim,
|
51 |
+
output_dim,
|
52 |
+
condition_dim,
|
53 |
+
output_layer,
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
self._layer = torch.nn.Linear(input_dim, output_dim)
|
57 |
+
self._hyper_bias = torch.nn.Linear(
|
58 |
+
1 + condition_dim,
|
59 |
+
output_dim,
|
60 |
+
bias=False
|
61 |
+
)
|
62 |
+
self._hyper_gate = torch.nn.Linear(
|
63 |
+
1 + condition_dim,
|
64 |
+
output_dim
|
65 |
+
)
|
66 |
+
self.output_layer = output_layer
|
67 |
+
|
68 |
+
def forward(self, t, z, labels):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
t (torch.Tensor): () Time step of the ODE
|
72 |
+
z (torch.Tensor): (Batch_size, Input_dim) Intermediate value
|
73 |
+
labels (torch.Tensor): (Batch_size, condition_dim) Speaker attributes
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
z (torch.Tensor): : (Batch_size, Output_dim) Intermediate value
|
77 |
+
"""
|
78 |
+
if labels.dim() == 1:
|
79 |
+
labels = labels[:, None]
|
80 |
+
elif labels.dim() == 3:
|
81 |
+
labels = labels.squeeze(1)
|
82 |
+
|
83 |
+
tz_cat = torch.cat((t.expand(z.shape[0], 1), labels), dim=1)
|
84 |
+
|
85 |
+
gate = torch.sigmoid(self._hyper_gate(tz_cat))
|
86 |
+
bias = self._hyper_bias(tz_cat)
|
87 |
+
|
88 |
+
if z.dim() == 3:
|
89 |
+
gate = gate.unsqueeze(1)
|
90 |
+
bias = bias.unsqueeze(1)
|
91 |
+
|
92 |
+
z = self._layer(z) * gate + bias
|
93 |
+
|
94 |
+
if not self.output_layer:
|
95 |
+
z = torch.tanh(z)
|
96 |
+
return z
|
pvq_manipulation/models/vits.py
ADDED
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a wrapper for the TTS VITS model.
|
3 |
+
TTS.tts.models.vits
|
4 |
+
https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/models/vits.py
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import paderbox as pb
|
9 |
+
import padertorch as pt
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from coqpit import Coqpit
|
13 |
+
from padertorch.ops._stft import STFT
|
14 |
+
from pathlib import Path
|
15 |
+
from pvq_manipulation.helper.utils import VitsAudioConfig_NT, VitsConfig_NT, load_audio
|
16 |
+
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from torch.cuda.amp.autocast_mode import autocast
|
19 |
+
from TTS.tts.configs.shared_configs import CharactersConfig
|
20 |
+
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
21 |
+
from TTS.tts.models.vits import Vits, VitsArgs, VitsDataset, spec_to_mel, wav_to_spec
|
22 |
+
from TTS.tts.utils.languages import LanguageManager
|
23 |
+
from TTS.tts.utils.speakers import SpeakerManager
|
24 |
+
from TTS.tts.utils.synthesis import embedding_to_torch, numpy_to_torch
|
25 |
+
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
26 |
+
from TTS.tts.utils.helpers import generate_path, rand_segments, segment, sequence_mask
|
27 |
+
from TTS.utils.audio import AudioProcessor
|
28 |
+
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
29 |
+
from trainer.trainer import to_cuda
|
30 |
+
from typing import Dict, List, Union
|
31 |
+
|
32 |
+
|
33 |
+
STORAGE_ROOT = Path(os.getenv('STORAGE_ROOT')).expanduser()
|
34 |
+
|
35 |
+
|
36 |
+
class Vits_NT(Vits):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
config: Coqpit,
|
40 |
+
ap: "AudioProcessor" = None,
|
41 |
+
tokenizer: "TTSTokenizer" = None,
|
42 |
+
speaker_manager: SpeakerManager = None,
|
43 |
+
language_manager: LanguageManager = None,
|
44 |
+
sample_rate: int = None,
|
45 |
+
):
|
46 |
+
super().__init__(
|
47 |
+
config,
|
48 |
+
ap,
|
49 |
+
tokenizer,
|
50 |
+
speaker_manager,
|
51 |
+
language_manager
|
52 |
+
)
|
53 |
+
self.sample_rate = sample_rate
|
54 |
+
self.embedded_speaker_dim = self.args.d_vector_dim
|
55 |
+
self.posterior_encoder = PosteriorEncoder(
|
56 |
+
self.args.out_channels,
|
57 |
+
self.args.hidden_channels,
|
58 |
+
self.args.hidden_channels,
|
59 |
+
kernel_size=self.args.kernel_size_posterior_encoder,
|
60 |
+
dilation_rate=self.args.dilation_rate_posterior_encoder,
|
61 |
+
num_layers=self.args.num_layers_posterior_encoder,
|
62 |
+
cond_channels=self.embedded_speaker_dim,
|
63 |
+
)
|
64 |
+
|
65 |
+
self.flow = ResidualCouplingBlocks(
|
66 |
+
self.args.hidden_channels,
|
67 |
+
self.args.hidden_channels,
|
68 |
+
kernel_size=self.args.kernel_size_flow,
|
69 |
+
dilation_rate=self.args.dilation_rate_flow,
|
70 |
+
num_layers=self.args.num_layers_flow,
|
71 |
+
cond_channels=self.embedded_speaker_dim,
|
72 |
+
)
|
73 |
+
|
74 |
+
self.text_encoder = TextEncoder(
|
75 |
+
self.args.num_chars,
|
76 |
+
self.args.hidden_channels,
|
77 |
+
self.args.hidden_channels,
|
78 |
+
self.args.hidden_channels_ffn_text_encoder,
|
79 |
+
self.args.num_heads_text_encoder,
|
80 |
+
self.args.num_layers_text_encoder,
|
81 |
+
self.args.kernel_size_text_encoder,
|
82 |
+
self.args.dropout_p_text_encoder,
|
83 |
+
language_emb_dim=self.embedded_language_dim,
|
84 |
+
)
|
85 |
+
self.waveform_decoder = HifiganGenerator(
|
86 |
+
self.args.hidden_channels,
|
87 |
+
1,
|
88 |
+
self.args.resblock_type_decoder,
|
89 |
+
self.args.resblock_dilation_sizes_decoder,
|
90 |
+
self.args.resblock_kernel_sizes_decoder,
|
91 |
+
self.args.upsample_kernel_sizes_decoder,
|
92 |
+
self.args.upsample_initial_channel_decoder,
|
93 |
+
self.args.upsample_rates_decoder,
|
94 |
+
inference_padding=0,
|
95 |
+
cond_channels=self.embedded_speaker_dim if self.config.gan_speaker_conditioning else 0,
|
96 |
+
conv_pre_weight_norm=False,
|
97 |
+
conv_post_weight_norm=False,
|
98 |
+
conv_post_bias=False,
|
99 |
+
)
|
100 |
+
self.speaker_manager = self.speaker_manager
|
101 |
+
self.speaker_encoder = self.speaker_manager
|
102 |
+
|
103 |
+
self.speaker_manager.eval()
|
104 |
+
|
105 |
+
self.epoch = 0
|
106 |
+
self.num_epochs = config['epochs']
|
107 |
+
self.lr_lambda = 0
|
108 |
+
self.config_solver = config['CONFIG_SOLVER']
|
109 |
+
self.config = config
|
110 |
+
|
111 |
+
self.stft = STFT(
|
112 |
+
size=self.config.audio.win_length,
|
113 |
+
shift=self.config.audio.hop_length,
|
114 |
+
window_length=self.config.audio.win_length,
|
115 |
+
fading=self.config.audio.fading,
|
116 |
+
window=self.config.audio.window,
|
117 |
+
pad=self.config.audio.pad
|
118 |
+
)
|
119 |
+
|
120 |
+
def get_spectogram_nt(self, wav):
|
121 |
+
"""
|
122 |
+
Extracts spectrogram from audio
|
123 |
+
Args:
|
124 |
+
wav (torch.Tensor): (Batch_size, Num_samples)
|
125 |
+
Returns:
|
126 |
+
spectrogram (torch.Tensor): (Batch_size, Frequency_bins, Time) spectrogram
|
127 |
+
"""
|
128 |
+
wav = wav.squeeze(1)
|
129 |
+
stft_signal = self.stft(wav)
|
130 |
+
stft_signal = torch.einsum('btf-> bft', stft_signal)
|
131 |
+
spectrogram = stft_signal.real ** 2 + stft_signal.imag ** 2
|
132 |
+
spectrogram = torch.sqrt(spectrogram + 1e-6)
|
133 |
+
return spectrogram
|
134 |
+
|
135 |
+
def get_aux_input_from_test_sentences(self, sentence_info):
|
136 |
+
"""
|
137 |
+
Get aux input for the inference step from test sentences
|
138 |
+
Args:
|
139 |
+
sentence_info (dict): Expected keys:
|
140 |
+
- "d_vector_storage_root" (str)
|
141 |
+
- "d_vector" (torch.Tensor)
|
142 |
+
- "d_vector_man" (torch.Tensor) (optional)
|
143 |
+
Returns:
|
144 |
+
aux_input (dict): aux input for the inference step
|
145 |
+
"""
|
146 |
+
if 'd_vector' not in sentence_info.keys():
|
147 |
+
d_vector_file = sentence_info['d_vector_storage_root']
|
148 |
+
d_vector = torch.load(d_vector_file)
|
149 |
+
return {"d_vector": d_vector, **sentence_info}
|
150 |
+
else:
|
151 |
+
return sentence_info
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def init_from_config(
|
155 |
+
config: "VitsConfig",
|
156 |
+
samples= None,
|
157 |
+
verbose=True
|
158 |
+
):
|
159 |
+
"""
|
160 |
+
Initiate model from config
|
161 |
+
Args:
|
162 |
+
config (VitsConfig): Model config.
|
163 |
+
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
164 |
+
Defaults to None.
|
165 |
+
Returns:
|
166 |
+
model (Vits): Initialized model.
|
167 |
+
"""
|
168 |
+
upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
|
169 |
+
assert (upsample_rate == config.audio.hop_length), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
170 |
+
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
171 |
+
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
172 |
+
language_manager = LanguageManager.init_from_config(config)
|
173 |
+
speaker_manager = pt.Module.from_storage_dir(
|
174 |
+
config['d_vector_model_file'],
|
175 |
+
checkpoint_name='ckpt_latest.pth',
|
176 |
+
consider_mpi=False,
|
177 |
+
config_name='config.json',
|
178 |
+
)
|
179 |
+
speaker_manager.num_speakers = config['num_speakers']
|
180 |
+
for param in speaker_manager.parameters():
|
181 |
+
param.requires_grad = False
|
182 |
+
|
183 |
+
return Vits_NT(
|
184 |
+
new_config,
|
185 |
+
ap,
|
186 |
+
tokenizer,
|
187 |
+
speaker_manager=speaker_manager,
|
188 |
+
language_manager=language_manager,
|
189 |
+
sample_rate=config['sample_rate'],
|
190 |
+
)
|
191 |
+
|
192 |
+
@torch.no_grad()
|
193 |
+
def inference(self, x, aux_input=None):
|
194 |
+
"""
|
195 |
+
Note:
|
196 |
+
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
x (torch.Tensor): (batch_size, T_seq) or (T_seq) Input character sequence IDs
|
200 |
+
aux_input (dict): Expected keys:
|
201 |
+
- d_vector (torch.Tensor): (batch_size, Feature_dim) speaker_embedding
|
202 |
+
- x_lengths: (torch.Tensor): (batch_size) length of each text token
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
- model_outputs (torch.Tensor): (batch_size, T_wav) Synthesized waveform
|
206 |
+
"""
|
207 |
+
speaker_embedding = aux_input['d_vector'].detach()[:, :, None]
|
208 |
+
if aux_input['d_vector_man'] is not None:
|
209 |
+
speaker_embedding_man = aux_input['d_vector_man'].detach()[:, :, None]
|
210 |
+
else:
|
211 |
+
speaker_embedding_man = speaker_embedding
|
212 |
+
aux_input['tokens'] = x.clone()
|
213 |
+
x_lengths = self._set_x_lengths(x, aux_input)
|
214 |
+
x, m_p, logs_p, x_mask = self.text_encoder(
|
215 |
+
x,
|
216 |
+
x_lengths,
|
217 |
+
lang_emb=None
|
218 |
+
)
|
219 |
+
logw = self.duration_predictor(
|
220 |
+
x,
|
221 |
+
x_mask,
|
222 |
+
g=speaker_embedding,
|
223 |
+
lang_emb=None,
|
224 |
+
)
|
225 |
+
|
226 |
+
w = torch.exp(logw) * x_mask * self.length_scale
|
227 |
+
w_ceil = torch.ceil(w)
|
228 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
229 |
+
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
230 |
+
|
231 |
+
attn_mask = x_mask * y_mask.transpose(1, 2)
|
232 |
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
233 |
+
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
234 |
+
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
235 |
+
|
236 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
237 |
+
|
238 |
+
z = self.flow(z_p, y_mask, g=speaker_embedding_man, reverse=True)
|
239 |
+
z, _, _, y_mask = self.upsampling_z(
|
240 |
+
z,
|
241 |
+
y_lengths=y_lengths,
|
242 |
+
y_mask=y_mask
|
243 |
+
)
|
244 |
+
o = self.waveform_decoder(
|
245 |
+
(z * y_mask)[:, :, : self.max_inference_len],
|
246 |
+
g=speaker_embedding_man if self.config.gan_speaker_conditioning else None
|
247 |
+
)
|
248 |
+
return o
|
249 |
+
|
250 |
+
def forward(self, x, x_lengths, y, y_lengths, aux_input, inference=False):
|
251 |
+
"""
|
252 |
+
Forward pass of the model.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
x (torch.tensor): (Batch, T_seq) Input character sequence IDs
|
256 |
+
x_lengths (torch.tensor): (Batch) Input character sequence lengths.
|
257 |
+
y (torch.tensor): (Batch_size, Frequency_bins, Time) Input spectrograms.
|
258 |
+
y_lengths (torch.tensor): (Batch) Input spectrogram lengths.
|
259 |
+
aux_input (dict, optional): Expected keys:
|
260 |
+
- d_vector (torch.Tensor): (batch_size, Feature_dim) speaker_embedding
|
261 |
+
- waveform: (torch.Tensor): (Batch_size, Num_samples) Target waveform
|
262 |
+
Returns:
|
263 |
+
Dict: model outputs keyed by the output name.
|
264 |
+
"""
|
265 |
+
outputs = {}
|
266 |
+
speaker_embedding = aux_input['d_vector'].detach()[:, :, None]
|
267 |
+
x, m_p, logs_p, x_mask = self.text_encoder(
|
268 |
+
x,
|
269 |
+
x_lengths,
|
270 |
+
lang_emb=None
|
271 |
+
)
|
272 |
+
z, m_q, logs_q, y_mask = self.posterior_encoder(
|
273 |
+
y,
|
274 |
+
y_lengths,
|
275 |
+
g=speaker_embedding,
|
276 |
+
)
|
277 |
+
z_p = self.flow(z, y_mask, g=speaker_embedding)
|
278 |
+
outputs, attn = self.forward_mas(
|
279 |
+
outputs,
|
280 |
+
z_p,
|
281 |
+
m_p,
|
282 |
+
logs_p,
|
283 |
+
x,
|
284 |
+
x_mask,
|
285 |
+
y_mask,
|
286 |
+
g=speaker_embedding,
|
287 |
+
lang_emb=None,
|
288 |
+
)
|
289 |
+
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
290 |
+
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
291 |
+
|
292 |
+
z_slice, slice_ids = rand_segments(
|
293 |
+
z,
|
294 |
+
y_lengths,
|
295 |
+
self.spec_segment_size,
|
296 |
+
let_short_samples=True,
|
297 |
+
pad_short=True
|
298 |
+
)
|
299 |
+
z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(
|
300 |
+
z_slice,
|
301 |
+
slice_ids=slice_ids,
|
302 |
+
)
|
303 |
+
|
304 |
+
wav_seg = segment(
|
305 |
+
aux_input['waveform'],
|
306 |
+
slice_ids * self.config.audio.hop_length,
|
307 |
+
spec_segment_size * self.config.audio.hop_length,
|
308 |
+
pad_short=True,
|
309 |
+
)
|
310 |
+
o = self.waveform_decoder(
|
311 |
+
z_slice,
|
312 |
+
g=speaker_embedding if self.config.gan_speaker_conditioning else None
|
313 |
+
)
|
314 |
+
|
315 |
+
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
|
316 |
+
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
317 |
+
if self.audio_transform is not None:
|
318 |
+
wavs_batch = self.audio_transform(wavs_batch)
|
319 |
+
with torch.no_grad():
|
320 |
+
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
|
321 |
+
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
322 |
+
else:
|
323 |
+
gt_spk_emb, syn_spk_emb = None, None
|
324 |
+
|
325 |
+
outputs.update(
|
326 |
+
{
|
327 |
+
"model_outputs": o,
|
328 |
+
"alignments": attn.squeeze(1),
|
329 |
+
"m_p": m_p,
|
330 |
+
"logs_p": logs_p,
|
331 |
+
"z": z,
|
332 |
+
"z_p": z_p,
|
333 |
+
"m_q": m_q,
|
334 |
+
"logs_q": logs_q,
|
335 |
+
"waveform_seg": wav_seg,
|
336 |
+
"gt_spk_emb": gt_spk_emb,
|
337 |
+
"syn_spk_emb": syn_spk_emb,
|
338 |
+
"slice_ids": slice_ids,
|
339 |
+
"z_slice": z_slice,
|
340 |
+
"speaker_embedding": speaker_embedding,
|
341 |
+
}
|
342 |
+
)
|
343 |
+
return outputs
|
344 |
+
|
345 |
+
@staticmethod
|
346 |
+
def load_model(model_path, checkpoint):
|
347 |
+
"""
|
348 |
+
Load model from checkpoint
|
349 |
+
|
350 |
+
Args:
|
351 |
+
model_path (str): model path
|
352 |
+
checkpoint (str): checkpoint name
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
model (pvq_manipulation.models.vits.Vits_NT): model
|
356 |
+
"""
|
357 |
+
config = pb.io.load_json(model_path / "config.json")
|
358 |
+
model_args = VitsArgs(**config['model_args'])
|
359 |
+
audio_config = VitsAudioConfig_NT(**config['audio'])
|
360 |
+
characters_config = CharactersConfig(**config['characters'])
|
361 |
+
del config['audio']
|
362 |
+
del config['characters']
|
363 |
+
del config['model_args']
|
364 |
+
|
365 |
+
config = VitsConfig_NT(
|
366 |
+
model_args=model_args,
|
367 |
+
audio=audio_config,
|
368 |
+
characters=characters_config,
|
369 |
+
**config,
|
370 |
+
)
|
371 |
+
model = Vits_NT.init_from_config(config)
|
372 |
+
cp = torch.load(
|
373 |
+
model_path / checkpoint,
|
374 |
+
map_location=torch.device('cpu')
|
375 |
+
)
|
376 |
+
model_weights = cp['model'].copy()
|
377 |
+
model.load_state_dict(model_weights, strict=False)
|
378 |
+
model.eval()
|
379 |
+
return model
|
380 |
+
|
381 |
+
def synthesize_from_example(self, s_info):
|
382 |
+
"""
|
383 |
+
Synthesize voice from example
|
384 |
+
|
385 |
+
Args:
|
386 |
+
s_info (dict): Expected keys:
|
387 |
+
- "speaker_id" (str),
|
388 |
+
- "example_id" (str),
|
389 |
+
- "audio_path" (str),
|
390 |
+
- "d_vector_storage_root" (str),
|
391 |
+
- "text" (str) specifying the text to synthesize
|
392 |
+
Returns:
|
393 |
+
wav (torch.Tensor): synthesized waveform
|
394 |
+
"""
|
395 |
+
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
396 |
+
use_cuda = "cuda" in str(next(self.parameters()).device)
|
397 |
+
|
398 |
+
device = next(self.parameters()).device
|
399 |
+
if use_cuda:
|
400 |
+
device = "cuda"
|
401 |
+
|
402 |
+
text_inputs = np.asarray(
|
403 |
+
self.tokenizer.text_to_ids(aux_inputs["text"], language=None),
|
404 |
+
dtype=np.int32,
|
405 |
+
)
|
406 |
+
d_vector = embedding_to_torch(aux_inputs["d_vector"], device=device)
|
407 |
+
|
408 |
+
if "d_vector_man" in aux_inputs.keys():
|
409 |
+
d_vector_man = embedding_to_torch(aux_inputs["d_vector_man"], device=device)
|
410 |
+
|
411 |
+
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
412 |
+
text_inputs = text_inputs.unsqueeze(0)
|
413 |
+
|
414 |
+
wav = self.inference(
|
415 |
+
text_inputs,
|
416 |
+
aux_input={
|
417 |
+
"x_lengths": torch.tensor(
|
418 |
+
text_inputs.shape[1:2]
|
419 |
+
).to(text_inputs.device),
|
420 |
+
"d_vector": d_vector,
|
421 |
+
"d_vector_man": d_vector_man if "d_vector_man" in aux_inputs.keys() else None
|
422 |
+
}
|
423 |
+
)[0].data.cpu().numpy().squeeze()
|
424 |
+
return wav
|
425 |
+
|
426 |
+
def format_batch_on_device(self, batch):
|
427 |
+
"""Format batch on device"""
|
428 |
+
ac = self.config.audio
|
429 |
+
|
430 |
+
batch['waveform'] = to_cuda(batch['waveform'])
|
431 |
+
wav = batch["waveform"]
|
432 |
+
|
433 |
+
batch['spec'] = self.get_spectogram_nt(wav)
|
434 |
+
|
435 |
+
if self.args.encoder_sample_rate:
|
436 |
+
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
437 |
+
if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor):
|
438 |
+
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
439 |
+
else:
|
440 |
+
batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)]
|
441 |
+
else:
|
442 |
+
spec_mel = batch["spec"]
|
443 |
+
|
444 |
+
batch["mel"] = spec_to_mel(
|
445 |
+
spec=spec_mel,
|
446 |
+
n_fft=ac.fft_size,
|
447 |
+
num_mels=ac.num_mels,
|
448 |
+
sample_rate=ac.sample_rate,
|
449 |
+
fmin=ac.mel_fmin,
|
450 |
+
fmax=ac.mel_fmax,
|
451 |
+
)
|
452 |
+
|
453 |
+
if self.args.encoder_sample_rate:
|
454 |
+
assert batch["spec"].shape[2] == int(
|
455 |
+
batch["mel"].shape[2] / self.interpolate_factor
|
456 |
+
), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
457 |
+
else:
|
458 |
+
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
459 |
+
|
460 |
+
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
461 |
+
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
462 |
+
|
463 |
+
if self.args.encoder_sample_rate:
|
464 |
+
assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0
|
465 |
+
else:
|
466 |
+
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
467 |
+
|
468 |
+
batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1)
|
469 |
+
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
470 |
+
return batch
|
471 |
+
|
472 |
+
def train_step(
|
473 |
+
self,
|
474 |
+
batch: dict,
|
475 |
+
criterion: torch.nn.Module,
|
476 |
+
optimizer_idx: int,
|
477 |
+
):
|
478 |
+
"""
|
479 |
+
Perform a single training step. Run the model forward pass and compute losses.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
batch (Dict): Input tensors.
|
483 |
+
criterion (nn.Module): Loss layer designed for the model.
|
484 |
+
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
488 |
+
"""
|
489 |
+
if optimizer_idx == 0:
|
490 |
+
# generator pass
|
491 |
+
outputs = self.forward(
|
492 |
+
batch["tokens"],
|
493 |
+
batch["token_lens"],
|
494 |
+
batch["spec"],
|
495 |
+
batch["spec_lens"],
|
496 |
+
aux_input={
|
497 |
+
**batch,
|
498 |
+
},
|
499 |
+
)
|
500 |
+
# cache tensors for the generator pass
|
501 |
+
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
502 |
+
scores_disc_fake, _, scores_disc_real, _ = self.disc(
|
503 |
+
outputs["model_outputs"].detach(),
|
504 |
+
outputs["waveform_seg"]
|
505 |
+
)
|
506 |
+
# compute loss
|
507 |
+
with autocast(enabled=False): # use float32 for the criterion
|
508 |
+
loss_dict = criterion[optimizer_idx](
|
509 |
+
scores_disc_real,
|
510 |
+
scores_disc_fake,
|
511 |
+
)
|
512 |
+
return outputs, loss_dict
|
513 |
+
|
514 |
+
if optimizer_idx == 1:
|
515 |
+
# compute melspec segment
|
516 |
+
with autocast(enabled=False):
|
517 |
+
if self.args.encoder_sample_rate:
|
518 |
+
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
519 |
+
else:
|
520 |
+
spec_segment_size = self.spec_segment_size
|
521 |
+
mel_slice = segment(
|
522 |
+
batch["mel"].float(),
|
523 |
+
self.model_outputs_cache["slice_ids"],
|
524 |
+
spec_segment_size,
|
525 |
+
pad_short=True
|
526 |
+
)
|
527 |
+
|
528 |
+
spec = self.get_spectogram_nt(
|
529 |
+
self.model_outputs_cache["model_outputs"].float(),
|
530 |
+
)
|
531 |
+
mel_slice_hat = spec_to_mel(
|
532 |
+
spec=spec,
|
533 |
+
n_fft=self.config.audio.fft_size,
|
534 |
+
num_mels=self.config.audio.num_mels,
|
535 |
+
sample_rate=self.config.audio.sample_rate,
|
536 |
+
fmin=self.config.audio.mel_fmin,
|
537 |
+
fmax=self.config.audio.mel_fmax,
|
538 |
+
)
|
539 |
+
|
540 |
+
# compute discriminator scores and features
|
541 |
+
scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
|
542 |
+
self.model_outputs_cache["model_outputs"],
|
543 |
+
self.model_outputs_cache["waveform_seg"],
|
544 |
+
)
|
545 |
+
|
546 |
+
# compute losses
|
547 |
+
with autocast(enabled=False): # use float32 for the criterion
|
548 |
+
loss_dict = criterion[optimizer_idx](
|
549 |
+
mel_slice_hat=mel_slice.float(),
|
550 |
+
mel_slice=mel_slice_hat.float(),
|
551 |
+
z_p=self.model_outputs_cache["z_p"].float(),
|
552 |
+
logs_q=self.model_outputs_cache["logs_q"].float(),
|
553 |
+
m_p=self.model_outputs_cache["m_p"].float(),
|
554 |
+
logs_p=self.model_outputs_cache["logs_p"].float(),
|
555 |
+
z_len=batch["spec_lens"],
|
556 |
+
scores_disc_fake=scores_disc_fake,
|
557 |
+
feats_disc_fake=feats_disc_fake,
|
558 |
+
feats_disc_real=feats_disc_real,
|
559 |
+
loss_duration=self.model_outputs_cache["loss_duration"],
|
560 |
+
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
561 |
+
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
|
562 |
+
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
|
563 |
+
)
|
564 |
+
return self.model_outputs_cache, loss_dict
|
565 |
+
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
566 |
+
|
567 |
+
@torch.no_grad()
|
568 |
+
def test_run(self, assets):
|
569 |
+
"""Generic test run for `tts` models used by `Trainer`.
|
570 |
+
|
571 |
+
You can override this for a different behaviour.
|
572 |
+
|
573 |
+
Returns:
|
574 |
+
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
575 |
+
"""
|
576 |
+
print(" | > Synthesizing test sentences.")
|
577 |
+
test_audios = {}
|
578 |
+
test_figures = {}
|
579 |
+
test_sentences = self.config.test_sentences
|
580 |
+
for idx, s_info in enumerate(test_sentences):
|
581 |
+
wav = self.synthesize_from_example(s_info)
|
582 |
+
test_audios["{}-audio".format(idx)] = wav
|
583 |
+
return {"figures": test_figures, "audios": test_audios}
|
584 |
+
|
585 |
+
def get_data_loader(
|
586 |
+
self,
|
587 |
+
config: Coqpit,
|
588 |
+
assets: Dict,
|
589 |
+
is_eval: bool,
|
590 |
+
samples: Union[List[Dict], List[List]],
|
591 |
+
verbose: bool,
|
592 |
+
num_gpus: int,
|
593 |
+
rank: int = None,
|
594 |
+
) -> "DataLoader":
|
595 |
+
dataset = VitsDataset_NT(
|
596 |
+
model_args=self.args,
|
597 |
+
speaker_manager=self.speaker_manager,
|
598 |
+
config=self.config,
|
599 |
+
use_phone_labels=config.use_phone_labels,
|
600 |
+
sample_rate=self.sample_rate,
|
601 |
+
samples=samples,
|
602 |
+
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
603 |
+
min_text_len=config.min_text_len,
|
604 |
+
max_text_len=config.max_text_len,
|
605 |
+
min_audio_len=config.min_audio_len,
|
606 |
+
max_audio_len=config.max_audio_len,
|
607 |
+
phoneme_cache_path=config.phoneme_cache_path,
|
608 |
+
precompute_num_workers=config.precompute_num_workers,
|
609 |
+
verbose=verbose,
|
610 |
+
tokenizer=self.tokenizer,
|
611 |
+
start_by_longest=config.start_by_longest,
|
612 |
+
)
|
613 |
+
|
614 |
+
# sort input sequences from short to long
|
615 |
+
dataset.preprocess_samples()
|
616 |
+
|
617 |
+
# get samplers
|
618 |
+
sampler = self.get_sampler(config, dataset, num_gpus)
|
619 |
+
loader = DataLoader(
|
620 |
+
dataset,
|
621 |
+
batch_sampler=sampler,
|
622 |
+
collate_fn=dataset.collate_fn,
|
623 |
+
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
624 |
+
pin_memory=False,
|
625 |
+
)
|
626 |
+
return loader
|
627 |
+
|
628 |
+
|
629 |
+
class VitsDataset_NT(VitsDataset):
|
630 |
+
def __init__(
|
631 |
+
self,
|
632 |
+
model_args,
|
633 |
+
speaker_manager,
|
634 |
+
sample_rate,
|
635 |
+
config,
|
636 |
+
use_phone_labels,
|
637 |
+
*args,
|
638 |
+
**kwargs
|
639 |
+
):
|
640 |
+
super().__init__(model_args, *args, **kwargs)
|
641 |
+
self.speaker_manager = speaker_manager
|
642 |
+
self.sample_rate = sample_rate
|
643 |
+
self.config = config
|
644 |
+
self.use_phone_labels = use_phone_labels
|
645 |
+
|
646 |
+
def __getitem__(self, idx):
|
647 |
+
example = self.samples[idx]
|
648 |
+
token_ids = self.get_token_ids(idx, example["text"])
|
649 |
+
|
650 |
+
wav, _ = load_audio(example["audio_file"], target_sr=self.sample_rate)
|
651 |
+
|
652 |
+
speaker_id = example['speaker_name']
|
653 |
+
example_id = example['example_id']
|
654 |
+
d_vector = None
|
655 |
+
for dataset_dict_sub in self.config.dataset_dict['datasets'].values():
|
656 |
+
d_vector_file = dataset_dict_sub['d_vector_storage_root']
|
657 |
+
if (Path(d_vector_file) / f'{speaker_id}/{example_id}.pth').is_file():
|
658 |
+
d_vector = torch.load(Path(d_vector_file) / f'{speaker_id}/{example_id}.pth')
|
659 |
+
break
|
660 |
+
if d_vector is None:
|
661 |
+
raise ValueError(f'Could not find d_vector for example {example_id}')
|
662 |
+
|
663 |
+
if d_vector.dim() == 1:
|
664 |
+
d_vector = d_vector[None, :]
|
665 |
+
return {
|
666 |
+
"raw_text": example['text'],
|
667 |
+
"token_ids": token_ids,
|
668 |
+
"token_len": len(token_ids),
|
669 |
+
"wav": wav,
|
670 |
+
"d_vector": d_vector,
|
671 |
+
"speaker_name": example["speaker_name"]
|
672 |
+
}
|
673 |
+
|
674 |
+
def collate_fn(self, batch):
|
675 |
+
"""
|
676 |
+
Collate a list of samples from a Dataset into a batch for VITS.
|
677 |
+
|
678 |
+
Args:
|
679 |
+
batch (dict): Expeted keys:
|
680 |
+
- wav (list): list of tensors
|
681 |
+
- token_ids (list):
|
682 |
+
- token_len (list):
|
683 |
+
- speaker_name (list):
|
684 |
+
- language_name (list):
|
685 |
+
- audiofile_path (list):
|
686 |
+
- raw_text (list):
|
687 |
+
- wav_d_vector (list):
|
688 |
+
Returns:
|
689 |
+
- tokens (torch.Tensor): (B, T)
|
690 |
+
- token_lens (torch.Tensor): (B)
|
691 |
+
- token_rel_lens (torch.Tensor): (B)
|
692 |
+
- wav (torch.Tensor): (B, 1, T)
|
693 |
+
- wav_lens (torch.Tensor): (B)
|
694 |
+
- wav_rel_lens (torch.Tensor): (B)
|
695 |
+
- speaker_names (torch.Tensor): (B)
|
696 |
+
- language_names (torch.Tensor): (B)
|
697 |
+
- audiofile_paths (torch.Tensor): (B)
|
698 |
+
- raw_texts (torch.Tensor): (B)
|
699 |
+
- audio_unique_names (torch.Tensor): (B)
|
700 |
+
"""
|
701 |
+
B = len(batch)
|
702 |
+
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
703 |
+
|
704 |
+
_, ids_sorted_decreasing = torch.sort(
|
705 |
+
torch.LongTensor(
|
706 |
+
[
|
707 |
+
x.size(1) for x in batch["wav"]]
|
708 |
+
),
|
709 |
+
dim=0,
|
710 |
+
descending=True
|
711 |
+
)
|
712 |
+
|
713 |
+
max_text_len = max([len(x) for x in batch["token_ids"]])
|
714 |
+
token_lens = torch.LongTensor(batch["token_len"])
|
715 |
+
token_rel_lens = token_lens / token_lens.max()
|
716 |
+
|
717 |
+
wav_lens = [w.shape[1] for w in batch["wav"]]
|
718 |
+
wav_lens = torch.LongTensor(wav_lens)
|
719 |
+
wav_lens_max = torch.max(wav_lens)
|
720 |
+
wav_rel_lens = wav_lens / wav_lens_max
|
721 |
+
|
722 |
+
token_padded = torch.LongTensor(B, max_text_len)
|
723 |
+
wav_padded = torch.FloatTensor(B, 1, wav_lens_max)
|
724 |
+
token_padded = token_padded.zero_() + self.pad_id
|
725 |
+
wav_padded = wav_padded.zero_() + self.pad_id
|
726 |
+
for i in range(len(ids_sorted_decreasing)):
|
727 |
+
token_ids = batch["token_ids"][i]
|
728 |
+
token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids)
|
729 |
+
wav = batch["wav"][i]
|
730 |
+
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
731 |
+
|
732 |
+
return {
|
733 |
+
"tokens": token_padded,
|
734 |
+
"token_lens": token_lens,
|
735 |
+
"token_rel_lens": token_rel_lens,
|
736 |
+
"waveform": wav_padded,
|
737 |
+
"waveform_lens": wav_lens,
|
738 |
+
"waveform_rel_lens": wav_rel_lens,
|
739 |
+
"speaker_names": batch["speaker_name"],
|
740 |
+
"raw_text": batch["raw_text"],
|
741 |
+
"d_vector": torch.concatenate(batch["d_vector"]) if 'd_vector' in batch.keys() else None,
|
742 |
+
}
|
setup.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils.core import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='pvq_manipulation',
|
5 |
+
version='0.0.0',
|
6 |
+
author='Department of Communications Engineering, Paderborn University',
|
7 |
+
author_email='sek@nt.upb.de',
|
8 |
+
license='MIT',
|
9 |
+
keywords='audio speech',
|
10 |
+
install_requires=[
|
11 |
+
'torchdiffeq',
|
12 |
+
],
|
13 |
+
)
|