import os import time import random import argparse import numpy as np from tqdm import tqdm from accelerate import Accelerator from einops import rearrange from cached_path import cached_path import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader # replace this with BigVGAN import bigvgan from model.modules import MelSpec from network.crossdit import CrossDiT from dataset.capspeech import CapSpeech from utils import load_checkpoint, make_pad_mask from utils import get_lr_scheduler, load_yaml_with_includes from inference import eval_model def parse_args(): parser = argparse.ArgumentParser() # Config settings parser.add_argument('--config-name', type=str, required=True) # Training settings parser.add_argument("--amp", type=str, default='fp16') parser.add_argument('--epochs', type=int, default=15) parser.add_argument('--num-workers', type=int, default=32) parser.add_argument('--num-threads', type=int, default=1) parser.add_argument('--eval-every-step', type=int, default=10000) # save all states including optimizer every save-every-step parser.add_argument('--save-every-step', type=int, default=10000) parser.add_argument('--resume-from', type=str, default=None, help='Path to checkpoint to resume training') # Log and random seed parser.add_argument('--random-seed', type=int, default=2025) parser.add_argument('--log-step', type=int, default=500) parser.add_argument('--log-dir', type=str, default='./logs/') parser.add_argument('--save-dir', type=str, default='./ckpts/') return parser.parse_args() def setup_directories(args, params): args.log_dir = os.path.join(args.log_dir, params['model_name']) + '/' args.save_dir = os.path.join(args.save_dir, params['model_name']) + '/' os.makedirs(args.log_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True) def set_device(args): torch.set_num_threads(args.num_threads) if torch.cuda.is_available(): args.device = 'cuda' torch.cuda.manual_seed_all(args.random_seed) torch.backends.cuda.matmul.allow_tf32 = True if torch.backends.cudnn.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: args.device = 'cpu' def prepare_batch(batch, mel, latent_sr): x, x_lens, y, y_lens, c, c_lens, tag = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"], batch["c"], batch["c_lens"], batch["tag"] # add len for clap embedding x_lens = x_lens + 1 with torch.no_grad(): # print(y.mean()) audio_clip = mel(y) audio_clip = rearrange(audio_clip, 'b d n -> b n d') y_lens = (y_lens * latent_sr).long() return x, x_lens, audio_clip, y_lens, c, c_lens, tag if __name__ == '__main__': args = parse_args() params = load_yaml_with_includes(args.config_name) # random seed set_device(args) random.seed(args.random_seed) torch.manual_seed(args.random_seed) accelerator = Accelerator(mixed_precision=args.amp, gradient_accumulation_steps=params['opt']['accumulation_steps'], step_scheduler_with_optimizer=False) # dataset train_set = CapSpeech(**params['data']['trainset']) train_loader = DataLoader(train_set, num_workers=args.num_workers, batch_size=params['opt']['batch_size'], shuffle=True, collate_fn=train_set.collate) val_set = CapSpeech(**params['data']['valset']) val_loader = DataLoader(val_set, num_workers=0, batch_size=1, shuffle=False, collate_fn=train_set.collate) # load dit model = CrossDiT(**params['model']) # mel spectrogram - move to accelerator device after preparation mel = MelSpec(**params['mel']) latent_sr = params['mel']['target_sample_rate'] / params['mel']['hop_length'] # load vocoder vocoder = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False) vocoder.remove_weight_norm() vocoder = vocoder.eval().to(accelerator.device) # prepare opt optimizer = torch.optim.AdamW(model.parameters(), lr=params['opt']['learning_rate']) if args.resume_from is not None and os.path.exists(args.resume_from): checkpoint = torch.load(args.resume_from, map_location='cpu') model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) global_step = checkpoint["global_step"] start_epoch = checkpoint["epoch"] + 1 # Continue from the next epoch print(f"Resuming training from checkpoint: {args.resume_from}, starting from epoch {start_epoch}.") else: global_step = 0 start_epoch = 0 lr_scheduler = get_lr_scheduler(optimizer, 'customized', **params['opt']['lr_scheduler']) # Prepare with accelerator (model, optimizer, lr_scheduler, train_loader, val_loader) = accelerator.prepare(model, optimizer, lr_scheduler, train_loader, val_loader) # Move mel and vocos to the same device as model AFTER preparation mel = mel.to(accelerator.device) vocoder = vocoder.to(accelerator.device) # Add synchronization point accelerator.wait_for_everyone() losses = 0.0 if accelerator.is_main_process: setup_directories(args, params) trainable_params = sum(param.nelement() for param in model.parameters() if param.requires_grad) print("Number of trainable parameters: %.2fM" % (trainable_params / 1e6)) # Add synchronization point accelerator.wait_for_everyone() # REMOVED initial evaluation to prevent deadlock # We'll evaluate after the first epoch or at the first eval step for epoch in range(start_epoch, args.epochs): model.train() # Use accelerator's progress bar for correct handling in distributed setup progress_bar = tqdm(train_loader, disable=not accelerator.is_local_main_process) for step, batch in enumerate(progress_bar): with accelerator.accumulate(model): (text, text_lens, audio_clips, audio_lens, prompt, prompt_lens, clap) = prepare_batch(batch, mel, latent_sr) # prepare flow mathing x1 = audio_clips x0 = torch.randn_like(x1) t = torch.rand((x1.shape[0],), dtype=x1.dtype, device=x1.device) sigma = rearrange(t, 'b -> b 1 1') noisy_x1 = (1 - sigma) * x0.clone() + sigma * x1.clone() flow = x1.clone() - x0.clone() # option: audio-prompt based zero-shot tts # tts_mask = create_tts_mask(seq_len, x1.shape[1], params['opt']['mask_range']) # # cond = x1.clone(), cond[tts_mask[..., None]] = 0 # cond = torch.where(tts_mask[..., None], torch.zeros_like(x1), x1) cond = None # prepare batch cfg drop_prompt = (torch.rand(x1.shape[0]) < params['opt']['drop_spk']) drop_text = drop_prompt & (torch.rand(x1.shape[0]) < params['opt']['drop_text']) prompt[drop_prompt] = 0.0 prompt_lens[drop_prompt] = 1 clap[drop_text] = 0.0 text[drop_text] = -1 seq_len_audio = audio_clips.shape[1] pad_mask = make_pad_mask(audio_lens, seq_len_audio).to(audio_clips.device) seq_len_prompt = prompt.shape[1] prompt_mask = make_pad_mask(prompt_lens, seq_len_prompt).to(prompt.device) pred = model(x=noisy_x1, cond=cond, prompt=prompt, clap=clap, text=text, time=t, mask=pad_mask, prompt_mask=prompt_mask) loss = F.mse_loss(pred, flow, reduction="none") loss = loss[pad_mask].mean() accelerator.backward(loss) if accelerator.sync_gradients: if 'grad_clip' in params['opt'] and params['opt']['grad_clip'] > 0: accelerator.clip_grad_norm_(model.parameters(), max_norm=params['opt']['grad_clip']) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Fixed step counting - increment only once per actual step, not per accumulation step if accelerator.sync_gradients: global_step += 1 losses += loss.item() # Add progress bar description if accelerator.is_local_main_process: progress_bar.set_description(f"Epoch {epoch+1}, Loss: {loss.item():.6f}") if global_step % args.log_step == 0: losses = losses / args.log_step # Calculate average loss if accelerator.is_main_process: current_time = time.asctime(time.localtime(time.time())) epoch_info = f'Epoch: [{epoch + 1}][{args.epochs}]' batch_info = f'Global Step: {global_step}' loss_info = f'Loss: {losses:.6f}' # Extract the learning rate from the optimizer lr = optimizer.param_groups[0]['lr'] lr_info = f'Learning Rate: {lr:.6f}' log_message = f'{current_time}\n{epoch_info} {batch_info} {loss_info} {lr_info}\n' with open(args.log_dir + 'log.txt', mode='a') as n: n.write(log_message) # Reset loss accumulator losses = 0.0 # Evaluation logic if global_step % args.eval_every_step == 0: # Set model to eval mode model.eval() # Synchronize before evaluation accelerator.wait_for_everyone() if accelerator.is_main_process: # Get unwrapped model for evaluation unwrapped_model = accelerator.unwrap_model(model) # Run evaluation without specifying device eval_model(unwrapped_model, vocoder, mel, val_loader, params, steps=25, cfg=2.0, sway_sampling_coef=-1.0, # Remove explicit device setting epoch=global_step, save_path=args.log_dir + 'output/', val_num=1) # Save model checkpoint accelerator.save({ "model": unwrapped_model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, }, args.save_dir + str(global_step) + '.pt') # Save full state including optimizer if needed if global_step % args.save_every_step == 0: accelerator.save_state(f"{args.save_dir}{global_step}") # Synchronize after evaluation and saving accelerator.wait_for_everyone() # Set model back to train mode model.train() # Synchronize at the end of each epoch accelerator.wait_for_everyone()