from __future__ import annotations import gc import math import os import torch import torch.nn as nn import wandb from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR from torch.utils.data import DataLoader, Dataset, SequentialSampler from tqdm import tqdm from f5_tts.model import CFM from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.model.utils import default, exists from unimodel import UniModel # trainer class RunningStats: def __init__(self): self.count = 0 self.mean = 0.0 self.M2 = 0.0 # Sum of squared differences from the current mean def update(self, x): """Update the running statistics with a new value x.""" self.count += 1 delta = x - self.mean self.mean += delta / self.count delta2 = x - self.mean self.M2 += delta * delta2 @property def variance(self): """Return the sample variance. Returns NaN if fewer than two samples.""" return self.M2 / (self.count - 1) if self.count > 1 else float("nan") @property def std(self): """Return the sample standard deviation.""" return math.sqrt(self.variance) class Trainer: def __init__( self, model: UniModel, epochs, learning_rate, num_warmup_updates=20000, save_per_updates=1000, checkpoint_path=None, batch_size=32, batch_size_type: str = "sample", max_samples=32, grad_accumulation_steps=1, max_grad_norm=1.0, noise_scheduler: str | None = None, duration_predictor: torch.nn.Module | None = None, wandb_project="test_e2-tts", wandb_run_name="test_run", wandb_resume_id: str = None, last_per_steps=None, log_step=1000, accelerate_kwargs: dict = dict(), bnb_optimizer: bool = False, scale: float = 1.0, # training parameters for DMDSpeech num_student_step: int = 1, gen_update_ratio: int = 5, lambda_discriminator_loss: float = 1.0, lambda_generator_loss: float = 1.0, lambda_ctc_loss: float = 1.0, lambda_sim_loss: float = 1.0, num_GAN: int = 5000, num_D: int = 500, num_ctc: int = 5000, num_sim: int = 10000, num_simu: int = 1000, ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) logger = "wandb" if wandb.api.api_key else None print(f"Using logger: {logger}") self.accelerator = Accelerator( log_with=logger, kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=grad_accumulation_steps, **accelerate_kwargs, ) if logger == "wandb": if exists(wandb_resume_id): init_kwargs = { "wandb": { "resume": "allow", "name": wandb_run_name, "id": wandb_resume_id, } } else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} self.accelerator.init_trackers( project_name=wandb_project, init_kwargs=init_kwargs, config={ "epochs": epochs, "learning_rate": learning_rate, "num_warmup_updates": num_warmup_updates, "batch_size": batch_size, "batch_size_type": batch_size_type, "max_samples": max_samples, "grad_accumulation_steps": grad_accumulation_steps, "max_grad_norm": max_grad_norm, "gpus": self.accelerator.num_processes, "noise_scheduler": noise_scheduler, }, ) self.model = model self.scale = scale self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates self.last_per_steps = default( last_per_steps, save_per_updates * grad_accumulation_steps ) self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") self.batch_size = batch_size self.batch_size_type = batch_size_type self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps self.max_grad_norm = max_grad_norm self.noise_scheduler = noise_scheduler self.duration_predictor = duration_predictor self.log_step = log_step self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update self.lambda_discriminator_loss = ( lambda_discriminator_loss # weight for discriminator loss (L_adv) ) self.lambda_generator_loss = ( lambda_generator_loss # weight for generator loss (L_adv) ) self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss # create distillation schedule for student model self.student_steps = torch.linspace(0.0, 1.0, num_student_step + 1)[:-1] self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training self.num_GAN = num_GAN # number of steps before adversarial training self.num_D = num_D # number of steps to train the discriminator before adversarial training self.num_ctc = num_ctc # number of steps before CTC training self.num_sim = num_sim # number of steps before similarity training self.num_simu = num_simu # number of steps before using simulated data # Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible if bnb_optimizer: import bitsandbytes as bnb self.optimizer_generator = bnb.optim.AdamW8bit( self.model.feedforward_model.parameters(), lr=learning_rate ) self.optimizer_guidance = bnb.optim.AdamW8bit( self.model.guidance_model.parameters(), lr=learning_rate ) else: self.optimizer_generator = AdamW( self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7 ) self.optimizer_guidance = AdamW( self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7 ) self.model, self.optimizer_generator, self.optimizer_guidance = ( self.accelerator.prepare( self.model, self.optimizer_generator, self.optimizer_guidance ) ) self.generator_norm = RunningStats() self.guidance_norm = RunningStats() @property def is_main(self): return self.accelerator.is_main_process def save_checkpoint(self, step, last=False): self.accelerator.wait_for_everyone() if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), optimizer_generator_state_dict=self.accelerator.unwrap_model( self.optimizer_generator ).state_dict(), optimizer_guidance_state_dict=self.accelerator.unwrap_model( self.optimizer_guidance ).state_dict(), scheduler_generator_state_dict=self.scheduler_generator.state_dict(), scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(), step=step, ) if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: self.accelerator.save( checkpoint, f"{self.checkpoint_path}/model_last.pt" ) print(f"Saved last checkpoint at step {step}") else: self.accelerator.save( checkpoint, f"{self.checkpoint_path}/model_{step}.pt" ) def load_checkpoint(self): if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path) ): return 0 self.accelerator.wait_for_everyone() if "model_last.pt" in os.listdir(self.checkpoint_path): latest_checkpoint = "model_last.pt" else: latest_checkpoint = sorted( [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")], key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ checkpoint = torch.load( f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu", ) self.accelerator.unwrap_model(self.model).load_state_dict( checkpoint["model_state_dict"], strict=False ) # self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"]) # self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"]) # if self.scheduler_guidance: # self.scheduler_guidance.load_state_dict(checkpoint["scheduler_guidance_state_dict"]) # if self.scheduler_generator: # self.scheduler_generator.load_state_dict(checkpoint["scheduler_generator_state_dict"]) step = checkpoint["step"] del checkpoint gc.collect() return step def train( self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int = None, vocoder: nn.Module = None, ): if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) else: generator = None if self.batch_size_type == "sample": train_dataloader = DataLoader( train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True, batch_size=self.batch_size, shuffle=True, generator=generator, ) elif self.batch_size_type == "frame": self.accelerator.even_batches = False sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False, ) train_dataloader = DataLoader( train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) else: raise ValueError( f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" ) # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp # otherwise by default with split_batches=False, warmup steps change with num_processes total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps decay_steps = total_steps - warmup_steps warmup_scheduler_generator = LinearLR( self.optimizer_generator, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps), ) decay_scheduler_generator = LinearLR( self.optimizer_generator, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps // (self.gen_update_ratio * self.grad_accumulation_steps), ) self.scheduler_generator = SequentialLR( self.optimizer_generator, schedulers=[warmup_scheduler_generator, decay_scheduler_generator], milestones=[ warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps) ], ) warmup_scheduler_guidance = LinearLR( self.optimizer_guidance, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps, ) decay_scheduler_guidance = LinearLR( self.optimizer_guidance, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps, ) self.scheduler_guidance = SequentialLR( self.optimizer_guidance, schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance], milestones=[warmup_steps], ) train_dataloader, self.scheduler_generator, self.scheduler_guidance = ( self.accelerator.prepare( train_dataloader, self.scheduler_generator, self.scheduler_guidance ) ) # actual steps = 1 gpu steps / gpus start_step = self.load_checkpoint() global_step = start_step if exists(resumable_with_seed): orig_epoch_step = len(train_dataloader) skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step skipped_dataloader = self.accelerator.skip_first_batches( train_dataloader, num_batches=skipped_batch ) else: skipped_epoch = 0 for epoch in range(skipped_epoch, self.epochs): self.model.train() if exists(resumable_with_seed) and epoch == skipped_epoch: progress_bar = tqdm( skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, initial=skipped_batch, total=orig_epoch_step, ) else: progress_bar = tqdm( train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, ) for batch in progress_bar: update_generator = global_step % self.gen_update_ratio == 0 with self.accelerator.accumulate(self.model): metrics = {} text_inputs = batch["text"] mel_spec = batch["mel"].permute(0, 2, 1) mel_lengths = batch["mel_lengths"] mel_spec = mel_spec / self.scale guidance_loss_dict, guidance_log_dict = self.model( inp=mel_spec, text=text_inputs, lens=mel_lengths, student_steps=self.student_steps, update_generator=False, use_simulated=global_step >= self.num_simu, ) # if self.GAN and update_generator: # # only add discriminator loss if GAN is enabled and generator is being updated # guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0) # metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"] # self.accelerator.backward(guidance_cls_loss, retain_graph=True) # if self.max_grad_norm > 0 and self.accelerator.sync_gradients: # metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) guidance_loss = 0 guidance_loss += guidance_loss_dict["loss_fake_mean"] metrics["loss/fake_score"] = guidance_loss_dict["loss_fake_mean"] metrics["loss/guidance_loss"] = guidance_loss if self.GAN and update_generator: # only add discriminator loss if GAN is enabled and generator is being updated guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * ( self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0 ) metrics["loss/discriminator_loss"] = guidance_loss_dict[ "guidance_cls_loss" ] guidance_loss += guidance_cls_loss self.accelerator.backward(guidance_loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: metrics["grad_norm_guidance"] = ( self.accelerator.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) ) # if self.guidance_norm.count < 100: # self.guidance_norm.update(metrics['grad_norm_guidance']) # if metrics['grad_norm_guidance'] > self.guidance_norm.mean + 5 * self.guidance_norm.std: # self.optimizer_generator.zero_grad() # self.optimizer_guidance.zero_grad() # print("Gradient explosion detected. Skipping batch.") # elif self.guidance_norm.count >= 100: # self.guidance_norm.update(metrics['grad_norm_guidance']) self.optimizer_guidance.step() self.scheduler_guidance.step() self.optimizer_guidance.zero_grad() self.optimizer_generator.zero_grad() # zero out the generator's gradient as well if update_generator: generator_loss_dict, generator_log_dict = self.model( inp=mel_spec, text=text_inputs, lens=mel_lengths, student_steps=self.student_steps, update_generator=True, use_simulated=global_step >= self.num_ctc, ) # if self.GAN: # gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0) # metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"] # self.accelerator.backward(gen_cls_loss, retain_graph=True) # if self.max_grad_norm > 0 and self.accelerator.sync_gradients: # metrics['grad_norm_generator'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) generator_loss = 0 generator_loss += generator_loss_dict["loss_dm"] if "loss_mse" in generator_loss_dict: generator_loss += generator_loss_dict["loss_mse"] generator_loss += generator_loss_dict["loss_ctc"] * ( self.lambda_ctc_loss if global_step >= self.num_ctc else 0 ) generator_loss += generator_loss_dict["loss_sim"] * ( self.lambda_sim_loss if global_step >= self.num_sim else 0 ) generator_loss += generator_loss_dict["loss_kl"] * ( self.lambda_ctc_loss if global_step >= self.num_ctc else 0 ) if self.GAN: gen_cls_loss = generator_loss_dict["gen_cls_loss"] * ( self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0 ) metrics["loss/gen_cls_loss"] = generator_loss_dict[ "gen_cls_loss" ] generator_loss += gen_cls_loss metrics["loss/dm_loss"] = generator_loss_dict["loss_dm"] metrics["loss/ctc_loss"] = generator_loss_dict["loss_ctc"] metrics["loss/similarity_loss"] = generator_loss_dict[ "loss_sim" ] metrics["loss/generator_loss"] = generator_loss if ( "loss_mse" in generator_loss_dict and generator_loss_dict["loss_mse"] != 0 ): metrics["loss/mse_loss"] = generator_loss_dict["loss_mse"] if ( "loss_kl" in generator_loss_dict and generator_loss_dict["loss_kl"] != 0 ): metrics["loss/kl_loss"] = generator_loss_dict["loss_kl"] self.accelerator.backward(generator_loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: metrics["grad_norm_generator"] = ( self.accelerator.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) ) # self.generator_norm.update(metrics['grad_norm_generator']) # if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std: # self.optimizer_generator.zero_grad() # self.optimizer_guidance.zero_grad() # update_generator = False # print("Gradient explosion detected. Skipping batch.") if update_generator: self.optimizer_generator.step() self.scheduler_generator.step() self.optimizer_generator.zero_grad() self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well global_step += 1 if self.accelerator.is_local_main_process: self.accelerator.log( { **metrics, "lr_generator": self.scheduler_generator.get_last_lr()[0], "lr_guidance": self.scheduler_guidance.get_last_lr()[0], }, step=global_step, ) if ( global_step % self.log_step == 0 and self.accelerator.is_local_main_process and vocoder is not None ): # log the first batch of the epoch with torch.no_grad(): generator_input = ( generator_log_dict["generator_input"][0] .unsqueeze(0) .permute(0, 2, 1) * self.scale ) generator_input = vocoder.decode(generator_input.float().cpu()) generator_input = wandb.Audio( generator_input.float().numpy().squeeze(), sample_rate=24000, caption="time: " + str(generator_log_dict["time"][0].float().cpu().numpy()), ) generator_output = ( generator_log_dict["generator_output"][0] .unsqueeze(0) .permute(0, 2, 1) * self.scale ) generator_output = vocoder.decode( generator_output.float().cpu() ) generator_output = wandb.Audio( generator_output.float().numpy().squeeze(), sample_rate=24000, caption="time: " + str(generator_log_dict["time"][0].float().cpu().numpy()), ) generator_cond = ( generator_log_dict["generator_cond"][0] .unsqueeze(0) .permute(0, 2, 1) * self.scale ) generator_cond = vocoder.decode(generator_cond.float().cpu()) generator_cond = wandb.Audio( generator_cond.float().numpy().squeeze(), sample_rate=24000, caption="time: " + str(generator_log_dict["time"][0].float().cpu().numpy()), ) ground_truth = ( generator_log_dict["ground_truth"][0] .unsqueeze(0) .permute(0, 2, 1) * self.scale ) ground_truth = vocoder.decode(ground_truth.float().cpu()) ground_truth = wandb.Audio( ground_truth.float().numpy().squeeze(), sample_rate=24000, caption="time: " + str(generator_log_dict["time"][0].float().cpu().numpy()), ) dmtrain_noisy_inp = ( generator_log_dict["dmtrain_noisy_inp"][0] .unsqueeze(0) .permute(0, 2, 1) * self.scale ) dmtrain_noisy_inp = vocoder.decode( dmtrain_noisy_inp.float().cpu() ) dmtrain_noisy_inp = wandb.Audio( dmtrain_noisy_inp.float().numpy().squeeze(), sample_rate=24000, caption="dmtrain_time: " + str( generator_log_dict["dmtrain_time"][0] .float() .cpu() .numpy() ), ) dmtrain_pred_real_image = ( generator_log_dict["dmtrain_pred_real_image"][0] .unsqueeze(0) .permute(0, 2, 1) * self.scale ) dmtrain_pred_real_image = vocoder.decode( dmtrain_pred_real_image.float().cpu() ) dmtrain_pred_real_image = wandb.Audio( dmtrain_pred_real_image.float().numpy().squeeze(), sample_rate=24000, caption="dmtrain_time: " + str( generator_log_dict["dmtrain_time"][0] .float() .cpu() .numpy() ), ) dmtrain_pred_fake_image = ( generator_log_dict["dmtrain_pred_fake_image"][0] .unsqueeze(0) .permute(0, 2, 1) * self.scale ) dmtrain_pred_fake_image = vocoder.decode( dmtrain_pred_fake_image.float().cpu() ) dmtrain_pred_fake_image = wandb.Audio( dmtrain_pred_fake_image.float().numpy().squeeze(), sample_rate=24000, caption="dmtrain_time: " + str( generator_log_dict["dmtrain_time"][0] .float() .cpu() .numpy() ), ) self.accelerator.log( { "noisy_input": generator_input, "output": generator_output, "cond": generator_cond, "ground_truth": ground_truth, "dmtrain_noisy_inp": dmtrain_noisy_inp, "dmtrain_pred_real_image": dmtrain_pred_real_image, "dmtrain_pred_fake_image": dmtrain_pred_fake_image, }, step=global_step, ) progress_bar.set_postfix(step=str(global_step), metrics=metrics) if ( global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0 ): self.save_checkpoint(global_step) if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True) self.save_checkpoint(global_step, last=True) self.accelerator.end_training()