lj1995 commited on
Commit
d979016
·
verified ·
1 Parent(s): 0db7fec

Delete s2_train.py

Browse files
Files changed (1) hide show
  1. s2_train.py +0 -601
s2_train.py DELETED
@@ -1,601 +0,0 @@
1
- import warnings
2
- warnings.filterwarnings("ignore")
3
- import utils, os
4
- hps = utils.get_hparams(stage=2)
5
- os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
6
- import torch
7
- from torch.nn import functional as F
8
- from torch.utils.data import DataLoader
9
- from torch.utils.tensorboard import SummaryWriter
10
- import torch.multiprocessing as mp
11
- import torch.distributed as dist, traceback
12
- from torch.nn.parallel import DistributedDataParallel as DDP
13
- from torch.cuda.amp import autocast, GradScaler
14
- from tqdm import tqdm
15
- import logging, traceback
16
-
17
- logging.getLogger("matplotlib").setLevel(logging.INFO)
18
- logging.getLogger("h5py").setLevel(logging.INFO)
19
- logging.getLogger("numba").setLevel(logging.INFO)
20
- from random import randint
21
- from module import commons
22
-
23
- from module.data_utils import (
24
- TextAudioSpeakerLoader,
25
- TextAudioSpeakerCollate,
26
- DistributedBucketSampler,
27
- )
28
- from module.models import (
29
- SynthesizerTrn,
30
- MultiPeriodDiscriminator,
31
- )
32
- from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
33
- from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
34
- from process_ckpt import savee
35
-
36
- torch.backends.cudnn.benchmark = False
37
- torch.backends.cudnn.deterministic = False
38
- ###反正A100fp32更快,那试试tf32吧
39
- torch.backends.cuda.matmul.allow_tf32 = True
40
- torch.backends.cudnn.allow_tf32 = True
41
- torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
42
- # from config import pretrained_s2G,pretrained_s2D
43
- global_step = 0
44
-
45
- device = "cpu" # cuda以外的设备,等mps优化后加入
46
-
47
-
48
- def main():
49
-
50
- if torch.cuda.is_available():
51
- n_gpus = torch.cuda.device_count()
52
- else:
53
- n_gpus = 1
54
- os.environ["MASTER_ADDR"] = "localhost"
55
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
56
-
57
- mp.spawn(
58
- run,
59
- nprocs=n_gpus,
60
- args=(
61
- n_gpus,
62
- hps,
63
- ),
64
- )
65
-
66
-
67
- def run(rank, n_gpus, hps):
68
- global global_step
69
- if rank == 0:
70
- logger = utils.get_logger(hps.data.exp_dir)
71
- logger.info(hps)
72
- # utils.check_git_hash(hps.s2_ckpt_dir)
73
- writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
74
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
75
-
76
- dist.init_process_group(
77
- backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
78
- init_method="env://",
79
- world_size=n_gpus,
80
- rank=rank,
81
- )
82
- torch.manual_seed(hps.train.seed)
83
- if torch.cuda.is_available():
84
- torch.cuda.set_device(rank)
85
-
86
- train_dataset = TextAudioSpeakerLoader(hps.data) ########
87
- train_sampler = DistributedBucketSampler(
88
- train_dataset,
89
- hps.train.batch_size,
90
- [
91
- 32,
92
- 300,
93
- 400,
94
- 500,
95
- 600,
96
- 700,
97
- 800,
98
- 900,
99
- 1000,
100
- 1100,
101
- 1200,
102
- 1300,
103
- 1400,
104
- 1500,
105
- 1600,
106
- 1700,
107
- 1800,
108
- 1900,
109
- ],
110
- num_replicas=n_gpus,
111
- rank=rank,
112
- shuffle=True,
113
- )
114
- collate_fn = TextAudioSpeakerCollate()
115
- train_loader = DataLoader(
116
- train_dataset,
117
- num_workers=6,
118
- shuffle=False,
119
- pin_memory=True,
120
- collate_fn=collate_fn,
121
- batch_sampler=train_sampler,
122
- persistent_workers=True,
123
- prefetch_factor=4,
124
- )
125
- # if rank == 0:
126
- # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
127
- # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
128
- # batch_size=1, pin_memory=True,
129
- # drop_last=False, collate_fn=collate_fn)
130
-
131
- net_g = SynthesizerTrn(
132
- hps.data.filter_length // 2 + 1,
133
- hps.train.segment_size // hps.data.hop_length,
134
- n_speakers=hps.data.n_speakers,
135
- **hps.model,
136
- ).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
137
- hps.data.filter_length // 2 + 1,
138
- hps.train.segment_size // hps.data.hop_length,
139
- n_speakers=hps.data.n_speakers,
140
- **hps.model,
141
- ).to(device)
142
-
143
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
144
- for name, param in net_g.named_parameters():
145
- if not param.requires_grad:
146
- print(name, "not requires_grad")
147
-
148
- te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
149
- et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
150
- mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
151
- base_params = filter(
152
- lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
153
- net_g.parameters(),
154
- )
155
-
156
- # te_p=net_g.enc_p.text_embedding.parameters()
157
- # et_p=net_g.enc_p.encoder_text.parameters()
158
- # mrte_p=net_g.enc_p.mrte.parameters()
159
-
160
- optim_g = torch.optim.AdamW(
161
- # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
162
- [
163
- {"params": base_params, "lr": hps.train.learning_rate},
164
- {
165
- "params": net_g.enc_p.text_embedding.parameters(),
166
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
167
- },
168
- {
169
- "params": net_g.enc_p.encoder_text.parameters(),
170
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
171
- },
172
- {
173
- "params": net_g.enc_p.mrte.parameters(),
174
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
175
- },
176
- ],
177
- hps.train.learning_rate,
178
- betas=hps.train.betas,
179
- eps=hps.train.eps,
180
- )
181
- optim_d = torch.optim.AdamW(
182
- net_d.parameters(),
183
- hps.train.learning_rate,
184
- betas=hps.train.betas,
185
- eps=hps.train.eps,
186
- )
187
- if torch.cuda.is_available():
188
- net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
189
- net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
190
- else:
191
- net_g = net_g.to(device)
192
- net_d = net_d.to(device)
193
-
194
- try: # 如果能加载自动resume
195
- _, _, _, epoch_str = utils.load_checkpoint(
196
- utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
197
- net_d,
198
- optim_d,
199
- ) # D多半加载没事
200
- if rank == 0:
201
- logger.info("loaded D")
202
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
203
- _, _, _, epoch_str = utils.load_checkpoint(
204
- utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
205
- net_g,
206
- optim_g,
207
- )
208
- global_step = (epoch_str - 1) * len(train_loader)
209
- # epoch_str = 1
210
- # global_step = 0
211
- except: # 如果首次不能加载,加载pretrain
212
- # traceback.print_exc()
213
- epoch_str = 1
214
- global_step = 0
215
- if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
216
- if rank == 0:
217
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
218
- print(
219
- net_g.module.load_state_dict(
220
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
221
- strict=False,
222
- ) if torch.cuda.is_available() else net_g.load_state_dict(
223
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
224
- strict=False,
225
- )
226
- ) ##测试不加载优化器
227
- if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
228
- if rank == 0:
229
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
230
- print(
231
- net_d.module.load_state_dict(
232
- torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
233
- ) if torch.cuda.is_available() else net_d.load_state_dict(
234
- torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
235
- )
236
- )
237
-
238
- # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
239
- # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
240
-
241
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
242
- optim_g, gamma=hps.train.lr_decay, last_epoch=-1
243
- )
244
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
245
- optim_d, gamma=hps.train.lr_decay, last_epoch=-1
246
- )
247
- for _ in range(epoch_str):
248
- scheduler_g.step()
249
- scheduler_d.step()
250
-
251
- scaler = GradScaler(enabled=hps.train.fp16_run)
252
-
253
- for epoch in range(epoch_str, hps.train.epochs + 1):
254
- if rank == 0:
255
- train_and_evaluate(
256
- rank,
257
- epoch,
258
- hps,
259
- [net_g, net_d],
260
- [optim_g, optim_d],
261
- [scheduler_g, scheduler_d],
262
- scaler,
263
- # [train_loader, eval_loader], logger, [writer, writer_eval])
264
- [train_loader, None],
265
- logger,
266
- [writer, writer_eval],
267
- )
268
- else:
269
- train_and_evaluate(
270
- rank,
271
- epoch,
272
- hps,
273
- [net_g, net_d],
274
- [optim_g, optim_d],
275
- [scheduler_g, scheduler_d],
276
- scaler,
277
- [train_loader, None],
278
- None,
279
- None,
280
- )
281
- scheduler_g.step()
282
- scheduler_d.step()
283
-
284
-
285
- def train_and_evaluate(
286
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
287
- ):
288
- net_g, net_d = nets
289
- optim_g, optim_d = optims
290
- # scheduler_g, scheduler_d = schedulers
291
- train_loader, eval_loader = loaders
292
- if writers is not None:
293
- writer, writer_eval = writers
294
-
295
- train_loader.batch_sampler.set_epoch(epoch)
296
- global global_step
297
-
298
- net_g.train()
299
- net_d.train()
300
- for batch_idx, (
301
- ssl,
302
- ssl_lengths,
303
- spec,
304
- spec_lengths,
305
- y,
306
- y_lengths,
307
- text,
308
- text_lengths,
309
- ) in enumerate(tqdm(train_loader)):
310
- if torch.cuda.is_available():
311
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
312
- rank, non_blocking=True
313
- )
314
- y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
315
- rank, non_blocking=True
316
- )
317
- ssl = ssl.cuda(rank, non_blocking=True)
318
- ssl.requires_grad = False
319
- # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
320
- text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
321
- rank, non_blocking=True
322
- )
323
- else:
324
- spec, spec_lengths = spec.to(device), spec_lengths.to(device)
325
- y, y_lengths = y.to(device), y_lengths.to(device)
326
- ssl = ssl.to(device)
327
- ssl.requires_grad = False
328
- # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
329
- text, text_lengths = text.to(device), text_lengths.to(device)
330
-
331
- with autocast(enabled=hps.train.fp16_run):
332
- (
333
- y_hat,
334
- kl_ssl,
335
- ids_slice,
336
- x_mask,
337
- z_mask,
338
- (z, z_p, m_p, logs_p, m_q, logs_q),
339
- stats_ssl,
340
- ) = net_g(ssl, spec, spec_lengths, text, text_lengths)
341
-
342
- mel = spec_to_mel_torch(
343
- spec,
344
- hps.data.filter_length,
345
- hps.data.n_mel_channels,
346
- hps.data.sampling_rate,
347
- hps.data.mel_fmin,
348
- hps.data.mel_fmax,
349
- )
350
- y_mel = commons.slice_segments(
351
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
352
- )
353
- y_hat_mel = mel_spectrogram_torch(
354
- y_hat.squeeze(1),
355
- hps.data.filter_length,
356
- hps.data.n_mel_channels,
357
- hps.data.sampling_rate,
358
- hps.data.hop_length,
359
- hps.data.win_length,
360
- hps.data.mel_fmin,
361
- hps.data.mel_fmax,
362
- )
363
-
364
- y = commons.slice_segments(
365
- y, ids_slice * hps.data.hop_length, hps.train.segment_size
366
- ) # slice
367
-
368
- # Discriminator
369
- y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
370
- with autocast(enabled=False):
371
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
372
- y_d_hat_r, y_d_hat_g
373
- )
374
- loss_disc_all = loss_disc
375
- optim_d.zero_grad()
376
- scaler.scale(loss_disc_all).backward()
377
- scaler.unscale_(optim_d)
378
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
379
- scaler.step(optim_d)
380
-
381
- with autocast(enabled=hps.train.fp16_run):
382
- # Generator
383
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
384
- with autocast(enabled=False):
385
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
386
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
387
-
388
- loss_fm = feature_loss(fmap_r, fmap_g)
389
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
390
- loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
391
-
392
- optim_g.zero_grad()
393
- scaler.scale(loss_gen_all).backward()
394
- scaler.unscale_(optim_g)
395
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
396
- scaler.step(optim_g)
397
- scaler.update()
398
-
399
- if rank == 0:
400
- if global_step % hps.train.log_interval == 0:
401
- lr = optim_g.param_groups[0]["lr"]
402
- losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
403
- logger.info(
404
- "Train Epoch: {} [{:.0f}%]".format(
405
- epoch, 100.0 * batch_idx / len(train_loader)
406
- )
407
- )
408
- logger.info([x.item() for x in losses] + [global_step, lr])
409
-
410
- scalar_dict = {
411
- "loss/g/total": loss_gen_all,
412
- "loss/d/total": loss_disc_all,
413
- "learning_rate": lr,
414
- "grad_norm_d": grad_norm_d,
415
- "grad_norm_g": grad_norm_g,
416
- }
417
- scalar_dict.update(
418
- {
419
- "loss/g/fm": loss_fm,
420
- "loss/g/mel": loss_mel,
421
- "loss/g/kl_ssl": kl_ssl,
422
- "loss/g/kl": loss_kl,
423
- }
424
- )
425
-
426
- # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
427
- # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
428
- # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
429
- image_dict = {
430
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
431
- y_mel[0].data.cpu().numpy()
432
- ),
433
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
434
- y_hat_mel[0].data.cpu().numpy()
435
- ),
436
- "all/mel": utils.plot_spectrogram_to_numpy(
437
- mel[0].data.cpu().numpy()
438
- ),
439
- "all/stats_ssl": utils.plot_spectrogram_to_numpy(
440
- stats_ssl[0].data.cpu().numpy()
441
- ),
442
- }
443
- utils.summarize(
444
- writer=writer,
445
- global_step=global_step,
446
- images=image_dict,
447
- scalars=scalar_dict,
448
- )
449
- global_step += 1
450
- if epoch % hps.train.save_every_epoch == 0 and rank == 0:
451
- if hps.train.if_save_latest == 0:
452
- utils.save_checkpoint(
453
- net_g,
454
- optim_g,
455
- hps.train.learning_rate,
456
- epoch,
457
- os.path.join(
458
- "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
459
- ),
460
- )
461
- utils.save_checkpoint(
462
- net_d,
463
- optim_d,
464
- hps.train.learning_rate,
465
- epoch,
466
- os.path.join(
467
- "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
468
- ),
469
- )
470
- else:
471
- utils.save_checkpoint(
472
- net_g,
473
- optim_g,
474
- hps.train.learning_rate,
475
- epoch,
476
- os.path.join(
477
- "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
478
- ),
479
- )
480
- utils.save_checkpoint(
481
- net_d,
482
- optim_d,
483
- hps.train.learning_rate,
484
- epoch,
485
- os.path.join(
486
- "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
487
- ),
488
- )
489
- if rank == 0 and hps.train.if_save_every_weights == True:
490
- if hasattr(net_g, "module"):
491
- ckpt = net_g.module.state_dict()
492
- else:
493
- ckpt = net_g.state_dict()
494
- logger.info(
495
- "saving ckpt %s_e%s:%s"
496
- % (
497
- hps.name,
498
- epoch,
499
- savee(
500
- ckpt,
501
- hps.name + "_e%s_s%s" % (epoch, global_step),
502
- epoch,
503
- global_step,
504
- hps,
505
- ),
506
- )
507
- )
508
-
509
- if rank == 0:
510
- logger.info("====> Epoch: {}".format(epoch))
511
-
512
-
513
- def evaluate(hps, generator, eval_loader, writer_eval):
514
- generator.eval()
515
- image_dict = {}
516
- audio_dict = {}
517
- print("Evaluating ...")
518
- with torch.no_grad():
519
- for batch_idx, (
520
- ssl,
521
- ssl_lengths,
522
- spec,
523
- spec_lengths,
524
- y,
525
- y_lengths,
526
- text,
527
- text_lengths,
528
- ) in enumerate(eval_loader):
529
- print(111)
530
- if torch.cuda.is_available():
531
- spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
532
- y, y_lengths = y.cuda(), y_lengths.cuda()
533
- ssl = ssl.cuda()
534
- text, text_lengths = text.cuda(), text_lengths.cuda()
535
- else:
536
- spec, spec_lengths = spec.to(device), spec_lengths.to(device)
537
- y, y_lengths = y.to(device), y_lengths.to(device)
538
- ssl = ssl.to(device)
539
- text, text_lengths = text.to(device), text_lengths.to(device)
540
- for test in [0, 1]:
541
- y_hat, mask, *_ = generator.module.infer(
542
- ssl, spec, spec_lengths, text, text_lengths, test=test
543
- ) if torch.cuda.is_available() else generator.infer(
544
- ssl, spec, spec_lengths, text, text_lengths, test=test
545
- )
546
- y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
547
-
548
- mel = spec_to_mel_torch(
549
- spec,
550
- hps.data.filter_length,
551
- hps.data.n_mel_channels,
552
- hps.data.sampling_rate,
553
- hps.data.mel_fmin,
554
- hps.data.mel_fmax,
555
- )
556
- y_hat_mel = mel_spectrogram_torch(
557
- y_hat.squeeze(1).float(),
558
- hps.data.filter_length,
559
- hps.data.n_mel_channels,
560
- hps.data.sampling_rate,
561
- hps.data.hop_length,
562
- hps.data.win_length,
563
- hps.data.mel_fmin,
564
- hps.data.mel_fmax,
565
- )
566
- image_dict.update(
567
- {
568
- f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
569
- y_hat_mel[0].cpu().numpy()
570
- )
571
- }
572
- )
573
- audio_dict.update(
574
- {f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
575
- )
576
- image_dict.update(
577
- {
578
- f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
579
- mel[0].cpu().numpy()
580
- )
581
- }
582
- )
583
- audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
584
-
585
- # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
586
- # audio_dict.update({
587
- # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
588
- # })
589
-
590
- utils.summarize(
591
- writer=writer_eval,
592
- global_step=global_step,
593
- images=image_dict,
594
- audios=audio_dict,
595
- audio_sampling_rate=hps.data.sampling_rate,
596
- )
597
- generator.train()
598
-
599
-
600
- if __name__ == "__main__":
601
- main()