File size: 28,038 Bytes
407412c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
from __future__ import annotations

import os
import gc
from tqdm import tqdm
import wandb

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.optim.lr_scheduler import LinearLR, SequentialLR

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

from unimodel import UniModel
from f5_tts.model import CFM
from f5_tts.model.utils import exists, default
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn


# trainer

import math

class RunningStats:
    def __init__(self):
        self.count = 0
        self.mean = 0.0
        self.M2 = 0.0  # Sum of squared differences from the current mean

    def update(self, x):
        """Update the running statistics with a new value x."""
        self.count += 1
        delta = x - self.mean
        self.mean += delta / self.count
        delta2 = x - self.mean
        self.M2 += delta * delta2

    @property
    def variance(self):
        """Return the sample variance. Returns NaN if fewer than two samples."""
        return self.M2 / (self.count - 1) if self.count > 1 else float('nan')

    @property
    def std(self):
        """Return the sample standard deviation."""
        return math.sqrt(self.variance)



class Trainer:
    def __init__(
        self,
        model: UniModel,
        epochs,
        learning_rate,
        num_warmup_updates=20000,
        save_per_updates=1000,
        checkpoint_path=None,
        batch_size=32,
        batch_size_type: str = "sample",
        max_samples=32,
        grad_accumulation_steps=1,
        max_grad_norm=1.0,
        noise_scheduler: str | None = None,
        duration_predictor: torch.nn.Module | None = None,
        wandb_project="test_e2-tts",
        wandb_run_name="test_run",
        wandb_resume_id: str = None,
        last_per_steps=None,
        log_step=1000,
        accelerate_kwargs: dict = dict(),
        bnb_optimizer: bool = False,
        scale: float = 1.0,
        
        # training parameters for DMDSpeech
        num_student_step: int = 1,
        gen_update_ratio: int = 5,
        lambda_discriminator_loss: float = 1.0,
        lambda_generator_loss: float = 1.0,
        lambda_ctc_loss: float = 1.0,
        lambda_sim_loss: float = 1.0,

        num_GAN: int = 5000,
        num_D: int = 500,
        num_ctc: int = 5000,
        num_sim: int = 10000,
        num_simu: int = 1000,
    ):
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

        logger = "wandb" if wandb.api.api_key else None
        print(f"Using logger: {logger}")

        self.accelerator = Accelerator(
            log_with=logger,
            kwargs_handlers=[ddp_kwargs],
            gradient_accumulation_steps=grad_accumulation_steps,
            **accelerate_kwargs,
        )

        if logger == "wandb":
            if exists(wandb_resume_id):
                init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
            else:
                init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
            self.accelerator.init_trackers(
                project_name=wandb_project,
                init_kwargs=init_kwargs,
                config={
                    "epochs": epochs,
                    "learning_rate": learning_rate,
                    "num_warmup_updates": num_warmup_updates,
                    "batch_size": batch_size,
                    "batch_size_type": batch_size_type,
                    "max_samples": max_samples,
                    "grad_accumulation_steps": grad_accumulation_steps,
                    "max_grad_norm": max_grad_norm,
                    "gpus": self.accelerator.num_processes,
                    "noise_scheduler": noise_scheduler,
                },
            )

        self.model = model

        self.scale = scale

        self.epochs = epochs
        self.num_warmup_updates = num_warmup_updates
        self.save_per_updates = save_per_updates
        self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
        self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")

        self.batch_size = batch_size
        self.batch_size_type = batch_size_type
        self.max_samples = max_samples
        self.grad_accumulation_steps = grad_accumulation_steps
        self.max_grad_norm = max_grad_norm

        self.noise_scheduler = noise_scheduler

        self.duration_predictor = duration_predictor
        
        self.log_step = log_step

        self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update
        self.lambda_discriminator_loss = lambda_discriminator_loss # weight for discriminator loss (L_adv)
        self.lambda_generator_loss = lambda_generator_loss # weight for generator loss (L_adv)
        self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss
        self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss
        
        # create distillation schedule for student model
        self.student_steps = (
                torch.linspace(0.0, 1.0, num_student_step + 1)[:-1])
        
        self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training
        self.num_GAN = num_GAN # number of steps before adversarial training
        self.num_D = num_D # number of steps to train the discriminator before adversarial training 
        self.num_ctc = num_ctc # number of steps before CTC training
        self.num_sim = num_sim # number of steps before similarity training
        self.num_simu = num_simu # number of steps before using simulated data

        # Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible
        if bnb_optimizer:
            import bitsandbytes as bnb
            self.optimizer_generator = bnb.optim.AdamW8bit(self.model.feedforward_model.parameters(), lr=learning_rate)
            self.optimizer_guidance = bnb.optim.AdamW8bit(self.model.guidance_model.parameters(), lr=learning_rate)
        else:
            self.optimizer_generator = AdamW(self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7)
            self.optimizer_guidance = AdamW(self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7)

        self.model, self.optimizer_generator, self.optimizer_guidance = self.accelerator.prepare(self.model, self.optimizer_generator, self.optimizer_guidance)

        self.generator_norm = RunningStats()
        self.guidance_norm = RunningStats()

    
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    def save_checkpoint(self, step, last=False):
        self.accelerator.wait_for_everyone()
        if self.is_main:
            checkpoint = dict(
                model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
                optimizer_generator_state_dict=self.accelerator.unwrap_model(self.optimizer_generator).state_dict(),
                optimizer_guidance_state_dict=self.accelerator.unwrap_model(self.optimizer_guidance).state_dict(),
                scheduler_generator_state_dict=self.scheduler_generator.state_dict(),
                scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(),
                step=step,
            )

            if not os.path.exists(self.checkpoint_path):
                os.makedirs(self.checkpoint_path)
            if last:
                self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
                print(f"Saved last checkpoint at step {step}")
            else:
                self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")

    def load_checkpoint(self):
        if (
            not exists(self.checkpoint_path)
            or not os.path.exists(self.checkpoint_path)
            or not os.listdir(self.checkpoint_path)
        ):
            return 0

        self.accelerator.wait_for_everyone()
        if "model_last.pt" in os.listdir(self.checkpoint_path):
            latest_checkpoint = "model_last.pt"
        else:
            latest_checkpoint = sorted(
                [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
                key=lambda x: int("".join(filter(str.isdigit, x))),
            )[-1]
        # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device)  # rather use accelerator.load_state ಥ_ಥ
        checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")

        self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"], strict=False)
        # self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"])
        # self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"])
        # if self.scheduler_guidance:
        #     self.scheduler_guidance.load_state_dict(checkpoint["scheduler_guidance_state_dict"])
        # if self.scheduler_generator:
        #     self.scheduler_generator.load_state_dict(checkpoint["scheduler_generator_state_dict"])
        step = checkpoint["step"]

        del checkpoint
        gc.collect()
        return step
    

    def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int = None, vocoder: nn.Module = None):
        if exists(resumable_with_seed):
            generator = torch.Generator()
            generator.manual_seed(resumable_with_seed)
        else:
            generator = None

        if self.batch_size_type == "sample":
            train_dataloader = DataLoader(
                train_dataset,
                collate_fn=collate_fn,
                num_workers=num_workers,
                pin_memory=True,
                persistent_workers=True,
                batch_size=self.batch_size,
                shuffle=True,
                generator=generator,
            )
        elif self.batch_size_type == "frame":
            self.accelerator.even_batches = False
            sampler = SequentialSampler(train_dataset)
            batch_sampler = DynamicBatchSampler(
                sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
            )
            train_dataloader = DataLoader(
                train_dataset,
                collate_fn=collate_fn,
                num_workers=num_workers,
                pin_memory=True,
                persistent_workers=True,
                batch_sampler=batch_sampler,
            )
        else:
            raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")

        #  accelerator.prepare() dispatches batches to devices;
        #  which means the length of dataloader calculated before, should consider the number of devices
        warmup_steps = (
            self.num_warmup_updates * self.accelerator.num_processes
        )
        
        # consider a fixed warmup steps while using accelerate multi-gpu ddp
        # otherwise by default with split_batches=False, warmup steps change with num_processes
        total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
        decay_steps = total_steps - warmup_steps
        
        warmup_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps))
        decay_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps // (self.gen_update_ratio * self.grad_accumulation_steps))
        self.scheduler_generator = SequentialLR(self.optimizer_generator, schedulers=[warmup_scheduler_generator, decay_scheduler_generator], milestones=[warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)])

        warmup_scheduler_guidance = LinearLR(self.optimizer_guidance, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
        decay_scheduler_guidance = LinearLR(self.optimizer_guidance, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
        self.scheduler_guidance = SequentialLR(self.optimizer_guidance, schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance], milestones=[warmup_steps])

        train_dataloader, self.scheduler_generator, self.scheduler_guidance = self.accelerator.prepare(
            train_dataloader, self.scheduler_generator, self.scheduler_guidance
        )  # actual steps = 1 gpu steps / gpus
        start_step = self.load_checkpoint()
        global_step = start_step

        if exists(resumable_with_seed):
            orig_epoch_step = len(train_dataloader)
            skipped_epoch = int(start_step // orig_epoch_step)
            skipped_batch = start_step % orig_epoch_step
            skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
        else:
            skipped_epoch = 0

        for epoch in range(skipped_epoch, self.epochs):
            self.model.train()
            if exists(resumable_with_seed) and epoch == skipped_epoch:
                progress_bar = tqdm(
                    skipped_dataloader,
                    desc=f"Epoch {epoch+1}/{self.epochs}",
                    unit="step",
                    disable=not self.accelerator.is_local_main_process,
                    initial=skipped_batch,
                    total=orig_epoch_step,
                )
            else:
                progress_bar = tqdm(
                    train_dataloader,
                    desc=f"Epoch {epoch+1}/{self.epochs}",
                    unit="step",
                    disable=not self.accelerator.is_local_main_process,
                )

            for batch in progress_bar:
                update_generator = global_step % self.gen_update_ratio == 0
                        
                with self.accelerator.accumulate(self.model):
                    metrics = {}
                    text_inputs = batch["text"]
                    mel_spec = batch["mel"].permute(0, 2, 1)
                    mel_lengths = batch["mel_lengths"]
                    
                    mel_spec = mel_spec / self.scale
                    
                    guidance_loss_dict, guidance_log_dict = self.model(inp=mel_spec, 
                                                                text=text_inputs, 
                                                                lens=mel_lengths, 
                                                                student_steps=self.student_steps,
                                                                update_generator=False,
                                                                use_simulated=global_step >= self.num_simu,
                                                                )

                    # if self.GAN and update_generator:
                    #     # only add discriminator loss if GAN is enabled and generator is being updated
                    #     guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0)
                    #     metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"]
                    #     self.accelerator.backward(guidance_cls_loss, retain_graph=True)
                        
                    #     if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
                    #         metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                    guidance_loss = 0
                    guidance_loss += guidance_loss_dict["loss_fake_mean"]
                    metrics['loss/fake_score'] = guidance_loss_dict["loss_fake_mean"]
                    metrics["loss/guidance_loss"] = guidance_loss

                    if self.GAN and update_generator:
                        # only add discriminator loss if GAN is enabled and generator is being updated
                        guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0)
                        metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"]

                        guidance_loss += guidance_cls_loss
                    
                    self.accelerator.backward(guidance_loss)

                    if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
                        metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                        # if self.guidance_norm.count < 100:
                        #     self.guidance_norm.update(metrics['grad_norm_guidance'])

                        # if metrics['grad_norm_guidance'] > self.guidance_norm.mean + 5 * self.guidance_norm.std:
                        #     self.optimizer_generator.zero_grad()
                        #     self.optimizer_guidance.zero_grad()
                        #     print("Gradient explosion detected. Skipping batch.")
                        # elif self.guidance_norm.count >= 100:
                        #     self.guidance_norm.update(metrics['grad_norm_guidance'])


                    self.optimizer_guidance.step()
                    self.scheduler_guidance.step()
                    self.optimizer_guidance.zero_grad()
                    self.optimizer_generator.zero_grad()  # zero out the generator's gradient as well
                    
                    if update_generator:
                        generator_loss_dict, generator_log_dict = self.model(inp=mel_spec, 
                                                                        text=text_inputs, 
                                                                        lens=mel_lengths, 
                                                                        student_steps=self.student_steps,
                                                                        update_generator=True,
                                                                        use_simulated=global_step >= self.num_ctc,
                                                                        )
                        # if self.GAN:
                        #     gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0)
                        #     metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"]

                        #     self.accelerator.backward(gen_cls_loss, retain_graph=True)

                        #     if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
                        #         metrics['grad_norm_generator'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                        generator_loss = 0
                        generator_loss += generator_loss_dict["loss_dm"]
                        if "loss_mse" in generator_loss_dict:
                            generator_loss += generator_loss_dict["loss_mse"] 
                        generator_loss += generator_loss_dict["loss_ctc"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0)
                        generator_loss += generator_loss_dict["loss_sim"] * (self.lambda_sim_loss if global_step >= self.num_sim else 0)
                        generator_loss += generator_loss_dict["loss_kl"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0)
                        if self.GAN:
                            gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0)
                            metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"]
                            generator_loss += gen_cls_loss

                        metrics['loss/dm_loss'] = generator_loss_dict["loss_dm"]
                        metrics['loss/ctc_loss'] = generator_loss_dict["loss_ctc"]

                        metrics['loss/similarity_loss'] = generator_loss_dict["loss_sim"]
                        metrics['loss/generator_loss'] = generator_loss
                        
                        if "loss_mse" in generator_loss_dict and generator_loss_dict["loss_mse"] != 0:
                            metrics['loss/mse_loss'] = generator_loss_dict["loss_mse"]
                        if "loss_kl" in generator_loss_dict and generator_loss_dict["loss_kl"] != 0:
                            metrics['loss/kl_loss'] = generator_loss_dict["loss_kl"]

                        self.accelerator.backward(generator_loss)

                        if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
                            metrics['grad_norm_generator'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                            # self.generator_norm.update(metrics['grad_norm_generator'])
                            
                            # if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std:
                            #     self.optimizer_generator.zero_grad()
                            #     self.optimizer_guidance.zero_grad()
                            #     update_generator = False
                            #     print("Gradient explosion detected. Skipping batch.")

                        if update_generator:
                            self.optimizer_generator.step()
                            self.scheduler_generator.step()
                            self.optimizer_generator.zero_grad()
                            self.optimizer_guidance.zero_grad()  # zero out the guidance's gradient as well


                global_step += 1

                if self.accelerator.is_local_main_process:
                    self.accelerator.log({**metrics,
                                          "lr_generator": self.scheduler_generator.get_last_lr()[0],
                                          "lr_guidance": self.scheduler_guidance.get_last_lr()[0],
                                          }
                                         , step=global_step)
                
                if global_step % self.log_step == 0 and self.accelerator.is_local_main_process and vocoder is not None:
                    # log the first batch of the epoch
                    with torch.no_grad():
                        generator_input = generator_log_dict['generator_input'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
                        generator_input = vocoder.decode(generator_input.float().cpu())
                        generator_input = wandb.Audio(
                            generator_input.float().numpy().squeeze(),
                            sample_rate=24000,
                            caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
                        )

                        generator_output = generator_log_dict['generator_output'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
                        generator_output = vocoder.decode(generator_output.float().cpu())
                        generator_output = wandb.Audio(
                            generator_output.float().numpy().squeeze(),
                            sample_rate=24000,
                            caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
                        )
                        
                        generator_cond = generator_log_dict['generator_cond'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
                        generator_cond = vocoder.decode(generator_cond.float().cpu())
                        generator_cond = wandb.Audio(
                            generator_cond.float().numpy().squeeze(),
                            sample_rate=24000,
                            caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
                        )
                        
                        ground_truth = generator_log_dict['ground_truth'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
                        ground_truth = vocoder.decode(ground_truth.float().cpu())
                        ground_truth = wandb.Audio(
                            ground_truth.float().numpy().squeeze(),
                            sample_rate=24000,
                            caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy())
                        )
                        
                        dmtrain_noisy_inp = generator_log_dict['dmtrain_noisy_inp'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
                        dmtrain_noisy_inp = vocoder.decode(dmtrain_noisy_inp.float().cpu())
                        dmtrain_noisy_inp = wandb.Audio(
                            dmtrain_noisy_inp.float().numpy().squeeze(),
                            sample_rate=24000,
                            caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy())
                        )
                        
                        dmtrain_pred_real_image = generator_log_dict['dmtrain_pred_real_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
                        dmtrain_pred_real_image = vocoder.decode(dmtrain_pred_real_image.float().cpu())
                        dmtrain_pred_real_image = wandb.Audio(
                            dmtrain_pred_real_image.float().numpy().squeeze(),
                            sample_rate=24000,
                            caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy())
                        )
                        
                        dmtrain_pred_fake_image = generator_log_dict['dmtrain_pred_fake_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale
                        dmtrain_pred_fake_image = vocoder.decode(dmtrain_pred_fake_image.float().cpu())
                        dmtrain_pred_fake_image = wandb.Audio(
                            dmtrain_pred_fake_image.float().numpy().squeeze(),
                            sample_rate=24000,
                            caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy())
                        )
                        
                                                
                        self.accelerator.log({"noisy_input": generator_input, 
                                              "output": generator_output,
                                                "cond": generator_cond,
                                                "ground_truth": ground_truth,
                                                "dmtrain_noisy_inp": dmtrain_noisy_inp,
                                                "dmtrain_pred_real_image": dmtrain_pred_real_image,
                                                "dmtrain_pred_fake_image": dmtrain_pred_fake_image,
                                                
                                             }, step=global_step)

                progress_bar.set_postfix(step=str(global_step), metrics=metrics)

                if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
                    self.save_checkpoint(global_step)

                if global_step % self.last_per_steps == 0:
                    self.save_checkpoint(global_step, last=True)

        self.save_checkpoint(global_step, last=True)

        self.accelerator.end_training()