lj1995 commited on
Commit
e27e3fe
·
1 Parent(s): 6b0dc77

Delete module

Browse files
module/__init__.py DELETED
File without changes
module/attentions.py DELETED
@@ -1,709 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from module import commons
7
- from module.modules import LayerNorm
8
-
9
-
10
- class Encoder(nn.Module):
11
- def __init__(
12
- self,
13
- hidden_channels,
14
- filter_channels,
15
- n_heads,
16
- n_layers,
17
- kernel_size=1,
18
- p_dropout=0.0,
19
- window_size=4,
20
- isflow=False,
21
- **kwargs
22
- ):
23
- super().__init__()
24
- self.hidden_channels = hidden_channels
25
- self.filter_channels = filter_channels
26
- self.n_heads = n_heads
27
- self.n_layers = n_layers
28
- self.kernel_size = kernel_size
29
- self.p_dropout = p_dropout
30
- self.window_size = window_size
31
-
32
- self.drop = nn.Dropout(p_dropout)
33
- self.attn_layers = nn.ModuleList()
34
- self.norm_layers_1 = nn.ModuleList()
35
- self.ffn_layers = nn.ModuleList()
36
- self.norm_layers_2 = nn.ModuleList()
37
- for i in range(self.n_layers):
38
- self.attn_layers.append(
39
- MultiHeadAttention(
40
- hidden_channels,
41
- hidden_channels,
42
- n_heads,
43
- p_dropout=p_dropout,
44
- window_size=window_size,
45
- )
46
- )
47
- self.norm_layers_1.append(LayerNorm(hidden_channels))
48
- self.ffn_layers.append(
49
- FFN(
50
- hidden_channels,
51
- hidden_channels,
52
- filter_channels,
53
- kernel_size,
54
- p_dropout=p_dropout,
55
- )
56
- )
57
- self.norm_layers_2.append(LayerNorm(hidden_channels))
58
- if isflow:
59
- cond_layer = torch.nn.Conv1d(
60
- kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
61
- )
62
- self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
63
- self.cond_layer = weight_norm_modules(cond_layer, name="weight")
64
- self.gin_channels = kwargs["gin_channels"]
65
-
66
- def forward(self, x, x_mask, g=None):
67
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
68
- x = x * x_mask
69
- if g is not None:
70
- g = self.cond_layer(g)
71
-
72
- for i in range(self.n_layers):
73
- if g is not None:
74
- x = self.cond_pre(x)
75
- cond_offset = i * 2 * self.hidden_channels
76
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
77
- x = commons.fused_add_tanh_sigmoid_multiply(
78
- x, g_l, torch.IntTensor([self.hidden_channels])
79
- )
80
- y = self.attn_layers[i](x, x, attn_mask)
81
- y = self.drop(y)
82
- x = self.norm_layers_1[i](x + y)
83
-
84
- y = self.ffn_layers[i](x, x_mask)
85
- y = self.drop(y)
86
- x = self.norm_layers_2[i](x + y)
87
- x = x * x_mask
88
- return x
89
-
90
-
91
- class Decoder(nn.Module):
92
- def __init__(
93
- self,
94
- hidden_channels,
95
- filter_channels,
96
- n_heads,
97
- n_layers,
98
- kernel_size=1,
99
- p_dropout=0.0,
100
- proximal_bias=False,
101
- proximal_init=True,
102
- **kwargs
103
- ):
104
- super().__init__()
105
- self.hidden_channels = hidden_channels
106
- self.filter_channels = filter_channels
107
- self.n_heads = n_heads
108
- self.n_layers = n_layers
109
- self.kernel_size = kernel_size
110
- self.p_dropout = p_dropout
111
- self.proximal_bias = proximal_bias
112
- self.proximal_init = proximal_init
113
-
114
- self.drop = nn.Dropout(p_dropout)
115
- self.self_attn_layers = nn.ModuleList()
116
- self.norm_layers_0 = nn.ModuleList()
117
- self.encdec_attn_layers = nn.ModuleList()
118
- self.norm_layers_1 = nn.ModuleList()
119
- self.ffn_layers = nn.ModuleList()
120
- self.norm_layers_2 = nn.ModuleList()
121
- for i in range(self.n_layers):
122
- self.self_attn_layers.append(
123
- MultiHeadAttention(
124
- hidden_channels,
125
- hidden_channels,
126
- n_heads,
127
- p_dropout=p_dropout,
128
- proximal_bias=proximal_bias,
129
- proximal_init=proximal_init,
130
- )
131
- )
132
- self.norm_layers_0.append(LayerNorm(hidden_channels))
133
- self.encdec_attn_layers.append(
134
- MultiHeadAttention(
135
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
136
- )
137
- )
138
- self.norm_layers_1.append(LayerNorm(hidden_channels))
139
- self.ffn_layers.append(
140
- FFN(
141
- hidden_channels,
142
- hidden_channels,
143
- filter_channels,
144
- kernel_size,
145
- p_dropout=p_dropout,
146
- causal=True,
147
- )
148
- )
149
- self.norm_layers_2.append(LayerNorm(hidden_channels))
150
-
151
- def forward(self, x, x_mask, h, h_mask):
152
- """
153
- x: decoder input
154
- h: encoder output
155
- """
156
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
157
- device=x.device, dtype=x.dtype
158
- )
159
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
160
- x = x * x_mask
161
- for i in range(self.n_layers):
162
- y = self.self_attn_layers[i](x, x, self_attn_mask)
163
- y = self.drop(y)
164
- x = self.norm_layers_0[i](x + y)
165
-
166
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
167
- y = self.drop(y)
168
- x = self.norm_layers_1[i](x + y)
169
-
170
- y = self.ffn_layers[i](x, x_mask)
171
- y = self.drop(y)
172
- x = self.norm_layers_2[i](x + y)
173
- x = x * x_mask
174
- return x
175
-
176
-
177
- class MultiHeadAttention(nn.Module):
178
- def __init__(
179
- self,
180
- channels,
181
- out_channels,
182
- n_heads,
183
- p_dropout=0.0,
184
- window_size=None,
185
- heads_share=True,
186
- block_length=None,
187
- proximal_bias=False,
188
- proximal_init=False,
189
- ):
190
- super().__init__()
191
- assert channels % n_heads == 0
192
-
193
- self.channels = channels
194
- self.out_channels = out_channels
195
- self.n_heads = n_heads
196
- self.p_dropout = p_dropout
197
- self.window_size = window_size
198
- self.heads_share = heads_share
199
- self.block_length = block_length
200
- self.proximal_bias = proximal_bias
201
- self.proximal_init = proximal_init
202
- self.attn = None
203
-
204
- self.k_channels = channels // n_heads
205
- self.conv_q = nn.Conv1d(channels, channels, 1)
206
- self.conv_k = nn.Conv1d(channels, channels, 1)
207
- self.conv_v = nn.Conv1d(channels, channels, 1)
208
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
209
- self.drop = nn.Dropout(p_dropout)
210
-
211
- if window_size is not None:
212
- n_heads_rel = 1 if heads_share else n_heads
213
- rel_stddev = self.k_channels**-0.5
214
- self.emb_rel_k = nn.Parameter(
215
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
216
- * rel_stddev
217
- )
218
- self.emb_rel_v = nn.Parameter(
219
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
220
- * rel_stddev
221
- )
222
-
223
- nn.init.xavier_uniform_(self.conv_q.weight)
224
- nn.init.xavier_uniform_(self.conv_k.weight)
225
- nn.init.xavier_uniform_(self.conv_v.weight)
226
- if proximal_init:
227
- with torch.no_grad():
228
- self.conv_k.weight.copy_(self.conv_q.weight)
229
- self.conv_k.bias.copy_(self.conv_q.bias)
230
-
231
- def forward(self, x, c, attn_mask=None):
232
- q = self.conv_q(x)
233
- k = self.conv_k(c)
234
- v = self.conv_v(c)
235
-
236
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
237
-
238
- x = self.conv_o(x)
239
- return x
240
-
241
- def attention(self, query, key, value, mask=None):
242
- # reshape [b, d, t] -> [b, n_h, t, d_k]
243
- b, d, t_s, t_t = (*key.size(), query.size(2))
244
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
245
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
246
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
247
-
248
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
249
- if self.window_size is not None:
250
- assert (
251
- t_s == t_t
252
- ), "Relative attention is only available for self-attention."
253
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
254
- rel_logits = self._matmul_with_relative_keys(
255
- query / math.sqrt(self.k_channels), key_relative_embeddings
256
- )
257
- scores_local = self._relative_position_to_absolute_position(rel_logits)
258
- scores = scores + scores_local
259
- if self.proximal_bias:
260
- assert t_s == t_t, "Proximal bias is only available for self-attention."
261
- scores = scores + self._attention_bias_proximal(t_s).to(
262
- device=scores.device, dtype=scores.dtype
263
- )
264
- if mask is not None:
265
- scores = scores.masked_fill(mask == 0, -1e4)
266
- if self.block_length is not None:
267
- assert (
268
- t_s == t_t
269
- ), "Local attention is only available for self-attention."
270
- block_mask = (
271
- torch.ones_like(scores)
272
- .triu(-self.block_length)
273
- .tril(self.block_length)
274
- )
275
- scores = scores.masked_fill(block_mask == 0, -1e4)
276
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
277
- p_attn = self.drop(p_attn)
278
- output = torch.matmul(p_attn, value)
279
- if self.window_size is not None:
280
- relative_weights = self._absolute_position_to_relative_position(p_attn)
281
- value_relative_embeddings = self._get_relative_embeddings(
282
- self.emb_rel_v, t_s
283
- )
284
- output = output + self._matmul_with_relative_values(
285
- relative_weights, value_relative_embeddings
286
- )
287
- output = (
288
- output.transpose(2, 3).contiguous().view(b, d, t_t)
289
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
290
- return output, p_attn
291
-
292
- def _matmul_with_relative_values(self, x, y):
293
- """
294
- x: [b, h, l, m]
295
- y: [h or 1, m, d]
296
- ret: [b, h, l, d]
297
- """
298
- ret = torch.matmul(x, y.unsqueeze(0))
299
- return ret
300
-
301
- def _matmul_with_relative_keys(self, x, y):
302
- """
303
- x: [b, h, l, d]
304
- y: [h or 1, m, d]
305
- ret: [b, h, l, m]
306
- """
307
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
308
- return ret
309
-
310
- def _get_relative_embeddings(self, relative_embeddings, length):
311
- max_relative_position = 2 * self.window_size + 1
312
- # Pad first before slice to avoid using cond ops.
313
- pad_length = max(length - (self.window_size + 1), 0)
314
- slice_start_position = max((self.window_size + 1) - length, 0)
315
- slice_end_position = slice_start_position + 2 * length - 1
316
- if pad_length > 0:
317
- padded_relative_embeddings = F.pad(
318
- relative_embeddings,
319
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
320
- )
321
- else:
322
- padded_relative_embeddings = relative_embeddings
323
- used_relative_embeddings = padded_relative_embeddings[
324
- :, slice_start_position:slice_end_position
325
- ]
326
- return used_relative_embeddings
327
-
328
- def _relative_position_to_absolute_position(self, x):
329
- """
330
- x: [b, h, l, 2*l-1]
331
- ret: [b, h, l, l]
332
- """
333
- batch, heads, length, _ = x.size()
334
- # Concat columns of pad to shift from relative to absolute indexing.
335
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
336
-
337
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
338
- x_flat = x.view([batch, heads, length * 2 * length])
339
- x_flat = F.pad(
340
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
341
- )
342
-
343
- # Reshape and slice out the padded elements.
344
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
345
- :, :, :length, length - 1 :
346
- ]
347
- return x_final
348
-
349
- def _absolute_position_to_relative_position(self, x):
350
- """
351
- x: [b, h, l, l]
352
- ret: [b, h, l, 2*l-1]
353
- """
354
- batch, heads, length, _ = x.size()
355
- # padd along column
356
- x = F.pad(
357
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
358
- )
359
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
360
- # add 0's in the beginning that will skew the elements after reshape
361
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
362
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
363
- return x_final
364
-
365
- def _attention_bias_proximal(self, length):
366
- """Bias for self-attention to encourage attention to close positions.
367
- Args:
368
- length: an integer scalar.
369
- Returns:
370
- a Tensor with shape [1, 1, length, length]
371
- """
372
- r = torch.arange(length, dtype=torch.float32)
373
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
374
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
375
-
376
-
377
- class FFN(nn.Module):
378
- def __init__(
379
- self,
380
- in_channels,
381
- out_channels,
382
- filter_channels,
383
- kernel_size,
384
- p_dropout=0.0,
385
- activation=None,
386
- causal=False,
387
- ):
388
- super().__init__()
389
- self.in_channels = in_channels
390
- self.out_channels = out_channels
391
- self.filter_channels = filter_channels
392
- self.kernel_size = kernel_size
393
- self.p_dropout = p_dropout
394
- self.activation = activation
395
- self.causal = causal
396
-
397
- if causal:
398
- self.padding = self._causal_padding
399
- else:
400
- self.padding = self._same_padding
401
-
402
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
403
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
404
- self.drop = nn.Dropout(p_dropout)
405
-
406
- def forward(self, x, x_mask):
407
- x = self.conv_1(self.padding(x * x_mask))
408
- if self.activation == "gelu":
409
- x = x * torch.sigmoid(1.702 * x)
410
- else:
411
- x = torch.relu(x)
412
- x = self.drop(x)
413
- x = self.conv_2(self.padding(x * x_mask))
414
- return x * x_mask
415
-
416
- def _causal_padding(self, x):
417
- if self.kernel_size == 1:
418
- return x
419
- pad_l = self.kernel_size - 1
420
- pad_r = 0
421
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
422
- x = F.pad(x, commons.convert_pad_shape(padding))
423
- return x
424
-
425
- def _same_padding(self, x):
426
- if self.kernel_size == 1:
427
- return x
428
- pad_l = (self.kernel_size - 1) // 2
429
- pad_r = self.kernel_size // 2
430
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
431
- x = F.pad(x, commons.convert_pad_shape(padding))
432
- return x
433
-
434
-
435
- import torch.nn as nn
436
- from torch.nn.utils import remove_weight_norm, weight_norm
437
-
438
-
439
- class Depthwise_Separable_Conv1D(nn.Module):
440
- def __init__(
441
- self,
442
- in_channels,
443
- out_channels,
444
- kernel_size,
445
- stride=1,
446
- padding=0,
447
- dilation=1,
448
- bias=True,
449
- padding_mode="zeros", # TODO: refine this type
450
- device=None,
451
- dtype=None,
452
- ):
453
- super().__init__()
454
- self.depth_conv = nn.Conv1d(
455
- in_channels=in_channels,
456
- out_channels=in_channels,
457
- kernel_size=kernel_size,
458
- groups=in_channels,
459
- stride=stride,
460
- padding=padding,
461
- dilation=dilation,
462
- bias=bias,
463
- padding_mode=padding_mode,
464
- device=device,
465
- dtype=dtype,
466
- )
467
- self.point_conv = nn.Conv1d(
468
- in_channels=in_channels,
469
- out_channels=out_channels,
470
- kernel_size=1,
471
- bias=bias,
472
- device=device,
473
- dtype=dtype,
474
- )
475
-
476
- def forward(self, input):
477
- return self.point_conv(self.depth_conv(input))
478
-
479
- def weight_norm(self):
480
- self.depth_conv = weight_norm(self.depth_conv, name="weight")
481
- self.point_conv = weight_norm(self.point_conv, name="weight")
482
-
483
- def remove_weight_norm(self):
484
- self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
485
- self.point_conv = remove_weight_norm(self.point_conv, name="weight")
486
-
487
-
488
- class Depthwise_Separable_TransposeConv1D(nn.Module):
489
- def __init__(
490
- self,
491
- in_channels,
492
- out_channels,
493
- kernel_size,
494
- stride=1,
495
- padding=0,
496
- output_padding=0,
497
- bias=True,
498
- dilation=1,
499
- padding_mode="zeros", # TODO: refine this type
500
- device=None,
501
- dtype=None,
502
- ):
503
- super().__init__()
504
- self.depth_conv = nn.ConvTranspose1d(
505
- in_channels=in_channels,
506
- out_channels=in_channels,
507
- kernel_size=kernel_size,
508
- groups=in_channels,
509
- stride=stride,
510
- output_padding=output_padding,
511
- padding=padding,
512
- dilation=dilation,
513
- bias=bias,
514
- padding_mode=padding_mode,
515
- device=device,
516
- dtype=dtype,
517
- )
518
- self.point_conv = nn.Conv1d(
519
- in_channels=in_channels,
520
- out_channels=out_channels,
521
- kernel_size=1,
522
- bias=bias,
523
- device=device,
524
- dtype=dtype,
525
- )
526
-
527
- def forward(self, input):
528
- return self.point_conv(self.depth_conv(input))
529
-
530
- def weight_norm(self):
531
- self.depth_conv = weight_norm(self.depth_conv, name="weight")
532
- self.point_conv = weight_norm(self.point_conv, name="weight")
533
-
534
- def remove_weight_norm(self):
535
- remove_weight_norm(self.depth_conv, name="weight")
536
- remove_weight_norm(self.point_conv, name="weight")
537
-
538
-
539
- def weight_norm_modules(module, name="weight", dim=0):
540
- if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
541
- module, Depthwise_Separable_TransposeConv1D
542
- ):
543
- module.weight_norm()
544
- return module
545
- else:
546
- return weight_norm(module, name, dim)
547
-
548
-
549
- def remove_weight_norm_modules(module, name="weight"):
550
- if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
551
- module, Depthwise_Separable_TransposeConv1D
552
- ):
553
- module.remove_weight_norm()
554
- else:
555
- remove_weight_norm(module, name)
556
-
557
-
558
- class FFT(nn.Module):
559
- def __init__(
560
- self,
561
- hidden_channels,
562
- filter_channels,
563
- n_heads,
564
- n_layers=1,
565
- kernel_size=1,
566
- p_dropout=0.0,
567
- proximal_bias=False,
568
- proximal_init=True,
569
- isflow=False,
570
- **kwargs
571
- ):
572
- super().__init__()
573
- self.hidden_channels = hidden_channels
574
- self.filter_channels = filter_channels
575
- self.n_heads = n_heads
576
- self.n_layers = n_layers
577
- self.kernel_size = kernel_size
578
- self.p_dropout = p_dropout
579
- self.proximal_bias = proximal_bias
580
- self.proximal_init = proximal_init
581
- if isflow:
582
- cond_layer = torch.nn.Conv1d(
583
- kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
584
- )
585
- self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
586
- self.cond_layer = weight_norm_modules(cond_layer, name="weight")
587
- self.gin_channels = kwargs["gin_channels"]
588
- self.drop = nn.Dropout(p_dropout)
589
- self.self_attn_layers = nn.ModuleList()
590
- self.norm_layers_0 = nn.ModuleList()
591
- self.ffn_layers = nn.ModuleList()
592
- self.norm_layers_1 = nn.ModuleList()
593
- for i in range(self.n_layers):
594
- self.self_attn_layers.append(
595
- MultiHeadAttention(
596
- hidden_channels,
597
- hidden_channels,
598
- n_heads,
599
- p_dropout=p_dropout,
600
- proximal_bias=proximal_bias,
601
- proximal_init=proximal_init,
602
- )
603
- )
604
- self.norm_layers_0.append(LayerNorm(hidden_channels))
605
- self.ffn_layers.append(
606
- FFN(
607
- hidden_channels,
608
- hidden_channels,
609
- filter_channels,
610
- kernel_size,
611
- p_dropout=p_dropout,
612
- causal=True,
613
- )
614
- )
615
- self.norm_layers_1.append(LayerNorm(hidden_channels))
616
-
617
- def forward(self, x, x_mask, g=None):
618
- """
619
- x: decoder input
620
- h: encoder output
621
- """
622
- if g is not None:
623
- g = self.cond_layer(g)
624
-
625
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
626
- device=x.device, dtype=x.dtype
627
- )
628
- x = x * x_mask
629
- for i in range(self.n_layers):
630
- if g is not None:
631
- x = self.cond_pre(x)
632
- cond_offset = i * 2 * self.hidden_channels
633
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
634
- x = commons.fused_add_tanh_sigmoid_multiply(
635
- x, g_l, torch.IntTensor([self.hidden_channels])
636
- )
637
- y = self.self_attn_layers[i](x, x, self_attn_mask)
638
- y = self.drop(y)
639
- x = self.norm_layers_0[i](x + y)
640
-
641
- y = self.ffn_layers[i](x, x_mask)
642
- y = self.drop(y)
643
- x = self.norm_layers_1[i](x + y)
644
- x = x * x_mask
645
- return x
646
-
647
-
648
- class TransformerCouplingLayer(nn.Module):
649
- def __init__(
650
- self,
651
- channels,
652
- hidden_channels,
653
- kernel_size,
654
- n_layers,
655
- n_heads,
656
- p_dropout=0,
657
- filter_channels=0,
658
- mean_only=False,
659
- wn_sharing_parameter=None,
660
- gin_channels=0,
661
- ):
662
- assert channels % 2 == 0, "channels should be divisible by 2"
663
- super().__init__()
664
- self.channels = channels
665
- self.hidden_channels = hidden_channels
666
- self.kernel_size = kernel_size
667
- self.n_layers = n_layers
668
- self.half_channels = channels // 2
669
- self.mean_only = mean_only
670
-
671
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
672
- self.enc = (
673
- Encoder(
674
- hidden_channels,
675
- filter_channels,
676
- n_heads,
677
- n_layers,
678
- kernel_size,
679
- p_dropout,
680
- isflow=True,
681
- gin_channels=gin_channels,
682
- )
683
- if wn_sharing_parameter is None
684
- else wn_sharing_parameter
685
- )
686
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
687
- self.post.weight.data.zero_()
688
- self.post.bias.data.zero_()
689
-
690
- def forward(self, x, x_mask, g=None, reverse=False):
691
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
692
- h = self.pre(x0) * x_mask
693
- h = self.enc(h, x_mask, g=g)
694
- stats = self.post(h) * x_mask
695
- if not self.mean_only:
696
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
697
- else:
698
- m = stats
699
- logs = torch.zeros_like(m)
700
-
701
- if not reverse:
702
- x1 = m + x1 * torch.exp(logs) * x_mask
703
- x = torch.cat([x0, x1], 1)
704
- logdet = torch.sum(logs, [1, 2])
705
- return x, logdet
706
- else:
707
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
708
- x = torch.cat([x0, x1], 1)
709
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/attentions_onnx.py DELETED
@@ -1,354 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from module import commons
7
- from module.modules import LayerNorm
8
-
9
-
10
- class LayerNorm(nn.Module):
11
- def __init__(self, channels, eps=1e-5):
12
- super().__init__()
13
- self.channels = channels
14
- self.eps = eps
15
-
16
- self.gamma = nn.Parameter(torch.ones(channels))
17
- self.beta = nn.Parameter(torch.zeros(channels))
18
-
19
- def forward(self, x):
20
- x = x.transpose(1, -1)
21
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
22
- return x.transpose(1, -1)
23
-
24
-
25
- @torch.jit.script
26
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
27
- n_channels_int = n_channels[0]
28
- in_act = input_a + input_b
29
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
30
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
31
- acts = t_act * s_act
32
- return acts
33
-
34
-
35
- class Encoder(nn.Module):
36
- def __init__(
37
- self,
38
- hidden_channels,
39
- filter_channels,
40
- n_heads,
41
- n_layers,
42
- kernel_size=1,
43
- p_dropout=0.0,
44
- window_size=4,
45
- isflow=True,
46
- **kwargs
47
- ):
48
- super().__init__()
49
- self.hidden_channels = hidden_channels
50
- self.filter_channels = filter_channels
51
- self.n_heads = n_heads
52
- self.n_layers = n_layers
53
- self.kernel_size = kernel_size
54
- self.p_dropout = p_dropout
55
- self.window_size = window_size
56
- # if isflow:
57
- # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
58
- # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
59
- # self.cond_layer = weight_norm(cond_layer, name='weight')
60
- # self.gin_channels = 256
61
- self.cond_layer_idx = self.n_layers
62
- if "gin_channels" in kwargs:
63
- self.gin_channels = kwargs["gin_channels"]
64
- if self.gin_channels != 0:
65
- self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
66
- # vits2 says 3rd block, so idx is 2 by default
67
- self.cond_layer_idx = (
68
- kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
69
- )
70
- logging.debug(self.gin_channels, self.cond_layer_idx)
71
- assert (
72
- self.cond_layer_idx < self.n_layers
73
- ), "cond_layer_idx should be less than n_layers"
74
- self.drop = nn.Dropout(p_dropout)
75
- self.attn_layers = nn.ModuleList()
76
- self.norm_layers_1 = nn.ModuleList()
77
- self.ffn_layers = nn.ModuleList()
78
- self.norm_layers_2 = nn.ModuleList()
79
- for i in range(self.n_layers):
80
- self.attn_layers.append(
81
- MultiHeadAttention(
82
- hidden_channels,
83
- hidden_channels,
84
- n_heads,
85
- p_dropout=p_dropout,
86
- window_size=window_size,
87
- )
88
- )
89
- self.norm_layers_1.append(LayerNorm(hidden_channels))
90
- self.ffn_layers.append(
91
- FFN(
92
- hidden_channels,
93
- hidden_channels,
94
- filter_channels,
95
- kernel_size,
96
- p_dropout=p_dropout,
97
- )
98
- )
99
- self.norm_layers_2.append(LayerNorm(hidden_channels))
100
-
101
- def forward(self, x, x_mask, g=None):
102
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
103
- x = x * x_mask
104
- for i in range(self.n_layers):
105
- if i == self.cond_layer_idx and g is not None:
106
- g = self.spk_emb_linear(g.transpose(1, 2))
107
- g = g.transpose(1, 2)
108
- x = x + g
109
- x = x * x_mask
110
- y = self.attn_layers[i](x, x, attn_mask)
111
- y = self.drop(y)
112
- x = self.norm_layers_1[i](x + y)
113
-
114
- y = self.ffn_layers[i](x, x_mask)
115
- y = self.drop(y)
116
- x = self.norm_layers_2[i](x + y)
117
- x = x * x_mask
118
- return x
119
-
120
-
121
- class MultiHeadAttention(nn.Module):
122
- def __init__(
123
- self,
124
- channels,
125
- out_channels,
126
- n_heads,
127
- p_dropout=0.0,
128
- window_size=None,
129
- heads_share=True,
130
- block_length=None,
131
- proximal_bias=False,
132
- proximal_init=False,
133
- ):
134
- super().__init__()
135
- assert channels % n_heads == 0
136
-
137
- self.channels = channels
138
- self.out_channels = out_channels
139
- self.n_heads = n_heads
140
- self.p_dropout = p_dropout
141
- self.window_size = window_size
142
- self.heads_share = heads_share
143
- self.block_length = block_length
144
- self.proximal_bias = proximal_bias
145
- self.proximal_init = proximal_init
146
- self.attn = None
147
-
148
- self.k_channels = channels // n_heads
149
- self.conv_q = nn.Conv1d(channels, channels, 1)
150
- self.conv_k = nn.Conv1d(channels, channels, 1)
151
- self.conv_v = nn.Conv1d(channels, channels, 1)
152
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
153
- self.drop = nn.Dropout(p_dropout)
154
-
155
- if window_size is not None:
156
- n_heads_rel = 1 if heads_share else n_heads
157
- rel_stddev = self.k_channels**-0.5
158
- self.emb_rel_k = nn.Parameter(
159
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
160
- * rel_stddev
161
- )
162
- self.emb_rel_v = nn.Parameter(
163
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
164
- * rel_stddev
165
- )
166
-
167
- nn.init.xavier_uniform_(self.conv_q.weight)
168
- nn.init.xavier_uniform_(self.conv_k.weight)
169
- nn.init.xavier_uniform_(self.conv_v.weight)
170
- if proximal_init:
171
- with torch.no_grad():
172
- self.conv_k.weight.copy_(self.conv_q.weight)
173
- self.conv_k.bias.copy_(self.conv_q.bias)
174
-
175
- def forward(self, x, c, attn_mask=None):
176
- q = self.conv_q(x)
177
- k = self.conv_k(c)
178
- v = self.conv_v(c)
179
-
180
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
181
-
182
- x = self.conv_o(x)
183
- return x
184
-
185
- def attention(self, query, key, value, mask=None):
186
- # reshape [b, d, t] -> [b, n_h, t, d_k]
187
- b, d, t_s, _ = (*key.size(), query.size(2))
188
- query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
189
- key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
190
- value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
191
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
192
-
193
- if self.window_size is not None:
194
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
195
- rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
196
- scores_local = self._relative_position_to_absolute_position(rel_logits)
197
- scores = scores + scores_local
198
-
199
- if mask is not None:
200
- scores = scores.masked_fill(mask == 0, -1e4)
201
-
202
- p_attn = F.softmax(scores, dim=-1)
203
- p_attn = self.drop(p_attn)
204
- output = torch.matmul(p_attn, value)
205
-
206
- if self.window_size is not None:
207
- relative_weights = self._absolute_position_to_relative_position(p_attn)
208
- value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
209
- output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
210
-
211
- output = (output.transpose(2, 3).contiguous().view(b, d, -1))
212
- return output, p_attn
213
-
214
- def _matmul_with_relative_values(self, x, y):
215
- """
216
- x: [b, h, l, m]
217
- y: [h or 1, m, d]
218
- ret: [b, h, l, d]
219
- """
220
- ret = torch.matmul(x, y.unsqueeze(0))
221
- return ret
222
-
223
- def _matmul_with_relative_keys(self, x, y):
224
- """
225
- x: [b, h, l, d]
226
- y: [h or 1, m, d]
227
- ret: [b, h, l, m]
228
- """
229
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
230
- return ret
231
-
232
- def _get_relative_embeddings(self, relative_embeddings, length):
233
- max_relative_position = 2 * self.window_size + 1
234
- # Pad first before slice to avoid using cond ops.
235
- pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
236
- pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
237
- pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
238
- slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
239
-
240
- slice_end_position = slice_start_position + 2 * length - 1
241
- padded_relative_embeddings = F.pad(
242
- relative_embeddings,
243
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
244
- )
245
- used_relative_embeddings = padded_relative_embeddings[
246
- :, slice_start_position:slice_end_position
247
- ]
248
- return used_relative_embeddings
249
-
250
- def _relative_position_to_absolute_position(self, x):
251
- """
252
- x: [b, h, l, 2*l-1]
253
- ret: [b, h, l, l]
254
- """
255
- batch, heads, length, _ = x.size()
256
- # Concat columns of pad to shift from relative to absolute indexing.
257
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
258
-
259
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
260
- x_flat = x.view([batch, heads, length * 2 * length])
261
- x_flat = F.pad(
262
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
263
- )
264
-
265
- # Reshape and slice out the padded elements.
266
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
267
- :, :, :length, length - 1 :
268
- ]
269
- return x_final
270
-
271
- def _absolute_position_to_relative_position(self, x):
272
- """
273
- x: [b, h, l, l]
274
- ret: [b, h, l, 2*l-1]
275
- """
276
- batch, heads, length, _ = x.size()
277
- # padd along column
278
- x = F.pad(
279
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
280
- )
281
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
282
- # add 0's in the beginning that will skew the elements after reshape
283
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
284
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
285
- return x_final
286
-
287
- def _attention_bias_proximal(self, length):
288
- """Bias for self-attention to encourage attention to close positions.
289
- Args:
290
- length: an integer scalar.
291
- Returns:
292
- a Tensor with shape [1, 1, length, length]
293
- """
294
- r = torch.arange(length, dtype=torch.float32)
295
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
296
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
297
-
298
-
299
- class FFN(nn.Module):
300
- def __init__(
301
- self,
302
- in_channels,
303
- out_channels,
304
- filter_channels,
305
- kernel_size,
306
- p_dropout=0.0,
307
- activation=None,
308
- causal=False,
309
- ):
310
- super().__init__()
311
- self.in_channels = in_channels
312
- self.out_channels = out_channels
313
- self.filter_channels = filter_channels
314
- self.kernel_size = kernel_size
315
- self.p_dropout = p_dropout
316
- self.activation = activation
317
- self.causal = causal
318
-
319
- if causal:
320
- self.padding = self._causal_padding
321
- else:
322
- self.padding = self._same_padding
323
-
324
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
325
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
326
- self.drop = nn.Dropout(p_dropout)
327
-
328
- def forward(self, x, x_mask):
329
- x = self.conv_1(self.padding(x * x_mask))
330
- if self.activation == "gelu":
331
- x = x * torch.sigmoid(1.702 * x)
332
- else:
333
- x = torch.relu(x)
334
- x = self.drop(x)
335
- x = self.conv_2(self.padding(x * x_mask))
336
- return x * x_mask
337
-
338
- def _causal_padding(self, x):
339
- if self.kernel_size == 1:
340
- return x
341
- pad_l = self.kernel_size - 1
342
- pad_r = 0
343
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
344
- x = F.pad(x, commons.convert_pad_shape(padding))
345
- return x
346
-
347
- def _same_padding(self, x):
348
- if self.kernel_size == 1:
349
- return x
350
- pad_l = (self.kernel_size - 1) // 2
351
- pad_r = self.kernel_size // 2
352
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
353
- x = F.pad(x, commons.convert_pad_shape(padding))
354
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/commons.py DELETED
@@ -1,189 +0,0 @@
1
- import math
2
- import torch
3
- from torch.nn import functional as F
4
-
5
-
6
- def init_weights(m, mean=0.0, std=0.01):
7
- classname = m.__class__.__name__
8
- if classname.find("Conv") != -1:
9
- m.weight.data.normal_(mean, std)
10
-
11
-
12
- def get_padding(kernel_size, dilation=1):
13
- return int((kernel_size * dilation - dilation) / 2)
14
-
15
-
16
- def convert_pad_shape(pad_shape):
17
- l = pad_shape[::-1]
18
- pad_shape = [item for sublist in l for item in sublist]
19
- return pad_shape
20
-
21
-
22
- def intersperse(lst, item):
23
- result = [item] * (len(lst) * 2 + 1)
24
- result[1::2] = lst
25
- return result
26
-
27
-
28
- def kl_divergence(m_p, logs_p, m_q, logs_q):
29
- """KL(P||Q)"""
30
- kl = (logs_q - logs_p) - 0.5
31
- kl += (
32
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
- )
34
- return kl
35
-
36
-
37
- def rand_gumbel(shape):
38
- """Sample from the Gumbel distribution, protect from overflows."""
39
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
- return -torch.log(-torch.log(uniform_samples))
41
-
42
-
43
- def rand_gumbel_like(x):
44
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
- return g
46
-
47
-
48
- def slice_segments(x, ids_str, segment_size=4):
49
- ret = torch.zeros_like(x[:, :, :segment_size])
50
- for i in range(x.size(0)):
51
- idx_str = ids_str[i]
52
- idx_end = idx_str + segment_size
53
- ret[i] = x[i, :, idx_str:idx_end]
54
- return ret
55
-
56
-
57
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
- b, d, t = x.size()
59
- if x_lengths is None:
60
- x_lengths = t
61
- ids_str_max = x_lengths - segment_size + 1
62
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
- ret = slice_segments(x, ids_str, segment_size)
64
- return ret, ids_str
65
-
66
-
67
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
- position = torch.arange(length, dtype=torch.float)
69
- num_timescales = channels // 2
70
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
- num_timescales - 1
72
- )
73
- inv_timescales = min_timescale * torch.exp(
74
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
- )
76
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
- signal = F.pad(signal, [0, 0, 0, channels % 2])
79
- signal = signal.view(1, channels, length)
80
- return signal
81
-
82
-
83
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
- b, channels, length = x.size()
85
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
- return x + signal.to(dtype=x.dtype, device=x.device)
87
-
88
-
89
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
- b, channels, length = x.size()
91
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
-
94
-
95
- def subsequent_mask(length):
96
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
- return mask
98
-
99
-
100
- @torch.jit.script
101
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
- n_channels_int = n_channels[0]
103
- in_act = input_a + input_b
104
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
- acts = t_act * s_act
107
- return acts
108
-
109
-
110
- def convert_pad_shape(pad_shape):
111
- l = pad_shape[::-1]
112
- pad_shape = [item for sublist in l for item in sublist]
113
- return pad_shape
114
-
115
-
116
- def shift_1d(x):
117
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
- return x
119
-
120
-
121
- def sequence_mask(length, max_length=None):
122
- if max_length is None:
123
- max_length = length.max()
124
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
- return x.unsqueeze(0) < length.unsqueeze(1)
126
-
127
-
128
- def generate_path(duration, mask):
129
- """
130
- duration: [b, 1, t_x]
131
- mask: [b, 1, t_y, t_x]
132
- """
133
- device = duration.device
134
-
135
- b, _, t_y, t_x = mask.shape
136
- cum_duration = torch.cumsum(duration, -1)
137
-
138
- cum_duration_flat = cum_duration.view(b * t_x)
139
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140
- path = path.view(b, t_x, t_y)
141
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142
- path = path.unsqueeze(1).transpose(2, 3) * mask
143
- return path
144
-
145
-
146
- def clip_grad_value_(parameters, clip_value, norm_type=2):
147
- if isinstance(parameters, torch.Tensor):
148
- parameters = [parameters]
149
- parameters = list(filter(lambda p: p.grad is not None, parameters))
150
- norm_type = float(norm_type)
151
- if clip_value is not None:
152
- clip_value = float(clip_value)
153
-
154
- total_norm = 0
155
- for p in parameters:
156
- param_norm = p.grad.data.norm(norm_type)
157
- total_norm += param_norm.item() ** norm_type
158
- if clip_value is not None:
159
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
- total_norm = total_norm ** (1.0 / norm_type)
161
- return total_norm
162
-
163
-
164
- def squeeze(x, x_mask=None, n_sqz=2):
165
- b, c, t = x.size()
166
-
167
- t = (t // n_sqz) * n_sqz
168
- x = x[:, :, :t]
169
- x_sqz = x.view(b, c, t // n_sqz, n_sqz)
170
- x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
171
-
172
- if x_mask is not None:
173
- x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
174
- else:
175
- x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
176
- return x_sqz * x_mask, x_mask
177
-
178
-
179
- def unsqueeze(x, x_mask=None, n_sqz=2):
180
- b, c, t = x.size()
181
-
182
- x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
183
- x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
184
-
185
- if x_mask is not None:
186
- x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
187
- else:
188
- x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
189
- return x_unsqz * x_mask, x_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/core_vq.py DELETED
@@ -1,383 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- #
7
- # This implementation is inspired from
8
- # https://github.com/lucidrains/vector-quantize-pytorch
9
- # which is released under MIT License. Hereafter, the original license:
10
- # MIT License
11
- #
12
- # Copyright (c) 2020 Phil Wang
13
- #
14
- # Permission is hereby granted, free of charge, to any person obtaining a copy
15
- # of this software and associated documentation files (the "Software"), to deal
16
- # in the Software without restriction, including without limitation the rights
17
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
- # copies of the Software, and to permit persons to whom the Software is
19
- # furnished to do so, subject to the following conditions:
20
- #
21
- # The above copyright notice and this permission notice shall be included in all
22
- # copies or substantial portions of the Software.
23
- #
24
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
- # SOFTWARE.
31
-
32
- """Core vector quantization implementation."""
33
- import typing as tp
34
-
35
- from einops import rearrange, repeat
36
- import torch
37
- from torch import nn
38
- import torch.nn.functional as F
39
- from tqdm import tqdm
40
-
41
-
42
- def default(val: tp.Any, d: tp.Any) -> tp.Any:
43
- return val if val is not None else d
44
-
45
-
46
- def ema_inplace(moving_avg, new, decay: float):
47
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
48
-
49
-
50
- def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
51
- return (x + epsilon) / (x.sum() + n_categories * epsilon)
52
-
53
-
54
- def uniform_init(*shape: int):
55
- t = torch.empty(shape)
56
- nn.init.kaiming_uniform_(t)
57
- return t
58
-
59
-
60
- def sample_vectors(samples, num: int):
61
- num_samples, device = samples.shape[0], samples.device
62
-
63
- if num_samples >= num:
64
- indices = torch.randperm(num_samples, device=device)[:num]
65
- else:
66
- indices = torch.randint(0, num_samples, (num,), device=device)
67
-
68
- return samples[indices]
69
-
70
-
71
- def kmeans(samples, num_clusters: int, num_iters: int = 10):
72
- dim, dtype = samples.shape[-1], samples.dtype
73
- max_kmeans_samples = 500
74
- samples = samples[:max_kmeans_samples, :]
75
- means = sample_vectors(samples, num_clusters)
76
-
77
- print("kmeans start ... ")
78
- for _ in tqdm(range(num_iters)):
79
- diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
- dists = -(diffs**2).sum(dim=-1)
81
-
82
- buckets = dists.max(dim=-1).indices
83
- bins = torch.bincount(buckets, minlength=num_clusters)
84
- zero_mask = bins == 0
85
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
-
87
- new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
- new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
- new_means = new_means / bins_min_clamped[..., None]
90
-
91
- means = torch.where(zero_mask[..., None], means, new_means)
92
-
93
- return means, bins
94
-
95
-
96
- class EuclideanCodebook(nn.Module):
97
- """Codebook with Euclidean distance.
98
- Args:
99
- dim (int): Dimension.
100
- codebook_size (int): Codebook size.
101
- kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
- If set to true, run the k-means algorithm on the first training batch and use
103
- the learned centroids as initialization.
104
- kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
- decay (float): Decay for exponential moving average over the codebooks.
106
- epsilon (float): Epsilon value for numerical stability.
107
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
- that have an exponential moving average cluster size less than the specified threshold with
109
- randomly selected vector from the current batch.
110
- """
111
-
112
- def __init__(
113
- self,
114
- dim: int,
115
- codebook_size: int,
116
- kmeans_init: int = False,
117
- kmeans_iters: int = 10,
118
- decay: float = 0.99,
119
- epsilon: float = 1e-5,
120
- threshold_ema_dead_code: int = 2,
121
- ):
122
- super().__init__()
123
- self.decay = decay
124
- init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
125
- uniform_init if not kmeans_init else torch.zeros
126
- )
127
- embed = init_fn(codebook_size, dim)
128
-
129
- self.codebook_size = codebook_size
130
-
131
- self.kmeans_iters = kmeans_iters
132
- self.epsilon = epsilon
133
- self.threshold_ema_dead_code = threshold_ema_dead_code
134
-
135
- self.register_buffer("inited", torch.Tensor([not kmeans_init]))
136
- self.register_buffer("cluster_size", torch.zeros(codebook_size))
137
- self.register_buffer("embed", embed)
138
- self.register_buffer("embed_avg", embed.clone())
139
-
140
- @torch.jit.ignore
141
- def init_embed_(self, data):
142
- if self.inited:
143
- return
144
-
145
- embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
146
- self.embed.data.copy_(embed)
147
- self.embed_avg.data.copy_(embed.clone())
148
- self.cluster_size.data.copy_(cluster_size)
149
- self.inited.data.copy_(torch.Tensor([True]))
150
- # Make sure all buffers across workers are in sync after initialization
151
- # broadcast_tensors(self.buffers())
152
-
153
- def replace_(self, samples, mask):
154
- modified_codebook = torch.where(
155
- mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
156
- )
157
- self.embed.data.copy_(modified_codebook)
158
-
159
- def expire_codes_(self, batch_samples):
160
- if self.threshold_ema_dead_code == 0:
161
- return
162
-
163
- expired_codes = self.cluster_size < self.threshold_ema_dead_code
164
- if not torch.any(expired_codes):
165
- return
166
-
167
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
168
- self.replace_(batch_samples, mask=expired_codes)
169
- # broadcast_tensors(self.buffers())
170
-
171
- def preprocess(self, x):
172
- x = rearrange(x, "... d -> (...) d")
173
- return x
174
-
175
- def quantize(self, x):
176
- embed = self.embed.t()
177
- dist = -(
178
- x.pow(2).sum(1, keepdim=True)
179
- - 2 * x @ embed
180
- + embed.pow(2).sum(0, keepdim=True)
181
- )
182
- embed_ind = dist.max(dim=-1).indices
183
- return embed_ind
184
-
185
- def postprocess_emb(self, embed_ind, shape):
186
- return embed_ind.view(*shape[:-1])
187
-
188
- def dequantize(self, embed_ind):
189
- quantize = F.embedding(embed_ind, self.embed)
190
- return quantize
191
-
192
- def encode(self, x):
193
- shape = x.shape
194
- # pre-process
195
- x = self.preprocess(x)
196
- # quantize
197
- embed_ind = self.quantize(x)
198
- # post-process
199
- embed_ind = self.postprocess_emb(embed_ind, shape)
200
- return embed_ind
201
-
202
- def decode(self, embed_ind):
203
- quantize = self.dequantize(embed_ind)
204
- return quantize
205
-
206
- def forward(self, x):
207
- shape, dtype = x.shape, x.dtype
208
- x = self.preprocess(x)
209
-
210
- self.init_embed_(x)
211
-
212
- embed_ind = self.quantize(x)
213
- embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
214
- embed_ind = self.postprocess_emb(embed_ind, shape)
215
- quantize = self.dequantize(embed_ind)
216
-
217
- if self.training:
218
- # We do the expiry of code at that point as buffers are in sync
219
- # and all the workers will take the same decision.
220
- self.expire_codes_(x)
221
- ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
222
- embed_sum = x.t() @ embed_onehot
223
- ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
224
- cluster_size = (
225
- laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
226
- * self.cluster_size.sum()
227
- )
228
- embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
229
- self.embed.data.copy_(embed_normalized)
230
-
231
- return quantize, embed_ind
232
-
233
-
234
- class VectorQuantization(nn.Module):
235
- """Vector quantization implementation.
236
- Currently supports only euclidean distance.
237
- Args:
238
- dim (int): Dimension
239
- codebook_size (int): Codebook size
240
- codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
241
- decay (float): Decay for exponential moving average over the codebooks.
242
- epsilon (float): Epsilon value for numerical stability.
243
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
244
- kmeans_iters (int): Number of iterations used for kmeans initialization.
245
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
246
- that have an exponential moving average cluster size less than the specified threshold with
247
- randomly selected vector from the current batch.
248
- commitment_weight (float): Weight for commitment loss.
249
- """
250
-
251
- def __init__(
252
- self,
253
- dim: int,
254
- codebook_size: int,
255
- codebook_dim: tp.Optional[int] = None,
256
- decay: float = 0.99,
257
- epsilon: float = 1e-5,
258
- kmeans_init: bool = True,
259
- kmeans_iters: int = 50,
260
- threshold_ema_dead_code: int = 2,
261
- commitment_weight: float = 1.0,
262
- ):
263
- super().__init__()
264
- _codebook_dim: int = default(codebook_dim, dim)
265
-
266
- requires_projection = _codebook_dim != dim
267
- self.project_in = (
268
- nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
269
- )
270
- self.project_out = (
271
- nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
272
- )
273
-
274
- self.epsilon = epsilon
275
- self.commitment_weight = commitment_weight
276
-
277
- self._codebook = EuclideanCodebook(
278
- dim=_codebook_dim,
279
- codebook_size=codebook_size,
280
- kmeans_init=kmeans_init,
281
- kmeans_iters=kmeans_iters,
282
- decay=decay,
283
- epsilon=epsilon,
284
- threshold_ema_dead_code=threshold_ema_dead_code,
285
- )
286
- self.codebook_size = codebook_size
287
-
288
- @property
289
- def codebook(self):
290
- return self._codebook.embed
291
-
292
- def encode(self, x):
293
- x = rearrange(x, "b d n -> b n d")
294
- x = self.project_in(x)
295
- embed_in = self._codebook.encode(x)
296
- return embed_in
297
-
298
- def decode(self, embed_ind):
299
- quantize = self._codebook.decode(embed_ind)
300
- quantize = self.project_out(quantize)
301
- quantize = rearrange(quantize, "b n d -> b d n")
302
- return quantize
303
-
304
- def forward(self, x):
305
- device = x.device
306
- x = rearrange(x, "b d n -> b n d")
307
- x = self.project_in(x)
308
-
309
- quantize, embed_ind = self._codebook(x)
310
-
311
- if self.training:
312
- quantize = x + (quantize - x).detach()
313
-
314
- loss = torch.tensor([0.0], device=device, requires_grad=self.training)
315
-
316
- if self.training:
317
- if self.commitment_weight > 0:
318
- commit_loss = F.mse_loss(quantize.detach(), x)
319
- loss = loss + commit_loss * self.commitment_weight
320
-
321
- quantize = self.project_out(quantize)
322
- quantize = rearrange(quantize, "b n d -> b d n")
323
- return quantize, embed_ind, loss
324
-
325
-
326
- class ResidualVectorQuantization(nn.Module):
327
- """Residual vector quantization implementation.
328
- Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
329
- """
330
-
331
- def __init__(self, *, num_quantizers, **kwargs):
332
- super().__init__()
333
- self.layers = nn.ModuleList(
334
- [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
335
- )
336
-
337
- def forward(
338
- self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
339
- ):
340
- quantized_out = 0.0
341
- residual = x
342
-
343
- all_losses = []
344
- all_indices = []
345
- out_quantized = []
346
-
347
- n_q = n_q or len(self.layers)
348
-
349
- for i, layer in enumerate(self.layers[:n_q]):
350
- quantized, indices, loss = layer(residual)
351
- residual = residual - quantized
352
- quantized_out = quantized_out + quantized
353
-
354
- all_indices.append(indices)
355
- all_losses.append(loss)
356
- if layers and i in layers:
357
- out_quantized.append(quantized)
358
-
359
- out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
360
- return quantized_out, out_indices, out_losses, out_quantized
361
-
362
- def encode(
363
- self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
364
- ) -> torch.Tensor:
365
- residual = x
366
- all_indices = []
367
- n_q = n_q or len(self.layers)
368
- st = st or 0
369
- for layer in self.layers[st:n_q]:
370
- indices = layer.encode(residual)
371
- quantized = layer.decode(indices)
372
- residual = residual - quantized
373
- all_indices.append(indices)
374
- out_indices = torch.stack(all_indices)
375
- return out_indices
376
-
377
- def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
378
- quantized_out = torch.tensor(0.0, device=q_indices.device)
379
- for i, indices in enumerate(q_indices):
380
- layer = self.layers[st + i]
381
- quantized = layer.decode(indices)
382
- quantized_out = quantized_out + quantized
383
- return quantized_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/data_utils.py DELETED
@@ -1,332 +0,0 @@
1
- import time
2
- import logging
3
- import os
4
- import random
5
- import traceback
6
- import numpy as np
7
- import torch
8
- import torch.utils.data
9
- from tqdm import tqdm
10
-
11
- from module import commons
12
- from module.mel_processing import spectrogram_torch
13
- from text import cleaned_text_to_sequence
14
- from utils import load_wav_to_torch, load_filepaths_and_text
15
- import torch.nn.functional as F
16
- from functools import lru_cache
17
- import requests
18
- from scipy.io import wavfile
19
- from io import BytesIO
20
- from tools.my_utils import load_audio
21
- version = os.environ.get('version',None)
22
- # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
23
- class TextAudioSpeakerLoader(torch.utils.data.Dataset):
24
- """
25
- 1) loads audio, speaker_id, text pairs
26
- 2) normalizes text and converts them to sequences of integers
27
- 3) computes spectrograms from audio files.
28
- """
29
-
30
- def __init__(self, hparams, val=False):
31
- exp_dir = hparams.exp_dir
32
- self.path2 = "%s/2-name2text.txt" % exp_dir
33
- self.path4 = "%s/4-cnhubert" % exp_dir
34
- self.path5 = "%s/5-wav32k" % exp_dir
35
- assert os.path.exists(self.path2)
36
- assert os.path.exists(self.path4)
37
- assert os.path.exists(self.path5)
38
- names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
39
- names5 = set(os.listdir(self.path5))
40
- self.phoneme_data = {}
41
- with open(self.path2, "r", encoding="utf8") as f:
42
- lines = f.read().strip("\n").split("\n")
43
-
44
- for line in lines:
45
- tmp = line.split("\t")
46
- if (len(tmp) != 4):
47
- continue
48
- self.phoneme_data[tmp[0]] = [tmp[1]]
49
-
50
- self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
51
- tmp = self.audiopaths_sid_text
52
- leng = len(tmp)
53
- min_num = 100
54
- if (leng < min_num):
55
- self.audiopaths_sid_text = []
56
- for _ in range(max(2, int(min_num / leng))):
57
- self.audiopaths_sid_text += tmp
58
- self.max_wav_value = hparams.max_wav_value
59
- self.sampling_rate = hparams.sampling_rate
60
- self.filter_length = hparams.filter_length
61
- self.hop_length = hparams.hop_length
62
- self.win_length = hparams.win_length
63
- self.sampling_rate = hparams.sampling_rate
64
- self.val = val
65
-
66
- random.seed(1234)
67
- random.shuffle(self.audiopaths_sid_text)
68
-
69
- print("phoneme_data_len:", len(self.phoneme_data.keys()))
70
- print("wav_data_len:", len(self.audiopaths_sid_text))
71
-
72
- audiopaths_sid_text_new = []
73
- lengths = []
74
- skipped_phone = 0
75
- skipped_dur = 0
76
- for audiopath in tqdm(self.audiopaths_sid_text):
77
- try:
78
- phoneme = self.phoneme_data[audiopath][0]
79
- phoneme = phoneme.split(' ')
80
- phoneme_ids = cleaned_text_to_sequence(phoneme, version)
81
- except Exception:
82
- print(f"{audiopath} not in self.phoneme_data !")
83
- skipped_phone += 1
84
- continue
85
-
86
- size = os.path.getsize("%s/%s" % (self.path5, audiopath))
87
- duration = size / self.sampling_rate / 2
88
-
89
- if duration == 0:
90
- print(f"Zero duration for {audiopath}, skipping...")
91
- skipped_dur += 1
92
- continue
93
-
94
- if 54 > duration > 0.6 or self.val:
95
- audiopaths_sid_text_new.append([audiopath, phoneme_ids])
96
- lengths.append(size // (2 * self.hop_length))
97
- else:
98
- skipped_dur += 1
99
- continue
100
-
101
- print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
102
- print("total left: ", len(audiopaths_sid_text_new))
103
- assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
104
- self.audiopaths_sid_text = audiopaths_sid_text_new
105
- self.lengths = lengths
106
-
107
- def get_audio_text_speaker_pair(self, audiopath_sid_text):
108
- audiopath, phoneme_ids = audiopath_sid_text
109
- text = torch.FloatTensor(phoneme_ids)
110
- try:
111
- spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
112
- with torch.no_grad():
113
- ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
114
- if (ssl.shape[-1] != spec.shape[-1]):
115
- typee = ssl.dtype
116
- ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
117
- ssl.requires_grad = False
118
- except:
119
- traceback.print_exc()
120
- spec = torch.zeros(1025, 100)
121
- wav = torch.zeros(1, 100 * self.hop_length)
122
- ssl = torch.zeros(1, 768, 100)
123
- text = text[-1:]
124
- print("load audio or ssl error!!!!!!", audiopath)
125
- return (ssl, spec, wav, text)
126
-
127
- def get_audio(self, filename):
128
- audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
129
- audio = torch.FloatTensor(audio_array) # /32768
130
- audio_norm = audio
131
- audio_norm = audio_norm.unsqueeze(0)
132
- spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
133
- center=False)
134
- spec = torch.squeeze(spec, 0)
135
- return spec, audio_norm
136
-
137
- def get_sid(self, sid):
138
- sid = torch.LongTensor([int(sid)])
139
- return sid
140
-
141
- def __getitem__(self, index):
142
- # with torch.no_grad():
143
- return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
144
-
145
- def __len__(self):
146
- return len(self.audiopaths_sid_text)
147
-
148
- def random_slice(self, ssl, wav, mel):
149
- assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
150
- "first", ssl.shape, wav.shape)
151
-
152
- len_mel = mel.shape[1]
153
- if self.val:
154
- reference_mel = mel[:, :len_mel // 3]
155
- return reference_mel, ssl, wav, mel
156
- dir = random.randint(0, 1)
157
- sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
158
-
159
- if dir == 0:
160
- reference_mel = mel[:, :sep_point]
161
- ssl = ssl[:, :, sep_point:]
162
- wav2 = wav[:, sep_point * self.hop_length:]
163
- mel = mel[:, sep_point:]
164
- else:
165
- reference_mel = mel[:, sep_point:]
166
- ssl = ssl[:, :, :sep_point]
167
- wav2 = wav[:, :sep_point * self.hop_length]
168
- mel = mel[:, :sep_point]
169
-
170
- assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
171
- ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
172
- return reference_mel, ssl, wav2, mel
173
-
174
-
175
- class TextAudioSpeakerCollate():
176
- """ Zero-pads model inputs and targets
177
- """
178
-
179
- def __init__(self, return_ids=False):
180
- self.return_ids = return_ids
181
-
182
- def __call__(self, batch):
183
- """Collate's training batch from normalized text, audio and speaker identities
184
- PARAMS
185
- ------
186
- batch: [text_normalized, spec_normalized, wav_normalized, sid]
187
- """
188
- # Right zero-pad all one-hot text sequences to max input length
189
- _, ids_sorted_decreasing = torch.sort(
190
- torch.LongTensor([x[1].size(1) for x in batch]),
191
- dim=0, descending=True)
192
-
193
- max_ssl_len = max([x[0].size(2) for x in batch])
194
- max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
195
- max_spec_len = max([x[1].size(1) for x in batch])
196
- max_spec_len = int(2 * ((max_spec_len // 2) + 1))
197
- max_wav_len = max([x[2].size(1) for x in batch])
198
- max_text_len = max([x[3].size(0) for x in batch])
199
-
200
- ssl_lengths = torch.LongTensor(len(batch))
201
- spec_lengths = torch.LongTensor(len(batch))
202
- wav_lengths = torch.LongTensor(len(batch))
203
- text_lengths = torch.LongTensor(len(batch))
204
-
205
- spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
206
- wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
207
- ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
208
- text_padded = torch.LongTensor(len(batch), max_text_len)
209
-
210
- spec_padded.zero_()
211
- wav_padded.zero_()
212
- ssl_padded.zero_()
213
- text_padded.zero_()
214
-
215
- for i in range(len(ids_sorted_decreasing)):
216
- row = batch[ids_sorted_decreasing[i]]
217
-
218
- ssl = row[0]
219
- ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
220
- ssl_lengths[i] = ssl.size(2)
221
-
222
- spec = row[1]
223
- spec_padded[i, :, :spec.size(1)] = spec
224
- spec_lengths[i] = spec.size(1)
225
-
226
- wav = row[2]
227
- wav_padded[i, :, :wav.size(1)] = wav
228
- wav_lengths[i] = wav.size(1)
229
-
230
- text = row[3]
231
- text_padded[i, :text.size(0)] = text
232
- text_lengths[i] = text.size(0)
233
-
234
- return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
235
-
236
-
237
- class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
238
- """
239
- Maintain similar input lengths in a batch.
240
- Length groups are specified by boundaries.
241
- Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
242
-
243
- It removes samples which are not included in the boundaries.
244
- Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
245
- """
246
-
247
- def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
248
- super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
249
- self.lengths = dataset.lengths
250
- self.batch_size = batch_size
251
- self.boundaries = boundaries
252
-
253
- self.buckets, self.num_samples_per_bucket = self._create_buckets()
254
- self.total_size = sum(self.num_samples_per_bucket)
255
- self.num_samples = self.total_size // self.num_replicas
256
-
257
- def _create_buckets(self):
258
- buckets = [[] for _ in range(len(self.boundaries) - 1)]
259
- for i in range(len(self.lengths)):
260
- length = self.lengths[i]
261
- idx_bucket = self._bisect(length)
262
- if idx_bucket != -1:
263
- buckets[idx_bucket].append(i)
264
-
265
- i = len(buckets) - 1
266
- while i >= 0:
267
- if len(buckets[i]) == 0:
268
- buckets.pop(i)
269
- self.boundaries.pop(i + 1)
270
- i -= 1
271
-
272
- num_samples_per_bucket = []
273
- for i in range(len(buckets)):
274
- len_bucket = len(buckets[i])
275
- total_batch_size = self.num_replicas * self.batch_size
276
- rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
277
- num_samples_per_bucket.append(len_bucket + rem)
278
- return buckets, num_samples_per_bucket
279
-
280
- def __iter__(self):
281
- g = torch.Generator()
282
- g.manual_seed(self.epoch)
283
-
284
- indices = []
285
- if self.shuffle:
286
- for bucket in self.buckets:
287
- indices.append(torch.randperm(len(bucket), generator=g).tolist())
288
- else:
289
- for bucket in self.buckets:
290
- indices.append(list(range(len(bucket))))
291
-
292
- batches = []
293
- for i in range(len(self.buckets)):
294
- bucket = self.buckets[i]
295
- len_bucket = len(bucket)
296
- ids_bucket = indices[i]
297
- num_samples_bucket = self.num_samples_per_bucket[i]
298
-
299
- rem = num_samples_bucket - len_bucket
300
- ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
301
-
302
- ids_bucket = ids_bucket[self.rank::self.num_replicas]
303
-
304
- for j in range(len(ids_bucket) // self.batch_size):
305
- batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
306
- batches.append(batch)
307
-
308
- if self.shuffle:
309
- batch_ids = torch.randperm(len(batches), generator=g).tolist()
310
- batches = [batches[i] for i in batch_ids]
311
- self.batches = batches
312
-
313
- assert len(self.batches) * self.batch_size == self.num_samples
314
- return iter(self.batches)
315
-
316
- def _bisect(self, x, lo=0, hi=None):
317
- if hi is None:
318
- hi = len(self.boundaries) - 1
319
-
320
- if hi > lo:
321
- mid = (hi + lo) // 2
322
- if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
323
- return mid
324
- elif x <= self.boundaries[mid]:
325
- return self._bisect(x, lo, mid)
326
- else:
327
- return self._bisect(x, mid + 1, hi)
328
- else:
329
- return -1
330
-
331
- def __len__(self):
332
- return self.num_samples // self.batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/losses.py DELETED
@@ -1,73 +0,0 @@
1
- import math
2
-
3
- import torch
4
- from torch.nn import functional as F
5
-
6
-
7
- def feature_loss(fmap_r, fmap_g):
8
- loss = 0
9
- for dr, dg in zip(fmap_r, fmap_g):
10
- for rl, gl in zip(dr, dg):
11
- rl = rl.float().detach()
12
- gl = gl.float()
13
- loss += torch.mean(torch.abs(rl - gl))
14
-
15
- return loss * 2
16
-
17
-
18
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19
- loss = 0
20
- r_losses = []
21
- g_losses = []
22
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23
- dr = dr.float()
24
- dg = dg.float()
25
- r_loss = torch.mean((1 - dr) ** 2)
26
- g_loss = torch.mean(dg**2)
27
- loss += r_loss + g_loss
28
- r_losses.append(r_loss.item())
29
- g_losses.append(g_loss.item())
30
-
31
- return loss, r_losses, g_losses
32
-
33
-
34
- def generator_loss(disc_outputs):
35
- loss = 0
36
- gen_losses = []
37
- for dg in disc_outputs:
38
- dg = dg.float()
39
- l = torch.mean((1 - dg) ** 2)
40
- gen_losses.append(l)
41
- loss += l
42
-
43
- return loss, gen_losses
44
-
45
-
46
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47
- """
48
- z_p, logs_q: [b, h, t_t]
49
- m_p, logs_p: [b, h, t_t]
50
- """
51
- z_p = z_p.float()
52
- logs_q = logs_q.float()
53
- m_p = m_p.float()
54
- logs_p = logs_p.float()
55
- z_mask = z_mask.float()
56
-
57
- kl = logs_p - logs_q - 0.5
58
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
59
- kl = torch.sum(kl * z_mask)
60
- l = kl / torch.sum(z_mask)
61
- return l
62
-
63
-
64
- def mle_loss(z, m, logs, logdet, mask):
65
- l = torch.sum(logs) + 0.5 * torch.sum(
66
- torch.exp(-2 * logs) * ((z - m) ** 2)
67
- ) # neg normal likelihood w/o the constant term
68
- l = l - torch.sum(logdet) # log jacobian determinant
69
- l = l / torch.sum(
70
- torch.ones_like(z) * mask
71
- ) # averaging across batch, channel and time axes
72
- l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
73
- return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/mel_processing.py DELETED
@@ -1,153 +0,0 @@
1
- import math
2
- import os
3
- import random
4
- import torch
5
- from torch import nn
6
- import torch.nn.functional as F
7
- import torch.utils.data
8
- import numpy as np
9
- import librosa
10
- import librosa.util as librosa_util
11
- from librosa.util import normalize, pad_center, tiny
12
- from scipy.signal import get_window
13
- from scipy.io.wavfile import read
14
- from librosa.filters import mel as librosa_mel_fn
15
-
16
- MAX_WAV_VALUE = 32768.0
17
-
18
-
19
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
- """
21
- PARAMS
22
- ------
23
- C: compression factor
24
- """
25
- return torch.log(torch.clamp(x, min=clip_val) * C)
26
-
27
-
28
- def dynamic_range_decompression_torch(x, C=1):
29
- """
30
- PARAMS
31
- ------
32
- C: compression factor used to compress
33
- """
34
- return torch.exp(x) / C
35
-
36
-
37
- def spectral_normalize_torch(magnitudes):
38
- output = dynamic_range_compression_torch(magnitudes)
39
- return output
40
-
41
-
42
- def spectral_de_normalize_torch(magnitudes):
43
- output = dynamic_range_decompression_torch(magnitudes)
44
- return output
45
-
46
-
47
- mel_basis = {}
48
- hann_window = {}
49
-
50
-
51
- def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
- if torch.min(y) < -1.0:
53
- print("min value is ", torch.min(y))
54
- if torch.max(y) > 1.0:
55
- print("max value is ", torch.max(y))
56
-
57
- global hann_window
58
- dtype_device = str(y.dtype) + "_" + str(y.device)
59
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
60
- if wnsize_dtype_device not in hann_window:
61
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
62
- dtype=y.dtype, device=y.device
63
- )
64
-
65
- y = torch.nn.functional.pad(
66
- y.unsqueeze(1),
67
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
68
- mode="reflect",
69
- )
70
- y = y.squeeze(1)
71
- spec = torch.stft(
72
- y,
73
- n_fft,
74
- hop_length=hop_size,
75
- win_length=win_size,
76
- window=hann_window[wnsize_dtype_device],
77
- center=center,
78
- pad_mode="reflect",
79
- normalized=False,
80
- onesided=True,
81
- return_complex=False,
82
- )
83
-
84
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
85
- return spec
86
-
87
-
88
- def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
89
- global mel_basis
90
- dtype_device = str(spec.dtype) + "_" + str(spec.device)
91
- fmax_dtype_device = str(fmax) + "_" + dtype_device
92
- if fmax_dtype_device not in mel_basis:
93
- mel = librosa_mel_fn(
94
- sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
95
- )
96
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
97
- dtype=spec.dtype, device=spec.device
98
- )
99
- spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
100
- spec = spectral_normalize_torch(spec)
101
- return spec
102
-
103
-
104
- def mel_spectrogram_torch(
105
- y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
106
- ):
107
- if torch.min(y) < -1.0:
108
- print("min value is ", torch.min(y))
109
- if torch.max(y) > 1.0:
110
- print("max value is ", torch.max(y))
111
-
112
- global mel_basis, hann_window
113
- dtype_device = str(y.dtype) + "_" + str(y.device)
114
- fmax_dtype_device = str(fmax) + "_" + dtype_device
115
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
116
- if fmax_dtype_device not in mel_basis:
117
- mel = librosa_mel_fn(
118
- sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
119
- )
120
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
121
- dtype=y.dtype, device=y.device
122
- )
123
- if wnsize_dtype_device not in hann_window:
124
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
125
- dtype=y.dtype, device=y.device
126
- )
127
-
128
- y = torch.nn.functional.pad(
129
- y.unsqueeze(1),
130
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
131
- mode="reflect",
132
- )
133
- y = y.squeeze(1)
134
-
135
- spec = torch.stft(
136
- y,
137
- n_fft,
138
- hop_length=hop_size,
139
- win_length=win_size,
140
- window=hann_window[wnsize_dtype_device],
141
- center=center,
142
- pad_mode="reflect",
143
- normalized=False,
144
- onesided=True,
145
- return_complex=False,
146
- )
147
-
148
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
149
-
150
- spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
151
- spec = spectral_normalize_torch(spec)
152
-
153
- return spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/models.py DELETED
@@ -1,1040 +0,0 @@
1
- import warnings
2
- warnings.filterwarnings("ignore")
3
- import copy
4
- import math
5
- import os
6
- import pdb
7
-
8
- import torch
9
- from torch import nn
10
- from torch.nn import functional as F
11
-
12
- from module import commons
13
- from module import modules
14
- from module import attentions
15
-
16
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
17
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
18
- from module.commons import init_weights, get_padding
19
- from module.mrte_model import MRTE
20
- from module.quantize import ResidualVectorQuantizer
21
- # from text import symbols
22
- from text import symbols as symbols_v1
23
- from text import symbols2 as symbols_v2
24
- from torch.cuda.amp import autocast
25
- import contextlib
26
-
27
-
28
- class StochasticDurationPredictor(nn.Module):
29
- def __init__(
30
- self,
31
- in_channels,
32
- filter_channels,
33
- kernel_size,
34
- p_dropout,
35
- n_flows=4,
36
- gin_channels=0,
37
- ):
38
- super().__init__()
39
- filter_channels = in_channels # it needs to be removed from future version.
40
- self.in_channels = in_channels
41
- self.filter_channels = filter_channels
42
- self.kernel_size = kernel_size
43
- self.p_dropout = p_dropout
44
- self.n_flows = n_flows
45
- self.gin_channels = gin_channels
46
-
47
- self.log_flow = modules.Log()
48
- self.flows = nn.ModuleList()
49
- self.flows.append(modules.ElementwiseAffine(2))
50
- for i in range(n_flows):
51
- self.flows.append(
52
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
53
- )
54
- self.flows.append(modules.Flip())
55
-
56
- self.post_pre = nn.Conv1d(1, filter_channels, 1)
57
- self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
58
- self.post_convs = modules.DDSConv(
59
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
60
- )
61
- self.post_flows = nn.ModuleList()
62
- self.post_flows.append(modules.ElementwiseAffine(2))
63
- for i in range(4):
64
- self.post_flows.append(
65
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
66
- )
67
- self.post_flows.append(modules.Flip())
68
-
69
- self.pre = nn.Conv1d(in_channels, filter_channels, 1)
70
- self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
71
- self.convs = modules.DDSConv(
72
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
73
- )
74
- if gin_channels != 0:
75
- self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
76
-
77
- def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
78
- x = torch.detach(x)
79
- x = self.pre(x)
80
- if g is not None:
81
- g = torch.detach(g)
82
- x = x + self.cond(g)
83
- x = self.convs(x, x_mask)
84
- x = self.proj(x) * x_mask
85
-
86
- if not reverse:
87
- flows = self.flows
88
- assert w is not None
89
-
90
- logdet_tot_q = 0
91
- h_w = self.post_pre(w)
92
- h_w = self.post_convs(h_w, x_mask)
93
- h_w = self.post_proj(h_w) * x_mask
94
- e_q = (
95
- torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
96
- * x_mask
97
- )
98
- z_q = e_q
99
- for flow in self.post_flows:
100
- z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
101
- logdet_tot_q += logdet_q
102
- z_u, z1 = torch.split(z_q, [1, 1], 1)
103
- u = torch.sigmoid(z_u) * x_mask
104
- z0 = (w - u) * x_mask
105
- logdet_tot_q += torch.sum(
106
- (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
107
- )
108
- logq = (
109
- torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
110
- - logdet_tot_q
111
- )
112
-
113
- logdet_tot = 0
114
- z0, logdet = self.log_flow(z0, x_mask)
115
- logdet_tot += logdet
116
- z = torch.cat([z0, z1], 1)
117
- for flow in flows:
118
- z, logdet = flow(z, x_mask, g=x, reverse=reverse)
119
- logdet_tot = logdet_tot + logdet
120
- nll = (
121
- torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
122
- - logdet_tot
123
- )
124
- return nll + logq # [b]
125
- else:
126
- flows = list(reversed(self.flows))
127
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
128
- z = (
129
- torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
130
- * noise_scale
131
- )
132
- for flow in flows:
133
- z = flow(z, x_mask, g=x, reverse=reverse)
134
- z0, z1 = torch.split(z, [1, 1], 1)
135
- logw = z0
136
- return logw
137
-
138
-
139
- class DurationPredictor(nn.Module):
140
- def __init__(
141
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
142
- ):
143
- super().__init__()
144
-
145
- self.in_channels = in_channels
146
- self.filter_channels = filter_channels
147
- self.kernel_size = kernel_size
148
- self.p_dropout = p_dropout
149
- self.gin_channels = gin_channels
150
-
151
- self.drop = nn.Dropout(p_dropout)
152
- self.conv_1 = nn.Conv1d(
153
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
154
- )
155
- self.norm_1 = modules.LayerNorm(filter_channels)
156
- self.conv_2 = nn.Conv1d(
157
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
158
- )
159
- self.norm_2 = modules.LayerNorm(filter_channels)
160
- self.proj = nn.Conv1d(filter_channels, 1, 1)
161
-
162
- if gin_channels != 0:
163
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
164
-
165
- def forward(self, x, x_mask, g=None):
166
- x = torch.detach(x)
167
- if g is not None:
168
- g = torch.detach(g)
169
- x = x + self.cond(g)
170
- x = self.conv_1(x * x_mask)
171
- x = torch.relu(x)
172
- x = self.norm_1(x)
173
- x = self.drop(x)
174
- x = self.conv_2(x * x_mask)
175
- x = torch.relu(x)
176
- x = self.norm_2(x)
177
- x = self.drop(x)
178
- x = self.proj(x * x_mask)
179
- return x * x_mask
180
-
181
-
182
- class TextEncoder(nn.Module):
183
- def __init__(
184
- self,
185
- out_channels,
186
- hidden_channels,
187
- filter_channels,
188
- n_heads,
189
- n_layers,
190
- kernel_size,
191
- p_dropout,
192
- latent_channels=192,
193
- version = "v2",
194
- ):
195
- super().__init__()
196
- self.out_channels = out_channels
197
- self.hidden_channels = hidden_channels
198
- self.filter_channels = filter_channels
199
- self.n_heads = n_heads
200
- self.n_layers = n_layers
201
- self.kernel_size = kernel_size
202
- self.p_dropout = p_dropout
203
- self.latent_channels = latent_channels
204
- self.version = version
205
-
206
- self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
207
-
208
- self.encoder_ssl = attentions.Encoder(
209
- hidden_channels,
210
- filter_channels,
211
- n_heads,
212
- n_layers // 2,
213
- kernel_size,
214
- p_dropout,
215
- )
216
-
217
- self.encoder_text = attentions.Encoder(
218
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
219
- )
220
-
221
- if self.version == "v1":
222
- symbols = symbols_v1.symbols
223
- else:
224
- symbols = symbols_v2.symbols
225
- self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
226
-
227
- self.mrte = MRTE()
228
-
229
- self.encoder2 = attentions.Encoder(
230
- hidden_channels,
231
- filter_channels,
232
- n_heads,
233
- n_layers // 2,
234
- kernel_size,
235
- p_dropout,
236
- )
237
-
238
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
239
-
240
- def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
241
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
242
- y.dtype
243
- )
244
-
245
- y = self.ssl_proj(y * y_mask) * y_mask
246
-
247
- y = self.encoder_ssl(y * y_mask, y_mask)
248
-
249
- text_mask = torch.unsqueeze(
250
- commons.sequence_mask(text_lengths, text.size(1)), 1
251
- ).to(y.dtype)
252
- if test == 1:
253
- text[:, :] = 0
254
- text = self.text_embedding(text).transpose(1, 2)
255
- text = self.encoder_text(text * text_mask, text_mask)
256
- y = self.mrte(y, y_mask, text, text_mask, ge)
257
- y = self.encoder2(y * y_mask, y_mask)
258
- if(speed!=1):
259
- y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
260
- y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
261
- stats = self.proj(y) * y_mask
262
- m, logs = torch.split(stats, self.out_channels, dim=1)
263
- return y, m, logs, y_mask
264
-
265
- def extract_latent(self, x):
266
- x = self.ssl_proj(x)
267
- quantized, codes, commit_loss, quantized_list = self.quantizer(x)
268
- return codes.transpose(0, 1)
269
-
270
- def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
271
- quantized = self.quantizer.decode(codes)
272
-
273
- y = self.vq_proj(quantized) * y_mask
274
- y = self.encoder_ssl(y * y_mask, y_mask)
275
-
276
- y = self.mrte(y, y_mask, refer, refer_mask, ge)
277
-
278
- y = self.encoder2(y * y_mask, y_mask)
279
-
280
- stats = self.proj(y) * y_mask
281
- m, logs = torch.split(stats, self.out_channels, dim=1)
282
- return y, m, logs, y_mask, quantized
283
-
284
-
285
- class ResidualCouplingBlock(nn.Module):
286
- def __init__(
287
- self,
288
- channels,
289
- hidden_channels,
290
- kernel_size,
291
- dilation_rate,
292
- n_layers,
293
- n_flows=4,
294
- gin_channels=0,
295
- ):
296
- super().__init__()
297
- self.channels = channels
298
- self.hidden_channels = hidden_channels
299
- self.kernel_size = kernel_size
300
- self.dilation_rate = dilation_rate
301
- self.n_layers = n_layers
302
- self.n_flows = n_flows
303
- self.gin_channels = gin_channels
304
-
305
- self.flows = nn.ModuleList()
306
- for i in range(n_flows):
307
- self.flows.append(
308
- modules.ResidualCouplingLayer(
309
- channels,
310
- hidden_channels,
311
- kernel_size,
312
- dilation_rate,
313
- n_layers,
314
- gin_channels=gin_channels,
315
- mean_only=True,
316
- )
317
- )
318
- self.flows.append(modules.Flip())
319
-
320
- def forward(self, x, x_mask, g=None, reverse=False):
321
- if not reverse:
322
- for flow in self.flows:
323
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
324
- else:
325
- for flow in reversed(self.flows):
326
- x = flow(x, x_mask, g=g, reverse=reverse)
327
- return x
328
-
329
-
330
- class PosteriorEncoder(nn.Module):
331
- def __init__(
332
- self,
333
- in_channels,
334
- out_channels,
335
- hidden_channels,
336
- kernel_size,
337
- dilation_rate,
338
- n_layers,
339
- gin_channels=0,
340
- ):
341
- super().__init__()
342
- self.in_channels = in_channels
343
- self.out_channels = out_channels
344
- self.hidden_channels = hidden_channels
345
- self.kernel_size = kernel_size
346
- self.dilation_rate = dilation_rate
347
- self.n_layers = n_layers
348
- self.gin_channels = gin_channels
349
-
350
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
351
- self.enc = modules.WN(
352
- hidden_channels,
353
- kernel_size,
354
- dilation_rate,
355
- n_layers,
356
- gin_channels=gin_channels,
357
- )
358
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
359
-
360
- def forward(self, x, x_lengths, g=None):
361
- if g != None:
362
- g = g.detach()
363
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
364
- x.dtype
365
- )
366
- x = self.pre(x) * x_mask
367
- x = self.enc(x, x_mask, g=g)
368
- stats = self.proj(x) * x_mask
369
- m, logs = torch.split(stats, self.out_channels, dim=1)
370
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
371
- return z, m, logs, x_mask
372
-
373
-
374
- class WNEncoder(nn.Module):
375
- def __init__(
376
- self,
377
- in_channels,
378
- out_channels,
379
- hidden_channels,
380
- kernel_size,
381
- dilation_rate,
382
- n_layers,
383
- gin_channels=0,
384
- ):
385
- super().__init__()
386
- self.in_channels = in_channels
387
- self.out_channels = out_channels
388
- self.hidden_channels = hidden_channels
389
- self.kernel_size = kernel_size
390
- self.dilation_rate = dilation_rate
391
- self.n_layers = n_layers
392
- self.gin_channels = gin_channels
393
-
394
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
395
- self.enc = modules.WN(
396
- hidden_channels,
397
- kernel_size,
398
- dilation_rate,
399
- n_layers,
400
- gin_channels=gin_channels,
401
- )
402
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
403
- self.norm = modules.LayerNorm(out_channels)
404
-
405
- def forward(self, x, x_lengths, g=None):
406
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
407
- x.dtype
408
- )
409
- x = self.pre(x) * x_mask
410
- x = self.enc(x, x_mask, g=g)
411
- out = self.proj(x) * x_mask
412
- out = self.norm(out)
413
- return out
414
-
415
-
416
- class Generator(torch.nn.Module):
417
- def __init__(
418
- self,
419
- initial_channel,
420
- resblock,
421
- resblock_kernel_sizes,
422
- resblock_dilation_sizes,
423
- upsample_rates,
424
- upsample_initial_channel,
425
- upsample_kernel_sizes,
426
- gin_channels=0,
427
- ):
428
- super(Generator, self).__init__()
429
- self.num_kernels = len(resblock_kernel_sizes)
430
- self.num_upsamples = len(upsample_rates)
431
- self.conv_pre = Conv1d(
432
- initial_channel, upsample_initial_channel, 7, 1, padding=3
433
- )
434
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
435
-
436
- self.ups = nn.ModuleList()
437
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
438
- self.ups.append(
439
- weight_norm(
440
- ConvTranspose1d(
441
- upsample_initial_channel // (2**i),
442
- upsample_initial_channel // (2 ** (i + 1)),
443
- k,
444
- u,
445
- padding=(k - u) // 2,
446
- )
447
- )
448
- )
449
-
450
- self.resblocks = nn.ModuleList()
451
- for i in range(len(self.ups)):
452
- ch = upsample_initial_channel // (2 ** (i + 1))
453
- for j, (k, d) in enumerate(
454
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
455
- ):
456
- self.resblocks.append(resblock(ch, k, d))
457
-
458
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
459
- self.ups.apply(init_weights)
460
-
461
- if gin_channels != 0:
462
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
463
-
464
- def forward(self, x, g=None):
465
- x = self.conv_pre(x)
466
- if g is not None:
467
- x = x + self.cond(g)
468
-
469
- for i in range(self.num_upsamples):
470
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
471
- x = self.ups[i](x)
472
- xs = None
473
- for j in range(self.num_kernels):
474
- if xs is None:
475
- xs = self.resblocks[i * self.num_kernels + j](x)
476
- else:
477
- xs += self.resblocks[i * self.num_kernels + j](x)
478
- x = xs / self.num_kernels
479
- x = F.leaky_relu(x)
480
- x = self.conv_post(x)
481
- x = torch.tanh(x)
482
-
483
- return x
484
-
485
- def remove_weight_norm(self):
486
- print("Removing weight norm...")
487
- for l in self.ups:
488
- remove_weight_norm(l)
489
- for l in self.resblocks:
490
- l.remove_weight_norm()
491
-
492
-
493
- class DiscriminatorP(torch.nn.Module):
494
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
495
- super(DiscriminatorP, self).__init__()
496
- self.period = period
497
- self.use_spectral_norm = use_spectral_norm
498
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
499
- self.convs = nn.ModuleList(
500
- [
501
- norm_f(
502
- Conv2d(
503
- 1,
504
- 32,
505
- (kernel_size, 1),
506
- (stride, 1),
507
- padding=(get_padding(kernel_size, 1), 0),
508
- )
509
- ),
510
- norm_f(
511
- Conv2d(
512
- 32,
513
- 128,
514
- (kernel_size, 1),
515
- (stride, 1),
516
- padding=(get_padding(kernel_size, 1), 0),
517
- )
518
- ),
519
- norm_f(
520
- Conv2d(
521
- 128,
522
- 512,
523
- (kernel_size, 1),
524
- (stride, 1),
525
- padding=(get_padding(kernel_size, 1), 0),
526
- )
527
- ),
528
- norm_f(
529
- Conv2d(
530
- 512,
531
- 1024,
532
- (kernel_size, 1),
533
- (stride, 1),
534
- padding=(get_padding(kernel_size, 1), 0),
535
- )
536
- ),
537
- norm_f(
538
- Conv2d(
539
- 1024,
540
- 1024,
541
- (kernel_size, 1),
542
- 1,
543
- padding=(get_padding(kernel_size, 1), 0),
544
- )
545
- ),
546
- ]
547
- )
548
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
549
-
550
- def forward(self, x):
551
- fmap = []
552
-
553
- # 1d to 2d
554
- b, c, t = x.shape
555
- if t % self.period != 0: # pad first
556
- n_pad = self.period - (t % self.period)
557
- x = F.pad(x, (0, n_pad), "reflect")
558
- t = t + n_pad
559
- x = x.view(b, c, t // self.period, self.period)
560
-
561
- for l in self.convs:
562
- x = l(x)
563
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
564
- fmap.append(x)
565
- x = self.conv_post(x)
566
- fmap.append(x)
567
- x = torch.flatten(x, 1, -1)
568
-
569
- return x, fmap
570
-
571
-
572
- class DiscriminatorS(torch.nn.Module):
573
- def __init__(self, use_spectral_norm=False):
574
- super(DiscriminatorS, self).__init__()
575
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
576
- self.convs = nn.ModuleList(
577
- [
578
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
579
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
580
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
581
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
582
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
583
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
584
- ]
585
- )
586
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
587
-
588
- def forward(self, x):
589
- fmap = []
590
-
591
- for l in self.convs:
592
- x = l(x)
593
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
594
- fmap.append(x)
595
- x = self.conv_post(x)
596
- fmap.append(x)
597
- x = torch.flatten(x, 1, -1)
598
-
599
- return x, fmap
600
-
601
-
602
- class MultiPeriodDiscriminator(torch.nn.Module):
603
- def __init__(self, use_spectral_norm=False):
604
- super(MultiPeriodDiscriminator, self).__init__()
605
- periods = [2, 3, 5, 7, 11]
606
-
607
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
608
- discs = discs + [
609
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
610
- ]
611
- self.discriminators = nn.ModuleList(discs)
612
-
613
- def forward(self, y, y_hat):
614
- y_d_rs = []
615
- y_d_gs = []
616
- fmap_rs = []
617
- fmap_gs = []
618
- for i, d in enumerate(self.discriminators):
619
- y_d_r, fmap_r = d(y)
620
- y_d_g, fmap_g = d(y_hat)
621
- y_d_rs.append(y_d_r)
622
- y_d_gs.append(y_d_g)
623
- fmap_rs.append(fmap_r)
624
- fmap_gs.append(fmap_g)
625
-
626
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
627
-
628
-
629
- class ReferenceEncoder(nn.Module):
630
- """
631
- inputs --- [N, Ty/r, n_mels*r] mels
632
- outputs --- [N, ref_enc_gru_size]
633
- """
634
-
635
- def __init__(self, spec_channels, gin_channels=0):
636
- super().__init__()
637
- self.spec_channels = spec_channels
638
- ref_enc_filters = [32, 32, 64, 64, 128, 128]
639
- K = len(ref_enc_filters)
640
- filters = [1] + ref_enc_filters
641
- convs = [
642
- weight_norm(
643
- nn.Conv2d(
644
- in_channels=filters[i],
645
- out_channels=filters[i + 1],
646
- kernel_size=(3, 3),
647
- stride=(2, 2),
648
- padding=(1, 1),
649
- )
650
- )
651
- for i in range(K)
652
- ]
653
- self.convs = nn.ModuleList(convs)
654
- # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
655
-
656
- out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
657
- self.gru = nn.GRU(
658
- input_size=ref_enc_filters[-1] * out_channels,
659
- hidden_size=256 // 2,
660
- batch_first=True,
661
- )
662
- self.proj = nn.Linear(128, gin_channels)
663
-
664
- def forward(self, inputs):
665
- N = inputs.size(0)
666
- out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
667
- for conv in self.convs:
668
- out = conv(out)
669
- # out = wn(out)
670
- out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
671
-
672
- out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
673
- T = out.size(1)
674
- N = out.size(0)
675
- out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
676
-
677
- self.gru.flatten_parameters()
678
- memory, out = self.gru(out) # out --- [1, N, 128]
679
-
680
- return self.proj(out.squeeze(0)).unsqueeze(-1)
681
-
682
- def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
683
- for i in range(n_convs):
684
- L = (L - kernel_size + 2 * pad) // stride + 1
685
- return L
686
-
687
-
688
- class Quantizer_module(torch.nn.Module):
689
- def __init__(self, n_e, e_dim):
690
- super(Quantizer_module, self).__init__()
691
- self.embedding = nn.Embedding(n_e, e_dim)
692
- self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
693
-
694
- def forward(self, x):
695
- d = (
696
- torch.sum(x**2, 1, keepdim=True)
697
- + torch.sum(self.embedding.weight**2, 1)
698
- - 2 * torch.matmul(x, self.embedding.weight.T)
699
- )
700
- min_indicies = torch.argmin(d, 1)
701
- z_q = self.embedding(min_indicies)
702
- return z_q, min_indicies
703
-
704
-
705
- class Quantizer(torch.nn.Module):
706
- def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
707
- super(Quantizer, self).__init__()
708
- assert embed_dim % n_code_groups == 0
709
- self.quantizer_modules = nn.ModuleList(
710
- [
711
- Quantizer_module(n_codes, embed_dim // n_code_groups)
712
- for _ in range(n_code_groups)
713
- ]
714
- )
715
- self.n_code_groups = n_code_groups
716
- self.embed_dim = embed_dim
717
-
718
- def forward(self, xin):
719
- # B, C, T
720
- B, C, T = xin.shape
721
- xin = xin.transpose(1, 2)
722
- x = xin.reshape(-1, self.embed_dim)
723
- x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
724
- min_indicies = []
725
- z_q = []
726
- for _x, m in zip(x, self.quantizer_modules):
727
- _z_q, _min_indicies = m(_x)
728
- z_q.append(_z_q)
729
- min_indicies.append(_min_indicies) # B * T,
730
- z_q = torch.cat(z_q, -1).reshape(xin.shape)
731
- loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
732
- (z_q - xin.detach()) ** 2
733
- )
734
- z_q = xin + (z_q - xin).detach()
735
- z_q = z_q.transpose(1, 2)
736
- codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
737
- return z_q, loss, codes.transpose(1, 2)
738
-
739
- def embed(self, x):
740
- # idx: N, 4, T
741
- x = x.transpose(1, 2)
742
- x = torch.split(x, 1, 2)
743
- ret = []
744
- for q, embed in zip(x, self.quantizer_modules):
745
- q = embed.embedding(q.squeeze(-1))
746
- ret.append(q)
747
- ret = torch.cat(ret, -1)
748
- return ret.transpose(1, 2) # N, C, T
749
-
750
-
751
- class CodePredictor(nn.Module):
752
- def __init__(
753
- self,
754
- hidden_channels,
755
- filter_channels,
756
- n_heads,
757
- n_layers,
758
- kernel_size,
759
- p_dropout,
760
- n_q=8,
761
- dims=1024,
762
- ssl_dim=768,
763
- ):
764
- super().__init__()
765
- self.hidden_channels = hidden_channels
766
- self.filter_channels = filter_channels
767
- self.n_heads = n_heads
768
- self.n_layers = n_layers
769
- self.kernel_size = kernel_size
770
- self.p_dropout = p_dropout
771
-
772
- self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
773
- self.ref_enc = modules.MelStyleEncoder(
774
- ssl_dim, style_vector_dim=hidden_channels
775
- )
776
-
777
- self.encoder = attentions.Encoder(
778
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
779
- )
780
-
781
- self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
782
- self.n_q = n_q
783
- self.dims = dims
784
-
785
- def forward(self, x, x_mask, refer, codes, infer=False):
786
- x = x.detach()
787
- x = self.vq_proj(x * x_mask) * x_mask
788
- g = self.ref_enc(refer, x_mask)
789
- x = x + g
790
- x = self.encoder(x * x_mask, x_mask)
791
- x = self.out_proj(x * x_mask) * x_mask
792
- logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
793
- 2, 3
794
- )
795
- target = codes[1:].transpose(0, 1)
796
- if not infer:
797
- logits = logits.reshape(-1, self.dims)
798
- target = target.reshape(-1)
799
- loss = torch.nn.functional.cross_entropy(logits, target)
800
- return loss
801
- else:
802
- _, top10_preds = torch.topk(logits, 10, dim=-1)
803
- correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
804
- top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
805
-
806
- print("Top-10 Accuracy:", top3_acc, "%")
807
-
808
- pred_codes = torch.argmax(logits, dim=-1)
809
- acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
810
- print("Top-1 Accuracy:", acc, "%")
811
-
812
- return pred_codes.transpose(0, 1)
813
-
814
-
815
- class SynthesizerTrn(nn.Module):
816
- """
817
- Synthesizer for Training
818
- """
819
-
820
- def __init__(
821
- self,
822
- spec_channels,
823
- segment_size,
824
- inter_channels,
825
- hidden_channels,
826
- filter_channels,
827
- n_heads,
828
- n_layers,
829
- kernel_size,
830
- p_dropout,
831
- resblock,
832
- resblock_kernel_sizes,
833
- resblock_dilation_sizes,
834
- upsample_rates,
835
- upsample_initial_channel,
836
- upsample_kernel_sizes,
837
- n_speakers=0,
838
- gin_channels=0,
839
- use_sdp=True,
840
- semantic_frame_rate=None,
841
- freeze_quantizer=None,
842
- version = "v2",
843
- **kwargs
844
- ):
845
- super().__init__()
846
- self.spec_channels = spec_channels
847
- self.inter_channels = inter_channels
848
- self.hidden_channels = hidden_channels
849
- self.filter_channels = filter_channels
850
- self.n_heads = n_heads
851
- self.n_layers = n_layers
852
- self.kernel_size = kernel_size
853
- self.p_dropout = p_dropout
854
- self.resblock = resblock
855
- self.resblock_kernel_sizes = resblock_kernel_sizes
856
- self.resblock_dilation_sizes = resblock_dilation_sizes
857
- self.upsample_rates = upsample_rates
858
- self.upsample_initial_channel = upsample_initial_channel
859
- self.upsample_kernel_sizes = upsample_kernel_sizes
860
- self.segment_size = segment_size
861
- self.n_speakers = n_speakers
862
- self.gin_channels = gin_channels
863
- self.version = version
864
-
865
- self.use_sdp = use_sdp
866
- self.enc_p = TextEncoder(
867
- inter_channels,
868
- hidden_channels,
869
- filter_channels,
870
- n_heads,
871
- n_layers,
872
- kernel_size,
873
- p_dropout,
874
- version = version,
875
- )
876
- self.dec = Generator(
877
- inter_channels,
878
- resblock,
879
- resblock_kernel_sizes,
880
- resblock_dilation_sizes,
881
- upsample_rates,
882
- upsample_initial_channel,
883
- upsample_kernel_sizes,
884
- gin_channels=gin_channels,
885
- )
886
- self.enc_q = PosteriorEncoder(
887
- spec_channels,
888
- inter_channels,
889
- hidden_channels,
890
- 5,
891
- 1,
892
- 16,
893
- gin_channels=gin_channels,
894
- )
895
- self.flow = ResidualCouplingBlock(
896
- inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
897
- )
898
-
899
- # self.version=os.environ.get("version","v1")
900
- if(self.version=="v1"):
901
- self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
902
- else:
903
- self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
904
-
905
- ssl_dim = 768
906
- assert semantic_frame_rate in ["25hz", "50hz"]
907
- self.semantic_frame_rate = semantic_frame_rate
908
- if semantic_frame_rate == "25hz":
909
- self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
910
- else:
911
- self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
912
-
913
- self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
914
- self.freeze_quantizer = freeze_quantizer
915
- self.sv_emb = nn.Linear(20480, gin_channels)
916
- self.ge_to512 = nn.Linear(gin_channels, 512)
917
- self.prelu = nn.PReLU(num_parameters=gin_channels)
918
-
919
- def forward(self, ssl, y, y_lengths, text, text_lengths):
920
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
921
- y.dtype
922
- )
923
- if(self.version=="v1"):
924
- ge = self.ref_enc(y * y_mask, y_mask)
925
- else:
926
- ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
927
- sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
928
- ge += sv_emb.unsqueeze(-1)
929
- ge = self.prelu(ge)
930
- ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
931
- with autocast(enabled=False):
932
- maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
933
- with maybe_no_grad:
934
- if self.freeze_quantizer:
935
- self.ssl_proj.eval()
936
- self.quantizer.eval()
937
- ssl = self.ssl_proj(ssl)
938
- quantized, codes, commit_loss, quantized_list = self.quantizer(
939
- ssl, layers=[0]
940
- )
941
-
942
- if self.semantic_frame_rate == "25hz":
943
- quantized = F.interpolate(
944
- quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
945
- )
946
-
947
- x, m_p, logs_p, y_mask = self.enc_p(
948
- quantized, y_lengths, text, text_lengths, ge512
949
- )
950
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
951
- z_p = self.flow(z, y_mask, g=ge)
952
-
953
- z_slice, ids_slice = commons.rand_slice_segments(
954
- z, y_lengths, self.segment_size
955
- )
956
- o = self.dec(z_slice, g=ge)
957
- return (
958
- o,
959
- commit_loss,
960
- ids_slice,
961
- y_mask,
962
- y_mask,
963
- (z, z_p, m_p, logs_p, m_q, logs_q),
964
- quantized,
965
- )
966
-
967
- def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
968
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
969
- y.dtype
970
- )
971
- if(self.version=="v1"):
972
- ge = self.ref_enc(y * y_mask, y_mask)
973
- else:
974
- ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
975
-
976
- ssl = self.ssl_proj(ssl)
977
- quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
978
- if self.semantic_frame_rate == "25hz":
979
- quantized = F.interpolate(
980
- quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
981
- )
982
-
983
- x, m_p, logs_p, y_mask = self.enc_p(
984
- quantized, y_lengths, text, text_lengths, ge, test=test
985
- )
986
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
987
-
988
- z = self.flow(z_p, y_mask, g=ge, reverse=True)
989
-
990
- o = self.dec((z * y_mask)[:, :, :], g=ge)
991
- return o, y_mask, (z, z_p, m_p, logs_p)
992
-
993
- @torch.no_grad()
994
- def decode(self, codes, text, refer, noise_scale=0.5,speed=1, sv_emb=None):
995
- def get_ge(refer, sv_emb):
996
- ge = None
997
- if refer is not None:
998
- refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
999
- refer_mask = torch.unsqueeze(
1000
- commons.sequence_mask(refer_lengths, refer.size(2)), 1
1001
- ).to(refer.dtype)
1002
- if (self.version == "v1"):
1003
- ge = self.ref_enc(refer * refer_mask, refer_mask)
1004
- else:
1005
- ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
1006
- sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
1007
- ge += sv_emb.unsqueeze(-1)
1008
- ge = self.prelu(ge)
1009
- return ge
1010
- if(type(refer)==list):
1011
- ges=[]
1012
- for idx,_refer in enumerate(refer):
1013
- ge=get_ge(_refer,sv_emb[idx])
1014
- ges.append(ge)
1015
- ge=torch.stack(ges,0).mean(0)
1016
- else:
1017
- ge = get_ge(refer, sv_emb)
1018
-
1019
- y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
1020
- text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
1021
-
1022
- quantized = self.quantizer.decode(codes)
1023
- if self.semantic_frame_rate == "25hz":
1024
- quantized = F.interpolate(
1025
- quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
1026
- )
1027
- x, m_p, logs_p, y_mask = self.enc_p(
1028
- quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1),speed
1029
- )
1030
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1031
-
1032
- z = self.flow(z_p, y_mask, g=ge, reverse=True)
1033
-
1034
- o = self.dec((z * y_mask)[:, :, :], g=ge)
1035
- return o
1036
-
1037
- def extract_latent(self, x):
1038
- ssl = self.ssl_proj(x)
1039
- quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
1040
- return codes.transpose(0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/models_onnx.py DELETED
@@ -1,918 +0,0 @@
1
- import copy
2
- import math
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from module import commons
8
- from module import modules
9
- from module import attentions_onnx as attentions
10
-
11
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
- from module.commons import init_weights, get_padding
14
- from module.mrte_model import MRTE
15
- from module.quantize import ResidualVectorQuantizer
16
- from text import symbols
17
- from torch.cuda.amp import autocast
18
-
19
-
20
- class StochasticDurationPredictor(nn.Module):
21
- def __init__(
22
- self,
23
- in_channels,
24
- filter_channels,
25
- kernel_size,
26
- p_dropout,
27
- n_flows=4,
28
- gin_channels=0,
29
- ):
30
- super().__init__()
31
- filter_channels = in_channels # it needs to be removed from future version.
32
- self.in_channels = in_channels
33
- self.filter_channels = filter_channels
34
- self.kernel_size = kernel_size
35
- self.p_dropout = p_dropout
36
- self.n_flows = n_flows
37
- self.gin_channels = gin_channels
38
-
39
- self.log_flow = modules.Log()
40
- self.flows = nn.ModuleList()
41
- self.flows.append(modules.ElementwiseAffine(2))
42
- for i in range(n_flows):
43
- self.flows.append(
44
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
45
- )
46
- self.flows.append(modules.Flip())
47
-
48
- self.post_pre = nn.Conv1d(1, filter_channels, 1)
49
- self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
50
- self.post_convs = modules.DDSConv(
51
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
52
- )
53
- self.post_flows = nn.ModuleList()
54
- self.post_flows.append(modules.ElementwiseAffine(2))
55
- for i in range(4):
56
- self.post_flows.append(
57
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
58
- )
59
- self.post_flows.append(modules.Flip())
60
-
61
- self.pre = nn.Conv1d(in_channels, filter_channels, 1)
62
- self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
63
- self.convs = modules.DDSConv(
64
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
65
- )
66
- if gin_channels != 0:
67
- self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
68
-
69
- def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
70
- x = torch.detach(x)
71
- x = self.pre(x)
72
- if g is not None:
73
- g = torch.detach(g)
74
- x = x + self.cond(g)
75
- x = self.convs(x, x_mask)
76
- x = self.proj(x) * x_mask
77
-
78
- if not reverse:
79
- flows = self.flows
80
- assert w is not None
81
-
82
- logdet_tot_q = 0
83
- h_w = self.post_pre(w)
84
- h_w = self.post_convs(h_w, x_mask)
85
- h_w = self.post_proj(h_w) * x_mask
86
- e_q = (
87
- torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
88
- * x_mask
89
- )
90
- z_q = e_q
91
- for flow in self.post_flows:
92
- z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
93
- logdet_tot_q += logdet_q
94
- z_u, z1 = torch.split(z_q, [1, 1], 1)
95
- u = torch.sigmoid(z_u) * x_mask
96
- z0 = (w - u) * x_mask
97
- logdet_tot_q += torch.sum(
98
- (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
99
- )
100
- logq = (
101
- torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
102
- - logdet_tot_q
103
- )
104
-
105
- logdet_tot = 0
106
- z0, logdet = self.log_flow(z0, x_mask)
107
- logdet_tot += logdet
108
- z = torch.cat([z0, z1], 1)
109
- for flow in flows:
110
- z, logdet = flow(z, x_mask, g=x, reverse=reverse)
111
- logdet_tot = logdet_tot + logdet
112
- nll = (
113
- torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
114
- - logdet_tot
115
- )
116
- return nll + logq # [b]
117
- else:
118
- flows = list(reversed(self.flows))
119
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
120
- z = (
121
- torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
122
- * noise_scale
123
- )
124
- for flow in flows:
125
- z = flow(z, x_mask, g=x, reverse=reverse)
126
- z0, z1 = torch.split(z, [1, 1], 1)
127
- logw = z0
128
- return logw
129
-
130
-
131
- class DurationPredictor(nn.Module):
132
- def __init__(
133
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
134
- ):
135
- super().__init__()
136
-
137
- self.in_channels = in_channels
138
- self.filter_channels = filter_channels
139
- self.kernel_size = kernel_size
140
- self.p_dropout = p_dropout
141
- self.gin_channels = gin_channels
142
-
143
- self.drop = nn.Dropout(p_dropout)
144
- self.conv_1 = nn.Conv1d(
145
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
146
- )
147
- self.norm_1 = modules.LayerNorm(filter_channels)
148
- self.conv_2 = nn.Conv1d(
149
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
150
- )
151
- self.norm_2 = modules.LayerNorm(filter_channels)
152
- self.proj = nn.Conv1d(filter_channels, 1, 1)
153
-
154
- if gin_channels != 0:
155
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
156
-
157
- def forward(self, x, x_mask, g=None):
158
- x = torch.detach(x)
159
- if g is not None:
160
- g = torch.detach(g)
161
- x = x + self.cond(g)
162
- x = self.conv_1(x * x_mask)
163
- x = torch.relu(x)
164
- x = self.norm_1(x)
165
- x = self.drop(x)
166
- x = self.conv_2(x * x_mask)
167
- x = torch.relu(x)
168
- x = self.norm_2(x)
169
- x = self.drop(x)
170
- x = self.proj(x * x_mask)
171
- return x * x_mask
172
-
173
-
174
- class TextEncoder(nn.Module):
175
- def __init__(
176
- self,
177
- out_channels,
178
- hidden_channels,
179
- filter_channels,
180
- n_heads,
181
- n_layers,
182
- kernel_size,
183
- p_dropout,
184
- latent_channels=192,
185
- ):
186
- super().__init__()
187
- self.out_channels = out_channels
188
- self.hidden_channels = hidden_channels
189
- self.filter_channels = filter_channels
190
- self.n_heads = n_heads
191
- self.n_layers = n_layers
192
- self.kernel_size = kernel_size
193
- self.p_dropout = p_dropout
194
- self.latent_channels = latent_channels
195
-
196
- self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
197
-
198
- self.encoder_ssl = attentions.Encoder(
199
- hidden_channels,
200
- filter_channels,
201
- n_heads,
202
- n_layers // 2,
203
- kernel_size,
204
- p_dropout,
205
- )
206
-
207
- self.encoder_text = attentions.Encoder(
208
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
209
- )
210
- self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
211
-
212
- self.mrte = MRTE()
213
-
214
- self.encoder2 = attentions.Encoder(
215
- hidden_channels,
216
- filter_channels,
217
- n_heads,
218
- n_layers // 2,
219
- kernel_size,
220
- p_dropout,
221
- )
222
-
223
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
224
-
225
- def forward(self, y, text, ge):
226
- y_mask = torch.ones_like(y[:1,:1,:])
227
-
228
- y = self.ssl_proj(y * y_mask) * y_mask
229
- y = self.encoder_ssl(y * y_mask, y_mask)
230
-
231
- text_mask = torch.ones_like(text).to(y.dtype).unsqueeze(0)
232
-
233
- text = self.text_embedding(text).transpose(1, 2)
234
- text = self.encoder_text(text * text_mask, text_mask)
235
- y = self.mrte(y, y_mask, text, text_mask, ge)
236
-
237
- y = self.encoder2(y * y_mask, y_mask)
238
-
239
- stats = self.proj(y) * y_mask
240
- m, logs = torch.split(stats, self.out_channels, dim=1)
241
- return y, m, logs, y_mask
242
-
243
- def extract_latent(self, x):
244
- x = self.ssl_proj(x)
245
- quantized, codes, commit_loss, quantized_list = self.quantizer(x)
246
- return codes.transpose(0, 1)
247
-
248
- def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
249
- quantized = self.quantizer.decode(codes)
250
-
251
- y = self.vq_proj(quantized) * y_mask
252
- y = self.encoder_ssl(y * y_mask, y_mask)
253
-
254
- y = self.mrte(y, y_mask, refer, refer_mask, ge)
255
-
256
- y = self.encoder2(y * y_mask, y_mask)
257
-
258
- stats = self.proj(y) * y_mask
259
- m, logs = torch.split(stats, self.out_channels, dim=1)
260
- return y, m, logs, y_mask, quantized
261
-
262
-
263
- class ResidualCouplingBlock(nn.Module):
264
- def __init__(
265
- self,
266
- channels,
267
- hidden_channels,
268
- kernel_size,
269
- dilation_rate,
270
- n_layers,
271
- n_flows=4,
272
- gin_channels=0,
273
- ):
274
- super().__init__()
275
- self.channels = channels
276
- self.hidden_channels = hidden_channels
277
- self.kernel_size = kernel_size
278
- self.dilation_rate = dilation_rate
279
- self.n_layers = n_layers
280
- self.n_flows = n_flows
281
- self.gin_channels = gin_channels
282
-
283
- self.flows = nn.ModuleList()
284
- for i in range(n_flows):
285
- self.flows.append(
286
- modules.ResidualCouplingLayer(
287
- channels,
288
- hidden_channels,
289
- kernel_size,
290
- dilation_rate,
291
- n_layers,
292
- gin_channels=gin_channels,
293
- mean_only=True,
294
- )
295
- )
296
- self.flows.append(modules.Flip())
297
-
298
- def forward(self, x, x_mask, g=None, reverse=False):
299
- if not reverse:
300
- for flow in self.flows:
301
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
302
- else:
303
- for flow in reversed(self.flows):
304
- x = flow(x, x_mask, g=g, reverse=reverse)
305
- return x
306
-
307
-
308
- class PosteriorEncoder(nn.Module):
309
- def __init__(
310
- self,
311
- in_channels,
312
- out_channels,
313
- hidden_channels,
314
- kernel_size,
315
- dilation_rate,
316
- n_layers,
317
- gin_channels=0,
318
- ):
319
- super().__init__()
320
- self.in_channels = in_channels
321
- self.out_channels = out_channels
322
- self.hidden_channels = hidden_channels
323
- self.kernel_size = kernel_size
324
- self.dilation_rate = dilation_rate
325
- self.n_layers = n_layers
326
- self.gin_channels = gin_channels
327
-
328
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
329
- self.enc = modules.WN(
330
- hidden_channels,
331
- kernel_size,
332
- dilation_rate,
333
- n_layers,
334
- gin_channels=gin_channels,
335
- )
336
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
337
-
338
- def forward(self, x, x_lengths, g=None):
339
- if g != None:
340
- g = g.detach()
341
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
342
- x.dtype
343
- )
344
- x = self.pre(x) * x_mask
345
- x = self.enc(x, x_mask, g=g)
346
- stats = self.proj(x) * x_mask
347
- m, logs = torch.split(stats, self.out_channels, dim=1)
348
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
349
- return z, m, logs, x_mask
350
-
351
-
352
- class WNEncoder(nn.Module):
353
- def __init__(
354
- self,
355
- in_channels,
356
- out_channels,
357
- hidden_channels,
358
- kernel_size,
359
- dilation_rate,
360
- n_layers,
361
- gin_channels=0,
362
- ):
363
- super().__init__()
364
- self.in_channels = in_channels
365
- self.out_channels = out_channels
366
- self.hidden_channels = hidden_channels
367
- self.kernel_size = kernel_size
368
- self.dilation_rate = dilation_rate
369
- self.n_layers = n_layers
370
- self.gin_channels = gin_channels
371
-
372
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
373
- self.enc = modules.WN(
374
- hidden_channels,
375
- kernel_size,
376
- dilation_rate,
377
- n_layers,
378
- gin_channels=gin_channels,
379
- )
380
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
381
- self.norm = modules.LayerNorm(out_channels)
382
-
383
- def forward(self, x, x_lengths, g=None):
384
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
385
- x.dtype
386
- )
387
- x = self.pre(x) * x_mask
388
- x = self.enc(x, x_mask, g=g)
389
- out = self.proj(x) * x_mask
390
- out = self.norm(out)
391
- return out
392
-
393
-
394
- class Generator(torch.nn.Module):
395
- def __init__(
396
- self,
397
- initial_channel,
398
- resblock,
399
- resblock_kernel_sizes,
400
- resblock_dilation_sizes,
401
- upsample_rates,
402
- upsample_initial_channel,
403
- upsample_kernel_sizes,
404
- gin_channels=0,
405
- ):
406
- super(Generator, self).__init__()
407
- self.num_kernels = len(resblock_kernel_sizes)
408
- self.num_upsamples = len(upsample_rates)
409
- self.conv_pre = Conv1d(
410
- initial_channel, upsample_initial_channel, 7, 1, padding=3
411
- )
412
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
413
-
414
- self.ups = nn.ModuleList()
415
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
416
- self.ups.append(
417
- weight_norm(
418
- ConvTranspose1d(
419
- upsample_initial_channel // (2**i),
420
- upsample_initial_channel // (2 ** (i + 1)),
421
- k,
422
- u,
423
- padding=(k - u) // 2,
424
- )
425
- )
426
- )
427
-
428
- self.resblocks = nn.ModuleList()
429
- for i in range(len(self.ups)):
430
- ch = upsample_initial_channel // (2 ** (i + 1))
431
- for j, (k, d) in enumerate(
432
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
433
- ):
434
- self.resblocks.append(resblock(ch, k, d))
435
-
436
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
437
- self.ups.apply(init_weights)
438
-
439
- if gin_channels != 0:
440
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
441
-
442
- def forward(self, x, g=None):
443
- x = self.conv_pre(x)
444
- if g is not None:
445
- x = x + self.cond(g)
446
-
447
- for i in range(self.num_upsamples):
448
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
449
- x = self.ups[i](x)
450
- xs = None
451
- for j in range(self.num_kernels):
452
- if xs is None:
453
- xs = self.resblocks[i * self.num_kernels + j](x)
454
- else:
455
- xs += self.resblocks[i * self.num_kernels + j](x)
456
- x = xs / self.num_kernels
457
- x = F.leaky_relu(x)
458
- x = self.conv_post(x)
459
- x = torch.tanh(x)
460
-
461
- return x
462
-
463
- def remove_weight_norm(self):
464
- print("Removing weight norm...")
465
- for l in self.ups:
466
- remove_weight_norm(l)
467
- for l in self.resblocks:
468
- l.remove_weight_norm()
469
-
470
-
471
- class DiscriminatorP(torch.nn.Module):
472
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
473
- super(DiscriminatorP, self).__init__()
474
- self.period = period
475
- self.use_spectral_norm = use_spectral_norm
476
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
477
- self.convs = nn.ModuleList(
478
- [
479
- norm_f(
480
- Conv2d(
481
- 1,
482
- 32,
483
- (kernel_size, 1),
484
- (stride, 1),
485
- padding=(get_padding(kernel_size, 1), 0),
486
- )
487
- ),
488
- norm_f(
489
- Conv2d(
490
- 32,
491
- 128,
492
- (kernel_size, 1),
493
- (stride, 1),
494
- padding=(get_padding(kernel_size, 1), 0),
495
- )
496
- ),
497
- norm_f(
498
- Conv2d(
499
- 128,
500
- 512,
501
- (kernel_size, 1),
502
- (stride, 1),
503
- padding=(get_padding(kernel_size, 1), 0),
504
- )
505
- ),
506
- norm_f(
507
- Conv2d(
508
- 512,
509
- 1024,
510
- (kernel_size, 1),
511
- (stride, 1),
512
- padding=(get_padding(kernel_size, 1), 0),
513
- )
514
- ),
515
- norm_f(
516
- Conv2d(
517
- 1024,
518
- 1024,
519
- (kernel_size, 1),
520
- 1,
521
- padding=(get_padding(kernel_size, 1), 0),
522
- )
523
- ),
524
- ]
525
- )
526
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
527
-
528
- def forward(self, x):
529
- fmap = []
530
-
531
- # 1d to 2d
532
- b, c, t = x.shape
533
- if t % self.period != 0: # pad first
534
- n_pad = self.period - (t % self.period)
535
- x = F.pad(x, (0, n_pad), "reflect")
536
- t = t + n_pad
537
- x = x.view(b, c, t // self.period, self.period)
538
-
539
- for l in self.convs:
540
- x = l(x)
541
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
542
- fmap.append(x)
543
- x = self.conv_post(x)
544
- fmap.append(x)
545
- x = torch.flatten(x, 1, -1)
546
-
547
- return x, fmap
548
-
549
-
550
- class DiscriminatorS(torch.nn.Module):
551
- def __init__(self, use_spectral_norm=False):
552
- super(DiscriminatorS, self).__init__()
553
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
554
- self.convs = nn.ModuleList(
555
- [
556
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
557
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
558
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
559
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
560
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
561
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
562
- ]
563
- )
564
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
565
-
566
- def forward(self, x):
567
- fmap = []
568
-
569
- for l in self.convs:
570
- x = l(x)
571
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
572
- fmap.append(x)
573
- x = self.conv_post(x)
574
- fmap.append(x)
575
- x = torch.flatten(x, 1, -1)
576
-
577
- return x, fmap
578
-
579
-
580
- class MultiPeriodDiscriminator(torch.nn.Module):
581
- def __init__(self, use_spectral_norm=False):
582
- super(MultiPeriodDiscriminator, self).__init__()
583
- periods = [2, 3, 5, 7, 11]
584
-
585
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
586
- discs = discs + [
587
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
588
- ]
589
- self.discriminators = nn.ModuleList(discs)
590
-
591
- def forward(self, y, y_hat):
592
- y_d_rs = []
593
- y_d_gs = []
594
- fmap_rs = []
595
- fmap_gs = []
596
- for i, d in enumerate(self.discriminators):
597
- y_d_r, fmap_r = d(y)
598
- y_d_g, fmap_g = d(y_hat)
599
- y_d_rs.append(y_d_r)
600
- y_d_gs.append(y_d_g)
601
- fmap_rs.append(fmap_r)
602
- fmap_gs.append(fmap_g)
603
-
604
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
605
-
606
-
607
- class ReferenceEncoder(nn.Module):
608
- """
609
- inputs --- [N, Ty/r, n_mels*r] mels
610
- outputs --- [N, ref_enc_gru_size]
611
- """
612
-
613
- def __init__(self, spec_channels, gin_channels=0):
614
- super().__init__()
615
- self.spec_channels = spec_channels
616
- ref_enc_filters = [32, 32, 64, 64, 128, 128]
617
- K = len(ref_enc_filters)
618
- filters = [1] + ref_enc_filters
619
- convs = [
620
- weight_norm(
621
- nn.Conv2d(
622
- in_channels=filters[i],
623
- out_channels=filters[i + 1],
624
- kernel_size=(3, 3),
625
- stride=(2, 2),
626
- padding=(1, 1),
627
- )
628
- )
629
- for i in range(K)
630
- ]
631
- self.convs = nn.ModuleList(convs)
632
- # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
633
-
634
- out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
635
- self.gru = nn.GRU(
636
- input_size=ref_enc_filters[-1] * out_channels,
637
- hidden_size=256 // 2,
638
- batch_first=True,
639
- )
640
- self.proj = nn.Linear(128, gin_channels)
641
-
642
- def forward(self, inputs):
643
- N = inputs.size(0)
644
- out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
645
- for conv in self.convs:
646
- out = conv(out)
647
- # out = wn(out)
648
- out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
649
-
650
- out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
651
- T = out.size(1)
652
- N = out.size(0)
653
- out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
654
-
655
- self.gru.flatten_parameters()
656
- memory, out = self.gru(out) # out --- [1, N, 128]
657
-
658
- return self.proj(out.squeeze(0)).unsqueeze(-1)
659
-
660
- def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
661
- for i in range(n_convs):
662
- L = (L - kernel_size + 2 * pad) // stride + 1
663
- return L
664
-
665
-
666
- class Quantizer_module(torch.nn.Module):
667
- def __init__(self, n_e, e_dim):
668
- super(Quantizer_module, self).__init__()
669
- self.embedding = nn.Embedding(n_e, e_dim)
670
- self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
671
-
672
- def forward(self, x):
673
- d = (
674
- torch.sum(x**2, 1, keepdim=True)
675
- + torch.sum(self.embedding.weight**2, 1)
676
- - 2 * torch.matmul(x, self.embedding.weight.T)
677
- )
678
- min_indicies = torch.argmin(d, 1)
679
- z_q = self.embedding(min_indicies)
680
- return z_q, min_indicies
681
-
682
-
683
- class Quantizer(torch.nn.Module):
684
- def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
685
- super(Quantizer, self).__init__()
686
- assert embed_dim % n_code_groups == 0
687
- self.quantizer_modules = nn.ModuleList(
688
- [
689
- Quantizer_module(n_codes, embed_dim // n_code_groups)
690
- for _ in range(n_code_groups)
691
- ]
692
- )
693
- self.n_code_groups = n_code_groups
694
- self.embed_dim = embed_dim
695
-
696
- def forward(self, xin):
697
- # B, C, T
698
- B, C, T = xin.shape
699
- xin = xin.transpose(1, 2)
700
- x = xin.reshape(-1, self.embed_dim)
701
- x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
702
- min_indicies = []
703
- z_q = []
704
- for _x, m in zip(x, self.quantizer_modules):
705
- _z_q, _min_indicies = m(_x)
706
- z_q.append(_z_q)
707
- min_indicies.append(_min_indicies) # B * T,
708
- z_q = torch.cat(z_q, -1).reshape(xin.shape)
709
- loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
710
- (z_q - xin.detach()) ** 2
711
- )
712
- z_q = xin + (z_q - xin).detach()
713
- z_q = z_q.transpose(1, 2)
714
- codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
715
- return z_q, loss, codes.transpose(1, 2)
716
-
717
- def embed(self, x):
718
- # idx: N, 4, T
719
- x = x.transpose(1, 2)
720
- x = torch.split(x, 1, 2)
721
- ret = []
722
- for q, embed in zip(x, self.quantizer_modules):
723
- q = embed.embedding(q.squeeze(-1))
724
- ret.append(q)
725
- ret = torch.cat(ret, -1)
726
- return ret.transpose(1, 2) # N, C, T
727
-
728
-
729
- class CodePredictor(nn.Module):
730
- def __init__(
731
- self,
732
- hidden_channels,
733
- filter_channels,
734
- n_heads,
735
- n_layers,
736
- kernel_size,
737
- p_dropout,
738
- n_q=8,
739
- dims=1024,
740
- ssl_dim=768,
741
- ):
742
- super().__init__()
743
- self.hidden_channels = hidden_channels
744
- self.filter_channels = filter_channels
745
- self.n_heads = n_heads
746
- self.n_layers = n_layers
747
- self.kernel_size = kernel_size
748
- self.p_dropout = p_dropout
749
-
750
- self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
751
- self.ref_enc = modules.MelStyleEncoder(
752
- ssl_dim, style_vector_dim=hidden_channels
753
- )
754
-
755
- self.encoder = attentions.Encoder(
756
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
757
- )
758
-
759
- self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
760
- self.n_q = n_q
761
- self.dims = dims
762
-
763
- def forward(self, x, x_mask, refer, codes, infer=False):
764
- x = x.detach()
765
- x = self.vq_proj(x * x_mask) * x_mask
766
- g = self.ref_enc(refer, x_mask)
767
- x = x + g
768
- x = self.encoder(x * x_mask, x_mask)
769
- x = self.out_proj(x * x_mask) * x_mask
770
- logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
771
- 2, 3
772
- )
773
- target = codes[1:].transpose(0, 1)
774
- if not infer:
775
- logits = logits.reshape(-1, self.dims)
776
- target = target.reshape(-1)
777
- loss = torch.nn.functional.cross_entropy(logits, target)
778
- return loss
779
- else:
780
- _, top10_preds = torch.topk(logits, 10, dim=-1)
781
- correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
782
- top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
783
-
784
- print("Top-10 Accuracy:", top3_acc, "%")
785
-
786
- pred_codes = torch.argmax(logits, dim=-1)
787
- acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
788
- print("Top-1 Accuracy:", acc, "%")
789
-
790
- return pred_codes.transpose(0, 1)
791
-
792
-
793
- class SynthesizerTrn(nn.Module):
794
- """
795
- Synthesizer for Training
796
- """
797
-
798
- def __init__(
799
- self,
800
- spec_channels,
801
- segment_size,
802
- inter_channels,
803
- hidden_channels,
804
- filter_channels,
805
- n_heads,
806
- n_layers,
807
- kernel_size,
808
- p_dropout,
809
- resblock,
810
- resblock_kernel_sizes,
811
- resblock_dilation_sizes,
812
- upsample_rates,
813
- upsample_initial_channel,
814
- upsample_kernel_sizes,
815
- n_speakers=0,
816
- gin_channels=0,
817
- use_sdp=True,
818
- semantic_frame_rate=None,
819
- freeze_quantizer=None,
820
- **kwargs
821
- ):
822
- super().__init__()
823
- self.spec_channels = spec_channels
824
- self.inter_channels = inter_channels
825
- self.hidden_channels = hidden_channels
826
- self.filter_channels = filter_channels
827
- self.n_heads = n_heads
828
- self.n_layers = n_layers
829
- self.kernel_size = kernel_size
830
- self.p_dropout = p_dropout
831
- self.resblock = resblock
832
- self.resblock_kernel_sizes = resblock_kernel_sizes
833
- self.resblock_dilation_sizes = resblock_dilation_sizes
834
- self.upsample_rates = upsample_rates
835
- self.upsample_initial_channel = upsample_initial_channel
836
- self.upsample_kernel_sizes = upsample_kernel_sizes
837
- self.segment_size = segment_size
838
- self.n_speakers = n_speakers
839
- self.gin_channels = gin_channels
840
-
841
- self.use_sdp = use_sdp
842
- self.enc_p = TextEncoder(
843
- inter_channels,
844
- hidden_channels,
845
- filter_channels,
846
- n_heads,
847
- n_layers,
848
- kernel_size,
849
- p_dropout,
850
- )
851
- self.dec = Generator(
852
- inter_channels,
853
- resblock,
854
- resblock_kernel_sizes,
855
- resblock_dilation_sizes,
856
- upsample_rates,
857
- upsample_initial_channel,
858
- upsample_kernel_sizes,
859
- gin_channels=gin_channels,
860
- )
861
- self.enc_q = PosteriorEncoder(
862
- spec_channels,
863
- inter_channels,
864
- hidden_channels,
865
- 5,
866
- 1,
867
- 16,
868
- gin_channels=gin_channels,
869
- )
870
- self.flow = ResidualCouplingBlock(
871
- inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
872
- )
873
-
874
- self.ref_enc = modules.MelStyleEncoder(
875
- spec_channels, style_vector_dim=gin_channels
876
- )
877
-
878
- ssl_dim = 768
879
- self.ssl_dim = ssl_dim
880
- assert semantic_frame_rate in ["25hz", "50hz"]
881
- self.semantic_frame_rate = semantic_frame_rate
882
- if semantic_frame_rate == "25hz":
883
- self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
884
- else:
885
- self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
886
-
887
- self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
888
- if freeze_quantizer:
889
- self.ssl_proj.requires_grad_(False)
890
- self.quantizer.requires_grad_(False)
891
- # self.enc_p.text_embedding.requires_grad_(False)
892
- # self.enc_p.encoder_text.requires_grad_(False)
893
- # self.enc_p.mrte.requires_grad_(False)
894
-
895
- def forward(self, codes, text, refer):
896
- refer_mask = torch.ones_like(refer[:1,:1,:])
897
- ge = self.ref_enc(refer * refer_mask, refer_mask)
898
-
899
- quantized = self.quantizer.decode(codes)
900
- if self.semantic_frame_rate == "25hz":
901
- dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
902
- quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
903
-
904
- x, m_p, logs_p, y_mask = self.enc_p(
905
- quantized, text, ge
906
- )
907
-
908
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
909
-
910
- z = self.flow(z_p, y_mask, g=ge, reverse=True)
911
-
912
- o = self.dec((z * y_mask)[:, :, :], g=ge)
913
- return o
914
-
915
- def extract_latent(self, x):
916
- ssl = self.ssl_proj(x)
917
- quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
918
- return codes.transpose(0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/modules.py DELETED
@@ -1,923 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from torch.nn import Conv1d
8
- from torch.nn.utils import weight_norm, remove_weight_norm
9
-
10
- from module import commons
11
- from module.commons import init_weights, get_padding
12
- from module.transforms import piecewise_rational_quadratic_transform
13
- import torch.distributions as D
14
-
15
-
16
- LRELU_SLOPE = 0.1
17
-
18
-
19
- class LayerNorm(nn.Module):
20
- def __init__(self, channels, eps=1e-5):
21
- super().__init__()
22
- self.channels = channels
23
- self.eps = eps
24
-
25
- self.gamma = nn.Parameter(torch.ones(channels))
26
- self.beta = nn.Parameter(torch.zeros(channels))
27
-
28
- def forward(self, x):
29
- x = x.transpose(1, -1)
30
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
- return x.transpose(1, -1)
32
-
33
-
34
- class ConvReluNorm(nn.Module):
35
- def __init__(
36
- self,
37
- in_channels,
38
- hidden_channels,
39
- out_channels,
40
- kernel_size,
41
- n_layers,
42
- p_dropout,
43
- ):
44
- super().__init__()
45
- self.in_channels = in_channels
46
- self.hidden_channels = hidden_channels
47
- self.out_channels = out_channels
48
- self.kernel_size = kernel_size
49
- self.n_layers = n_layers
50
- self.p_dropout = p_dropout
51
- assert n_layers > 1, "Number of layers should be larger than 0."
52
-
53
- self.conv_layers = nn.ModuleList()
54
- self.norm_layers = nn.ModuleList()
55
- self.conv_layers.append(
56
- nn.Conv1d(
57
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
58
- )
59
- )
60
- self.norm_layers.append(LayerNorm(hidden_channels))
61
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
62
- for _ in range(n_layers - 1):
63
- self.conv_layers.append(
64
- nn.Conv1d(
65
- hidden_channels,
66
- hidden_channels,
67
- kernel_size,
68
- padding=kernel_size // 2,
69
- )
70
- )
71
- self.norm_layers.append(LayerNorm(hidden_channels))
72
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
73
- self.proj.weight.data.zero_()
74
- self.proj.bias.data.zero_()
75
-
76
- def forward(self, x, x_mask):
77
- x_org = x
78
- for i in range(self.n_layers):
79
- x = self.conv_layers[i](x * x_mask)
80
- x = self.norm_layers[i](x)
81
- x = self.relu_drop(x)
82
- x = x_org + self.proj(x)
83
- return x * x_mask
84
-
85
-
86
- class DDSConv(nn.Module):
87
- """
88
- Dialted and Depth-Separable Convolution
89
- """
90
-
91
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
92
- super().__init__()
93
- self.channels = channels
94
- self.kernel_size = kernel_size
95
- self.n_layers = n_layers
96
- self.p_dropout = p_dropout
97
-
98
- self.drop = nn.Dropout(p_dropout)
99
- self.convs_sep = nn.ModuleList()
100
- self.convs_1x1 = nn.ModuleList()
101
- self.norms_1 = nn.ModuleList()
102
- self.norms_2 = nn.ModuleList()
103
- for i in range(n_layers):
104
- dilation = kernel_size**i
105
- padding = (kernel_size * dilation - dilation) // 2
106
- self.convs_sep.append(
107
- nn.Conv1d(
108
- channels,
109
- channels,
110
- kernel_size,
111
- groups=channels,
112
- dilation=dilation,
113
- padding=padding,
114
- )
115
- )
116
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
117
- self.norms_1.append(LayerNorm(channels))
118
- self.norms_2.append(LayerNorm(channels))
119
-
120
- def forward(self, x, x_mask, g=None):
121
- if g is not None:
122
- x = x + g
123
- for i in range(self.n_layers):
124
- y = self.convs_sep[i](x * x_mask)
125
- y = self.norms_1[i](y)
126
- y = F.gelu(y)
127
- y = self.convs_1x1[i](y)
128
- y = self.norms_2[i](y)
129
- y = F.gelu(y)
130
- y = self.drop(y)
131
- x = x + y
132
- return x * x_mask
133
-
134
-
135
- class WN(torch.nn.Module):
136
- def __init__(
137
- self,
138
- hidden_channels,
139
- kernel_size,
140
- dilation_rate,
141
- n_layers,
142
- gin_channels=0,
143
- p_dropout=0,
144
- ):
145
- super(WN, self).__init__()
146
- assert kernel_size % 2 == 1
147
- self.hidden_channels = hidden_channels
148
- self.kernel_size = (kernel_size,)
149
- self.dilation_rate = dilation_rate
150
- self.n_layers = n_layers
151
- self.gin_channels = gin_channels
152
- self.p_dropout = p_dropout
153
-
154
- self.in_layers = torch.nn.ModuleList()
155
- self.res_skip_layers = torch.nn.ModuleList()
156
- self.drop = nn.Dropout(p_dropout)
157
-
158
- if gin_channels != 0:
159
- cond_layer = torch.nn.Conv1d(
160
- gin_channels, 2 * hidden_channels * n_layers, 1
161
- )
162
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
163
-
164
- for i in range(n_layers):
165
- dilation = dilation_rate**i
166
- padding = int((kernel_size * dilation - dilation) / 2)
167
- in_layer = torch.nn.Conv1d(
168
- hidden_channels,
169
- 2 * hidden_channels,
170
- kernel_size,
171
- dilation=dilation,
172
- padding=padding,
173
- )
174
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
175
- self.in_layers.append(in_layer)
176
-
177
- # last one is not necessary
178
- if i < n_layers - 1:
179
- res_skip_channels = 2 * hidden_channels
180
- else:
181
- res_skip_channels = hidden_channels
182
-
183
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
184
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
185
- self.res_skip_layers.append(res_skip_layer)
186
-
187
- def forward(self, x, x_mask, g=None, **kwargs):
188
- output = torch.zeros_like(x)
189
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
190
-
191
- if g is not None:
192
- g = self.cond_layer(g)
193
-
194
- for i in range(self.n_layers):
195
- x_in = self.in_layers[i](x)
196
- if g is not None:
197
- cond_offset = i * 2 * self.hidden_channels
198
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
199
- else:
200
- g_l = torch.zeros_like(x_in)
201
-
202
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
203
- acts = self.drop(acts)
204
-
205
- res_skip_acts = self.res_skip_layers[i](acts)
206
- if i < self.n_layers - 1:
207
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
208
- x = (x + res_acts) * x_mask
209
- output = output + res_skip_acts[:, self.hidden_channels :, :]
210
- else:
211
- output = output + res_skip_acts
212
- return output * x_mask
213
-
214
- def remove_weight_norm(self):
215
- if self.gin_channels != 0:
216
- torch.nn.utils.remove_weight_norm(self.cond_layer)
217
- for l in self.in_layers:
218
- torch.nn.utils.remove_weight_norm(l)
219
- for l in self.res_skip_layers:
220
- torch.nn.utils.remove_weight_norm(l)
221
-
222
-
223
- class ResBlock1(torch.nn.Module):
224
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
225
- super(ResBlock1, self).__init__()
226
- self.convs1 = nn.ModuleList(
227
- [
228
- weight_norm(
229
- Conv1d(
230
- channels,
231
- channels,
232
- kernel_size,
233
- 1,
234
- dilation=dilation[0],
235
- padding=get_padding(kernel_size, dilation[0]),
236
- )
237
- ),
238
- weight_norm(
239
- Conv1d(
240
- channels,
241
- channels,
242
- kernel_size,
243
- 1,
244
- dilation=dilation[1],
245
- padding=get_padding(kernel_size, dilation[1]),
246
- )
247
- ),
248
- weight_norm(
249
- Conv1d(
250
- channels,
251
- channels,
252
- kernel_size,
253
- 1,
254
- dilation=dilation[2],
255
- padding=get_padding(kernel_size, dilation[2]),
256
- )
257
- ),
258
- ]
259
- )
260
- self.convs1.apply(init_weights)
261
-
262
- self.convs2 = nn.ModuleList(
263
- [
264
- weight_norm(
265
- Conv1d(
266
- channels,
267
- channels,
268
- kernel_size,
269
- 1,
270
- dilation=1,
271
- padding=get_padding(kernel_size, 1),
272
- )
273
- ),
274
- weight_norm(
275
- Conv1d(
276
- channels,
277
- channels,
278
- kernel_size,
279
- 1,
280
- dilation=1,
281
- padding=get_padding(kernel_size, 1),
282
- )
283
- ),
284
- weight_norm(
285
- Conv1d(
286
- channels,
287
- channels,
288
- kernel_size,
289
- 1,
290
- dilation=1,
291
- padding=get_padding(kernel_size, 1),
292
- )
293
- ),
294
- ]
295
- )
296
- self.convs2.apply(init_weights)
297
-
298
- def forward(self, x, x_mask=None):
299
- for c1, c2 in zip(self.convs1, self.convs2):
300
- xt = F.leaky_relu(x, LRELU_SLOPE)
301
- if x_mask is not None:
302
- xt = xt * x_mask
303
- xt = c1(xt)
304
- xt = F.leaky_relu(xt, LRELU_SLOPE)
305
- if x_mask is not None:
306
- xt = xt * x_mask
307
- xt = c2(xt)
308
- x = xt + x
309
- if x_mask is not None:
310
- x = x * x_mask
311
- return x
312
-
313
- def remove_weight_norm(self):
314
- for l in self.convs1:
315
- remove_weight_norm(l)
316
- for l in self.convs2:
317
- remove_weight_norm(l)
318
-
319
-
320
- class ResBlock2(torch.nn.Module):
321
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
322
- super(ResBlock2, self).__init__()
323
- self.convs = nn.ModuleList(
324
- [
325
- weight_norm(
326
- Conv1d(
327
- channels,
328
- channels,
329
- kernel_size,
330
- 1,
331
- dilation=dilation[0],
332
- padding=get_padding(kernel_size, dilation[0]),
333
- )
334
- ),
335
- weight_norm(
336
- Conv1d(
337
- channels,
338
- channels,
339
- kernel_size,
340
- 1,
341
- dilation=dilation[1],
342
- padding=get_padding(kernel_size, dilation[1]),
343
- )
344
- ),
345
- ]
346
- )
347
- self.convs.apply(init_weights)
348
-
349
- def forward(self, x, x_mask=None):
350
- for c in self.convs:
351
- xt = F.leaky_relu(x, LRELU_SLOPE)
352
- if x_mask is not None:
353
- xt = xt * x_mask
354
- xt = c(xt)
355
- x = xt + x
356
- if x_mask is not None:
357
- x = x * x_mask
358
- return x
359
-
360
- def remove_weight_norm(self):
361
- for l in self.convs:
362
- remove_weight_norm(l)
363
-
364
-
365
- class Log(nn.Module):
366
- def forward(self, x, x_mask, reverse=False, **kwargs):
367
- if not reverse:
368
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
369
- logdet = torch.sum(-y, [1, 2])
370
- return y, logdet
371
- else:
372
- x = torch.exp(x) * x_mask
373
- return x
374
-
375
-
376
- class Flip(nn.Module):
377
- def forward(self, x, *args, reverse=False, **kwargs):
378
- x = torch.flip(x, [1])
379
- if not reverse:
380
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
381
- return x, logdet
382
- else:
383
- return x
384
-
385
-
386
- class ElementwiseAffine(nn.Module):
387
- def __init__(self, channels):
388
- super().__init__()
389
- self.channels = channels
390
- self.m = nn.Parameter(torch.zeros(channels, 1))
391
- self.logs = nn.Parameter(torch.zeros(channels, 1))
392
-
393
- def forward(self, x, x_mask, reverse=False, **kwargs):
394
- if not reverse:
395
- y = self.m + torch.exp(self.logs) * x
396
- y = y * x_mask
397
- logdet = torch.sum(self.logs * x_mask, [1, 2])
398
- return y, logdet
399
- else:
400
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
401
- return x
402
-
403
-
404
- class ResidualCouplingLayer(nn.Module):
405
- def __init__(
406
- self,
407
- channels,
408
- hidden_channels,
409
- kernel_size,
410
- dilation_rate,
411
- n_layers,
412
- p_dropout=0,
413
- gin_channels=0,
414
- mean_only=False,
415
- ):
416
- assert channels % 2 == 0, "channels should be divisible by 2"
417
- super().__init__()
418
- self.channels = channels
419
- self.hidden_channels = hidden_channels
420
- self.kernel_size = kernel_size
421
- self.dilation_rate = dilation_rate
422
- self.n_layers = n_layers
423
- self.half_channels = channels // 2
424
- self.mean_only = mean_only
425
-
426
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
427
- self.enc = WN(
428
- hidden_channels,
429
- kernel_size,
430
- dilation_rate,
431
- n_layers,
432
- p_dropout=p_dropout,
433
- gin_channels=gin_channels,
434
- )
435
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
436
- self.post.weight.data.zero_()
437
- self.post.bias.data.zero_()
438
-
439
- def forward(self, x, x_mask, g=None, reverse=False):
440
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
441
- h = self.pre(x0) * x_mask
442
- h = self.enc(h, x_mask, g=g)
443
- stats = self.post(h) * x_mask
444
- if not self.mean_only:
445
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
446
- else:
447
- m = stats
448
- logs = torch.zeros_like(m)
449
-
450
- if not reverse:
451
- x1 = m + x1 * torch.exp(logs) * x_mask
452
- x = torch.cat([x0, x1], 1)
453
- logdet = torch.sum(logs, [1, 2])
454
- return x, logdet
455
- else:
456
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
457
- x = torch.cat([x0, x1], 1)
458
- return x
459
-
460
-
461
- class ConvFlow(nn.Module):
462
- def __init__(
463
- self,
464
- in_channels,
465
- filter_channels,
466
- kernel_size,
467
- n_layers,
468
- num_bins=10,
469
- tail_bound=5.0,
470
- ):
471
- super().__init__()
472
- self.in_channels = in_channels
473
- self.filter_channels = filter_channels
474
- self.kernel_size = kernel_size
475
- self.n_layers = n_layers
476
- self.num_bins = num_bins
477
- self.tail_bound = tail_bound
478
- self.half_channels = in_channels // 2
479
-
480
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
481
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
482
- self.proj = nn.Conv1d(
483
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
484
- )
485
- self.proj.weight.data.zero_()
486
- self.proj.bias.data.zero_()
487
-
488
- def forward(self, x, x_mask, g=None, reverse=False):
489
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
490
- h = self.pre(x0)
491
- h = self.convs(h, x_mask, g=g)
492
- h = self.proj(h) * x_mask
493
-
494
- b, c, t = x0.shape
495
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
496
-
497
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
498
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
499
- self.filter_channels
500
- )
501
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
502
-
503
- x1, logabsdet = piecewise_rational_quadratic_transform(
504
- x1,
505
- unnormalized_widths,
506
- unnormalized_heights,
507
- unnormalized_derivatives,
508
- inverse=reverse,
509
- tails="linear",
510
- tail_bound=self.tail_bound,
511
- )
512
-
513
- x = torch.cat([x0, x1], 1) * x_mask
514
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
515
- if not reverse:
516
- return x, logdet
517
- else:
518
- return x
519
-
520
-
521
- class LinearNorm(nn.Module):
522
- def __init__(
523
- self,
524
- in_channels,
525
- out_channels,
526
- bias=True,
527
- spectral_norm=False,
528
- ):
529
- super(LinearNorm, self).__init__()
530
- self.fc = nn.Linear(in_channels, out_channels, bias)
531
-
532
- if spectral_norm:
533
- self.fc = nn.utils.spectral_norm(self.fc)
534
-
535
- def forward(self, input):
536
- out = self.fc(input)
537
- return out
538
-
539
-
540
- class Mish(nn.Module):
541
- def __init__(self):
542
- super(Mish, self).__init__()
543
-
544
- def forward(self, x):
545
- return x * torch.tanh(F.softplus(x))
546
-
547
-
548
- class Conv1dGLU(nn.Module):
549
- """
550
- Conv1d + GLU(Gated Linear Unit) with residual connection.
551
- For GLU refer to https://arxiv.org/abs/1612.08083 paper.
552
- """
553
-
554
- def __init__(self, in_channels, out_channels, kernel_size, dropout):
555
- super(Conv1dGLU, self).__init__()
556
- self.out_channels = out_channels
557
- self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
558
- self.dropout = nn.Dropout(dropout)
559
-
560
- def forward(self, x):
561
- residual = x
562
- x = self.conv1(x)
563
- x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
564
- x = x1 * torch.sigmoid(x2)
565
- x = residual + self.dropout(x)
566
- return x
567
-
568
-
569
- class ConvNorm(nn.Module):
570
- def __init__(
571
- self,
572
- in_channels,
573
- out_channels,
574
- kernel_size=1,
575
- stride=1,
576
- padding=None,
577
- dilation=1,
578
- bias=True,
579
- spectral_norm=False,
580
- ):
581
- super(ConvNorm, self).__init__()
582
-
583
- if padding is None:
584
- assert kernel_size % 2 == 1
585
- padding = int(dilation * (kernel_size - 1) / 2)
586
-
587
- self.conv = torch.nn.Conv1d(
588
- in_channels,
589
- out_channels,
590
- kernel_size=kernel_size,
591
- stride=stride,
592
- padding=padding,
593
- dilation=dilation,
594
- bias=bias,
595
- )
596
-
597
- if spectral_norm:
598
- self.conv = nn.utils.spectral_norm(self.conv)
599
-
600
- def forward(self, input):
601
- out = self.conv(input)
602
- return out
603
-
604
-
605
- class MultiHeadAttention(nn.Module):
606
- """Multi-Head Attention module"""
607
-
608
- def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
609
- super().__init__()
610
-
611
- self.n_head = n_head
612
- self.d_k = d_k
613
- self.d_v = d_v
614
-
615
- self.w_qs = nn.Linear(d_model, n_head * d_k)
616
- self.w_ks = nn.Linear(d_model, n_head * d_k)
617
- self.w_vs = nn.Linear(d_model, n_head * d_v)
618
-
619
- self.attention = ScaledDotProductAttention(
620
- temperature=np.power(d_model, 0.5), dropout=dropout
621
- )
622
-
623
- self.fc = nn.Linear(n_head * d_v, d_model)
624
- self.dropout = nn.Dropout(dropout)
625
-
626
- if spectral_norm:
627
- self.w_qs = nn.utils.spectral_norm(self.w_qs)
628
- self.w_ks = nn.utils.spectral_norm(self.w_ks)
629
- self.w_vs = nn.utils.spectral_norm(self.w_vs)
630
- self.fc = nn.utils.spectral_norm(self.fc)
631
-
632
- def forward(self, x, mask=None):
633
- d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
634
- sz_b, len_x, _ = x.size()
635
-
636
- residual = x
637
-
638
- q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
639
- k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
640
- v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
641
- q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
642
- k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
643
- v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
644
-
645
- if mask is not None:
646
- slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
647
- else:
648
- slf_mask = None
649
- output, attn = self.attention(q, k, v, mask=slf_mask)
650
-
651
- output = output.view(n_head, sz_b, len_x, d_v)
652
- output = (
653
- output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
654
- ) # b x lq x (n*dv)
655
-
656
- output = self.fc(output)
657
-
658
- output = self.dropout(output) + residual
659
- return output, attn
660
-
661
-
662
- class ScaledDotProductAttention(nn.Module):
663
- """Scaled Dot-Product Attention"""
664
-
665
- def __init__(self, temperature, dropout):
666
- super().__init__()
667
- self.temperature = temperature
668
- self.softmax = nn.Softmax(dim=2)
669
- self.dropout = nn.Dropout(dropout)
670
-
671
- def forward(self, q, k, v, mask=None):
672
- attn = torch.bmm(q, k.transpose(1, 2))
673
- attn = attn / self.temperature
674
-
675
- if mask is not None:
676
- attn = attn.masked_fill(mask, -np.inf)
677
-
678
- attn = self.softmax(attn)
679
- p_attn = self.dropout(attn)
680
-
681
- output = torch.bmm(p_attn, v)
682
- return output, attn
683
-
684
-
685
- class MelStyleEncoder(nn.Module):
686
- """MelStyleEncoder"""
687
-
688
- def __init__(
689
- self,
690
- n_mel_channels=80,
691
- style_hidden=128,
692
- style_vector_dim=256,
693
- style_kernel_size=5,
694
- style_head=2,
695
- dropout=0.1,
696
- ):
697
- super(MelStyleEncoder, self).__init__()
698
- self.in_dim = n_mel_channels
699
- self.hidden_dim = style_hidden
700
- self.out_dim = style_vector_dim
701
- self.kernel_size = style_kernel_size
702
- self.n_head = style_head
703
- self.dropout = dropout
704
-
705
- self.spectral = nn.Sequential(
706
- LinearNorm(self.in_dim, self.hidden_dim),
707
- Mish(),
708
- nn.Dropout(self.dropout),
709
- LinearNorm(self.hidden_dim, self.hidden_dim),
710
- Mish(),
711
- nn.Dropout(self.dropout),
712
- )
713
-
714
- self.temporal = nn.Sequential(
715
- Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
716
- Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
717
- )
718
-
719
- self.slf_attn = MultiHeadAttention(
720
- self.n_head,
721
- self.hidden_dim,
722
- self.hidden_dim // self.n_head,
723
- self.hidden_dim // self.n_head,
724
- self.dropout,
725
- )
726
-
727
- self.fc = LinearNorm(self.hidden_dim, self.out_dim)
728
-
729
- def temporal_avg_pool(self, x, mask=None):
730
- if mask is None:
731
- out = torch.mean(x, dim=1)
732
- else:
733
- len_ = (~mask).sum(dim=1).unsqueeze(1)
734
- x = x.masked_fill(mask.unsqueeze(-1), 0)
735
- x = x.sum(dim=1)
736
- out = torch.div(x, len_)
737
- return out
738
-
739
- def forward(self, x, mask=None):
740
- x = x.transpose(1, 2)
741
- if mask is not None:
742
- mask = (mask.int() == 0).squeeze(1)
743
- max_len = x.shape[1]
744
- slf_attn_mask = (
745
- mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
746
- )
747
-
748
- # spectral
749
- x = self.spectral(x)
750
- # temporal
751
- x = x.transpose(1, 2)
752
- x = self.temporal(x)
753
- x = x.transpose(1, 2)
754
- # self-attention
755
- if mask is not None:
756
- x = x.masked_fill(mask.unsqueeze(-1), 0)
757
- x, _ = self.slf_attn(x, mask=slf_attn_mask)
758
- # fc
759
- x = self.fc(x)
760
- # temoral average pooling
761
- w = self.temporal_avg_pool(x, mask=mask)
762
-
763
- return w.unsqueeze(-1)
764
-
765
-
766
- class MelStyleEncoderVAE(nn.Module):
767
- def __init__(self, spec_channels, z_latent_dim, emb_dim):
768
- super().__init__()
769
- self.ref_encoder = MelStyleEncoder(spec_channels, style_vector_dim=emb_dim)
770
- self.fc1 = nn.Linear(emb_dim, z_latent_dim)
771
- self.fc2 = nn.Linear(emb_dim, z_latent_dim)
772
- self.fc3 = nn.Linear(z_latent_dim, emb_dim)
773
- self.z_latent_dim = z_latent_dim
774
-
775
- def reparameterize(self, mu, logvar):
776
- if self.training:
777
- std = torch.exp(0.5 * logvar)
778
- eps = torch.randn_like(std)
779
- return eps.mul(std).add_(mu)
780
- else:
781
- return mu
782
-
783
- def forward(self, inputs, mask=None):
784
- enc_out = self.ref_encoder(inputs.squeeze(-1), mask).squeeze(-1)
785
- mu = self.fc1(enc_out)
786
- logvar = self.fc2(enc_out)
787
- posterior = D.Normal(mu, torch.exp(logvar))
788
- kl_divergence = D.kl_divergence(
789
- posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
790
- )
791
- loss_kl = kl_divergence.mean()
792
-
793
- z = posterior.rsample()
794
- style_embed = self.fc3(z)
795
-
796
- return style_embed.unsqueeze(-1), loss_kl
797
-
798
- def infer(self, inputs=None, random_sample=False, manual_latent=None):
799
- if manual_latent is None:
800
- if random_sample:
801
- dev = next(self.parameters()).device
802
- posterior = D.Normal(
803
- torch.zeros(1, self.z_latent_dim, device=dev),
804
- torch.ones(1, self.z_latent_dim, device=dev),
805
- )
806
- z = posterior.rsample()
807
- else:
808
- enc_out = self.ref_encoder(inputs.transpose(1, 2))
809
- mu = self.fc1(enc_out)
810
- z = mu
811
- else:
812
- z = manual_latent
813
- style_embed = self.fc3(z)
814
- return style_embed.unsqueeze(-1), z
815
-
816
-
817
- class ActNorm(nn.Module):
818
- def __init__(self, channels, ddi=False, **kwargs):
819
- super().__init__()
820
- self.channels = channels
821
- self.initialized = not ddi
822
-
823
- self.logs = nn.Parameter(torch.zeros(1, channels, 1))
824
- self.bias = nn.Parameter(torch.zeros(1, channels, 1))
825
-
826
- def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
827
- if x_mask is None:
828
- x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
829
- device=x.device, dtype=x.dtype
830
- )
831
- x_len = torch.sum(x_mask, [1, 2])
832
- if not self.initialized:
833
- self.initialize(x, x_mask)
834
- self.initialized = True
835
-
836
- if reverse:
837
- z = (x - self.bias) * torch.exp(-self.logs) * x_mask
838
- logdet = None
839
- return z
840
- else:
841
- z = (self.bias + torch.exp(self.logs) * x) * x_mask
842
- logdet = torch.sum(self.logs) * x_len # [b]
843
- return z, logdet
844
-
845
- def store_inverse(self):
846
- pass
847
-
848
- def set_ddi(self, ddi):
849
- self.initialized = not ddi
850
-
851
- def initialize(self, x, x_mask):
852
- with torch.no_grad():
853
- denom = torch.sum(x_mask, [0, 2])
854
- m = torch.sum(x * x_mask, [0, 2]) / denom
855
- m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
856
- v = m_sq - (m**2)
857
- logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
858
-
859
- bias_init = (
860
- (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
861
- )
862
- logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
863
-
864
- self.bias.data.copy_(bias_init)
865
- self.logs.data.copy_(logs_init)
866
-
867
-
868
- class InvConvNear(nn.Module):
869
- def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
870
- super().__init__()
871
- assert n_split % 2 == 0
872
- self.channels = channels
873
- self.n_split = n_split
874
- self.no_jacobian = no_jacobian
875
-
876
- w_init = torch.linalg.qr(
877
- torch.FloatTensor(self.n_split, self.n_split).normal_()
878
- )[0]
879
- if torch.det(w_init) < 0:
880
- w_init[:, 0] = -1 * w_init[:, 0]
881
- self.weight = nn.Parameter(w_init)
882
-
883
- def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
884
- b, c, t = x.size()
885
- assert c % self.n_split == 0
886
- if x_mask is None:
887
- x_mask = 1
888
- x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
889
- else:
890
- x_len = torch.sum(x_mask, [1, 2])
891
-
892
- x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
893
- x = (
894
- x.permute(0, 1, 3, 2, 4)
895
- .contiguous()
896
- .view(b, self.n_split, c // self.n_split, t)
897
- )
898
-
899
- if reverse:
900
- if hasattr(self, "weight_inv"):
901
- weight = self.weight_inv
902
- else:
903
- weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
904
- logdet = None
905
- else:
906
- weight = self.weight
907
- if self.no_jacobian:
908
- logdet = 0
909
- else:
910
- logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
911
-
912
- weight = weight.view(self.n_split, self.n_split, 1, 1)
913
- z = F.conv2d(x, weight)
914
-
915
- z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
916
- z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
917
- if reverse:
918
- return z
919
- else:
920
- return z, logdet
921
-
922
- def store_inverse(self):
923
- self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/mrte_model.py DELETED
@@ -1,192 +0,0 @@
1
- # This is Multi-reference timbre encoder
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn.utils import remove_weight_norm, weight_norm
6
- from module.attentions import MultiHeadAttention
7
-
8
-
9
- class MRTE(nn.Module):
10
- def __init__(
11
- self,
12
- content_enc_channels=192,
13
- hidden_size=512,
14
- out_channels=192,
15
- kernel_size=5,
16
- n_heads=4,
17
- ge_layer=2,
18
- ):
19
- super(MRTE, self).__init__()
20
- self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
21
- self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
22
- self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
23
- self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
24
-
25
- def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
26
- if ge == None:
27
- ge = 0
28
- attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
29
-
30
- ssl_enc = self.c_pre(ssl_enc * ssl_mask)
31
- text_enc = self.text_pre(text * text_mask)
32
- if test != None:
33
- if test == 0:
34
- x = (
35
- self.cross_attention(
36
- ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
37
- )
38
- + ssl_enc
39
- + ge
40
- )
41
- elif test == 1:
42
- x = ssl_enc + ge
43
- elif test == 2:
44
- x = (
45
- self.cross_attention(
46
- ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
47
- )
48
- + ge
49
- )
50
- else:
51
- raise ValueError("test should be 0,1,2")
52
- else:
53
- x = (
54
- self.cross_attention(
55
- ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
56
- )
57
- + ssl_enc
58
- + ge
59
- )
60
- x = self.c_post(x * ssl_mask)
61
- return x
62
-
63
-
64
- class SpeakerEncoder(torch.nn.Module):
65
- def __init__(
66
- self,
67
- mel_n_channels=80,
68
- model_num_layers=2,
69
- model_hidden_size=256,
70
- model_embedding_size=256,
71
- ):
72
- super(SpeakerEncoder, self).__init__()
73
- self.lstm = nn.LSTM(
74
- mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
75
- )
76
- self.linear = nn.Linear(model_hidden_size, model_embedding_size)
77
- self.relu = nn.ReLU()
78
-
79
- def forward(self, mels):
80
- self.lstm.flatten_parameters()
81
- _, (hidden, _) = self.lstm(mels.transpose(-1, -2))
82
- embeds_raw = self.relu(self.linear(hidden[-1]))
83
- return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
84
-
85
-
86
- class MELEncoder(nn.Module):
87
- def __init__(
88
- self,
89
- in_channels,
90
- out_channels,
91
- hidden_channels,
92
- kernel_size,
93
- dilation_rate,
94
- n_layers,
95
- ):
96
- super().__init__()
97
- self.in_channels = in_channels
98
- self.out_channels = out_channels
99
- self.hidden_channels = hidden_channels
100
- self.kernel_size = kernel_size
101
- self.dilation_rate = dilation_rate
102
- self.n_layers = n_layers
103
-
104
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
105
- self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers)
106
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
107
-
108
- def forward(self, x):
109
- # print(x.shape,x_lengths.shape)
110
- x = self.pre(x)
111
- x = self.enc(x)
112
- x = self.proj(x)
113
- return x
114
-
115
-
116
- class WN(torch.nn.Module):
117
- def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
118
- super(WN, self).__init__()
119
- assert kernel_size % 2 == 1
120
- self.hidden_channels = hidden_channels
121
- self.kernel_size = kernel_size
122
- self.dilation_rate = dilation_rate
123
- self.n_layers = n_layers
124
-
125
- self.in_layers = torch.nn.ModuleList()
126
- self.res_skip_layers = torch.nn.ModuleList()
127
-
128
- for i in range(n_layers):
129
- dilation = dilation_rate**i
130
- padding = int((kernel_size * dilation - dilation) / 2)
131
- in_layer = nn.Conv1d(
132
- hidden_channels,
133
- 2 * hidden_channels,
134
- kernel_size,
135
- dilation=dilation,
136
- padding=padding,
137
- )
138
- in_layer = weight_norm(in_layer)
139
- self.in_layers.append(in_layer)
140
-
141
- # last one is not necessary
142
- if i < n_layers - 1:
143
- res_skip_channels = 2 * hidden_channels
144
- else:
145
- res_skip_channels = hidden_channels
146
-
147
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
148
- res_skip_layer = weight_norm(res_skip_layer, name="weight")
149
- self.res_skip_layers.append(res_skip_layer)
150
-
151
- def forward(self, x):
152
- output = torch.zeros_like(x)
153
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
154
-
155
- for i in range(self.n_layers):
156
- x_in = self.in_layers[i](x)
157
-
158
- acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
159
-
160
- res_skip_acts = self.res_skip_layers[i](acts)
161
- if i < self.n_layers - 1:
162
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
163
- x = x + res_acts
164
- output = output + res_skip_acts[:, self.hidden_channels :, :]
165
- else:
166
- output = output + res_skip_acts
167
- return output
168
-
169
- def remove_weight_norm(self):
170
- for l in self.in_layers:
171
- remove_weight_norm(l)
172
- for l in self.res_skip_layers:
173
- remove_weight_norm(l)
174
-
175
-
176
- @torch.jit.script
177
- def fused_add_tanh_sigmoid_multiply(input, n_channels):
178
- n_channels_int = n_channels[0]
179
- t_act = torch.tanh(input[:, :n_channels_int, :])
180
- s_act = torch.sigmoid(input[:, n_channels_int:, :])
181
- acts = t_act * s_act
182
- return acts
183
-
184
-
185
- if __name__ == "__main__":
186
- content_enc = torch.randn(3, 192, 100)
187
- content_mask = torch.ones(3, 1, 100)
188
- ref_mel = torch.randn(3, 128, 30)
189
- ref_mask = torch.ones(3, 1, 30)
190
- model = MRTE()
191
- out = model(content_enc, content_mask, ref_mel, ref_mask)
192
- print(out.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/quantize.py DELETED
@@ -1,119 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Residual vector quantizer implementation."""
8
-
9
- from dataclasses import dataclass, field
10
- import math
11
- import typing as tp
12
-
13
- import torch
14
- from torch import nn
15
-
16
- from module.core_vq import ResidualVectorQuantization
17
-
18
-
19
- @dataclass
20
- class QuantizedResult:
21
- quantized: torch.Tensor
22
- codes: torch.Tensor
23
- bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
24
- penalty: tp.Optional[torch.Tensor] = None
25
- metrics: dict = field(default_factory=dict)
26
-
27
-
28
- class ResidualVectorQuantizer(nn.Module):
29
- """Residual Vector Quantizer.
30
- Args:
31
- dimension (int): Dimension of the codebooks.
32
- n_q (int): Number of residual vector quantizers used.
33
- bins (int): Codebook size.
34
- decay (float): Decay for exponential moving average over the codebooks.
35
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
36
- kmeans_iters (int): Number of iterations used for kmeans initialization.
37
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
38
- that have an exponential moving average cluster size less than the specified threshold with
39
- randomly selected vector from the current batch.
40
- """
41
-
42
- def __init__(
43
- self,
44
- dimension: int = 256,
45
- n_q: int = 8,
46
- bins: int = 1024,
47
- decay: float = 0.99,
48
- kmeans_init: bool = True,
49
- kmeans_iters: int = 50,
50
- threshold_ema_dead_code: int = 2,
51
- ):
52
- super().__init__()
53
- self.n_q = n_q
54
- self.dimension = dimension
55
- self.bins = bins
56
- self.decay = decay
57
- self.kmeans_init = kmeans_init
58
- self.kmeans_iters = kmeans_iters
59
- self.threshold_ema_dead_code = threshold_ema_dead_code
60
- self.vq = ResidualVectorQuantization(
61
- dim=self.dimension,
62
- codebook_size=self.bins,
63
- num_quantizers=self.n_q,
64
- decay=self.decay,
65
- kmeans_init=self.kmeans_init,
66
- kmeans_iters=self.kmeans_iters,
67
- threshold_ema_dead_code=self.threshold_ema_dead_code,
68
- )
69
-
70
- def forward(
71
- self,
72
- x: torch.Tensor,
73
- n_q: tp.Optional[int] = None,
74
- layers: tp.Optional[list] = None,
75
- ) -> QuantizedResult:
76
- """Residual vector quantization on the given input tensor.
77
- Args:
78
- x (torch.Tensor): Input tensor.
79
- n_q (int): Number of quantizer used to quantize. Default: All quantizers.
80
- layers (list): Layer that need to return quantized. Defalt: None.
81
- Returns:
82
- QuantizedResult:
83
- The quantized (or approximately quantized) representation with
84
- the associated numbert quantizers and layer quantized required to return.
85
- """
86
- n_q = n_q if n_q else self.n_q
87
- if layers and max(layers) >= n_q:
88
- raise ValueError(
89
- f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
90
- )
91
- quantized, codes, commit_loss, quantized_list = self.vq(
92
- x, n_q=n_q, layers=layers
93
- )
94
- return quantized, codes, torch.mean(commit_loss), quantized_list
95
-
96
- def encode(
97
- self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
98
- ) -> torch.Tensor:
99
- """Encode a given input tensor with the specified sample rate at the given bandwidth.
100
- The RVQ encode method sets the appropriate number of quantizer to use
101
- and returns indices for each quantizer.
102
- Args:
103
- x (torch.Tensor): Input tensor.
104
- n_q (int): Number of quantizer used to quantize. Default: All quantizers.
105
- st (int): Start to encode input from which layers. Default: 0.
106
- """
107
- n_q = n_q if n_q else self.n_q
108
- st = st or 0
109
- codes = self.vq.encode(x, n_q=n_q, st=st)
110
- return codes
111
-
112
- def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
113
- """Decode the given codes to the quantized representation.
114
- Args:
115
- codes (torch.Tensor): Input indices for each quantizer.
116
- st (int): Start to decode input codes from which layers. Default: 0.
117
- """
118
- quantized = self.vq.decode(codes, st=st)
119
- return quantized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
module/transforms.py DELETED
@@ -1,209 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
-
4
- import numpy as np
5
-
6
-
7
- DEFAULT_MIN_BIN_WIDTH = 1e-3
8
- DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
- DEFAULT_MIN_DERIVATIVE = 1e-3
10
-
11
-
12
- def piecewise_rational_quadratic_transform(
13
- inputs,
14
- unnormalized_widths,
15
- unnormalized_heights,
16
- unnormalized_derivatives,
17
- inverse=False,
18
- tails=None,
19
- tail_bound=1.0,
20
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
- min_derivative=DEFAULT_MIN_DERIVATIVE,
23
- ):
24
- if tails is None:
25
- spline_fn = rational_quadratic_spline
26
- spline_kwargs = {}
27
- else:
28
- spline_fn = unconstrained_rational_quadratic_spline
29
- spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
-
31
- outputs, logabsdet = spline_fn(
32
- inputs=inputs,
33
- unnormalized_widths=unnormalized_widths,
34
- unnormalized_heights=unnormalized_heights,
35
- unnormalized_derivatives=unnormalized_derivatives,
36
- inverse=inverse,
37
- min_bin_width=min_bin_width,
38
- min_bin_height=min_bin_height,
39
- min_derivative=min_derivative,
40
- **spline_kwargs
41
- )
42
- return outputs, logabsdet
43
-
44
-
45
- def searchsorted(bin_locations, inputs, eps=1e-6):
46
- bin_locations[..., -1] += eps
47
- return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
-
49
-
50
- def unconstrained_rational_quadratic_spline(
51
- inputs,
52
- unnormalized_widths,
53
- unnormalized_heights,
54
- unnormalized_derivatives,
55
- inverse=False,
56
- tails="linear",
57
- tail_bound=1.0,
58
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
- min_derivative=DEFAULT_MIN_DERIVATIVE,
61
- ):
62
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
- outside_interval_mask = ~inside_interval_mask
64
-
65
- outputs = torch.zeros_like(inputs)
66
- logabsdet = torch.zeros_like(inputs)
67
-
68
- if tails == "linear":
69
- unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
- constant = np.log(np.exp(1 - min_derivative) - 1)
71
- unnormalized_derivatives[..., 0] = constant
72
- unnormalized_derivatives[..., -1] = constant
73
-
74
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
- logabsdet[outside_interval_mask] = 0
76
- else:
77
- raise RuntimeError("{} tails are not implemented.".format(tails))
78
-
79
- (
80
- outputs[inside_interval_mask],
81
- logabsdet[inside_interval_mask],
82
- ) = rational_quadratic_spline(
83
- inputs=inputs[inside_interval_mask],
84
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
- inverse=inverse,
88
- left=-tail_bound,
89
- right=tail_bound,
90
- bottom=-tail_bound,
91
- top=tail_bound,
92
- min_bin_width=min_bin_width,
93
- min_bin_height=min_bin_height,
94
- min_derivative=min_derivative,
95
- )
96
-
97
- return outputs, logabsdet
98
-
99
-
100
- def rational_quadratic_spline(
101
- inputs,
102
- unnormalized_widths,
103
- unnormalized_heights,
104
- unnormalized_derivatives,
105
- inverse=False,
106
- left=0.0,
107
- right=1.0,
108
- bottom=0.0,
109
- top=1.0,
110
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
- min_derivative=DEFAULT_MIN_DERIVATIVE,
113
- ):
114
- if torch.min(inputs) < left or torch.max(inputs) > right:
115
- raise ValueError("Input to a transform is not within its domain")
116
-
117
- num_bins = unnormalized_widths.shape[-1]
118
-
119
- if min_bin_width * num_bins > 1.0:
120
- raise ValueError("Minimal bin width too large for the number of bins")
121
- if min_bin_height * num_bins > 1.0:
122
- raise ValueError("Minimal bin height too large for the number of bins")
123
-
124
- widths = F.softmax(unnormalized_widths, dim=-1)
125
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
- cumwidths = torch.cumsum(widths, dim=-1)
127
- cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
- cumwidths = (right - left) * cumwidths + left
129
- cumwidths[..., 0] = left
130
- cumwidths[..., -1] = right
131
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
-
133
- derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
-
135
- heights = F.softmax(unnormalized_heights, dim=-1)
136
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
- cumheights = torch.cumsum(heights, dim=-1)
138
- cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
- cumheights = (top - bottom) * cumheights + bottom
140
- cumheights[..., 0] = bottom
141
- cumheights[..., -1] = top
142
- heights = cumheights[..., 1:] - cumheights[..., :-1]
143
-
144
- if inverse:
145
- bin_idx = searchsorted(cumheights, inputs)[..., None]
146
- else:
147
- bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
-
149
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
-
152
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
- delta = heights / widths
154
- input_delta = delta.gather(-1, bin_idx)[..., 0]
155
-
156
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
-
159
- input_heights = heights.gather(-1, bin_idx)[..., 0]
160
-
161
- if inverse:
162
- a = (inputs - input_cumheights) * (
163
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
- ) + input_heights * (input_delta - input_derivatives)
165
- b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
- )
168
- c = -input_delta * (inputs - input_cumheights)
169
-
170
- discriminant = b.pow(2) - 4 * a * c
171
- assert (discriminant >= 0).all()
172
-
173
- root = (2 * c) / (-b - torch.sqrt(discriminant))
174
- outputs = root * input_bin_widths + input_cumwidths
175
-
176
- theta_one_minus_theta = root * (1 - root)
177
- denominator = input_delta + (
178
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
- * theta_one_minus_theta
180
- )
181
- derivative_numerator = input_delta.pow(2) * (
182
- input_derivatives_plus_one * root.pow(2)
183
- + 2 * input_delta * theta_one_minus_theta
184
- + input_derivatives * (1 - root).pow(2)
185
- )
186
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
-
188
- return outputs, -logabsdet
189
- else:
190
- theta = (inputs - input_cumwidths) / input_bin_widths
191
- theta_one_minus_theta = theta * (1 - theta)
192
-
193
- numerator = input_heights * (
194
- input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
- )
196
- denominator = input_delta + (
197
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
- * theta_one_minus_theta
199
- )
200
- outputs = input_cumheights + numerator / denominator
201
-
202
- derivative_numerator = input_delta.pow(2) * (
203
- input_derivatives_plus_one * theta.pow(2)
204
- + 2 * input_delta * theta_one_minus_theta
205
- + input_derivatives * (1 - theta).pow(2)
206
- )
207
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
-
209
- return outputs, logabsdet