DMOSpeech2 / f5_tts /model /trainer.py
mrfakename's picture
pt 1
597cecf
from __future__ import annotations
import gc
import os
import torch
import torchaudio
import wandb
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from ema_pytorch import EMA
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from tqdm import tqdm
from f5_tts.model import CFM
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer
class Trainer:
def __init__(
self,
model: CFM,
epochs,
learning_rate,
num_warmup_updates=20000,
save_per_updates=1000,
checkpoint_path=None,
batch_size=32,
batch_size_type: str = "sample",
max_samples=32,
grad_accumulation_steps=1,
max_grad_norm=1.0,
noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None,
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
wandb_project="test_e2-tts",
wandb_run_name="test_run",
wandb_resume_id: str = None,
log_samples: bool = False,
last_per_steps=None,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if logger == "wandb" and not wandb.api.api_key:
logger = None
print(f"Using logger: {logger}")
self.log_samples = log_samples
self.accelerator = Accelerator(
log_with=logger if logger == "wandb" else None,
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
self.logger = logger
if self.logger == "wandb":
if exists(wandb_resume_id):
init_kwargs = {
"wandb": {
"resume": "allow",
"name": wandb_run_name,
"id": wandb_resume_id,
}
}
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config={
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler,
},
)
elif self.logger == "tensorboard":
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
self.model = model
if self.is_main:
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device)
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.last_per_steps = default(
last_per_steps, save_per_updates * grad_accumulation_steps
)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
self.batch_size = batch_size
self.batch_size_type = batch_size_type
self.max_samples = max_samples
self.grad_accumulation_steps = grad_accumulation_steps
self.max_grad_norm = max_grad_norm
# mel vocoder config
self.vocoder_name = mel_spec_type
self.is_local_vocoder = is_local_vocoder
self.local_vocoder_path = local_vocoder_path
self.noise_scheduler = noise_scheduler
self.duration_predictor = duration_predictor
if bnb_optimizer:
import bitsandbytes as bnb
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
else:
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
self.scale = None
self.count = 0
@property
def is_main(self):
return self.accelerator.is_main_process
def save_checkpoint(self, step, last=False):
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.accelerator.unwrap_model(
self.optimizer
).state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
step=step,
scale=self.scale,
count=self.count,
)
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
if last:
self.accelerator.save(
checkpoint, f"{self.checkpoint_path}/model_last.pt"
)
print(f"Saved last checkpoint at step {step}")
else:
self.accelerator.save(
checkpoint, f"{self.checkpoint_path}/model_{step}.pt"
)
def load_checkpoint(self):
if (
not exists(self.checkpoint_path)
or not os.path.exists(self.checkpoint_path)
or not any(
filename.endswith(".pt")
for filename in os.listdir(self.checkpoint_path)
)
):
return 0
self.accelerator.wait_for_everyone()
if "model_last.pt" in os.listdir(self.checkpoint_path):
latest_checkpoint = "model_last.pt"
else:
latest_checkpoint = sorted(
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
key=lambda x: int("".join(filter(str.isdigit, x))),
)[-1]
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(
f"{self.checkpoint_path}/{latest_checkpoint}",
weights_only=True,
map_location="cpu",
)
# patch for backward compatibility, 305e3ea
for key in [
"ema_model.mel_spec.mel_stft.mel_scale.fb",
"ema_model.mel_spec.mel_stft.spectrogram.window",
]:
if key in checkpoint["ema_model_state_dict"]:
del checkpoint["ema_model_state_dict"][key]
if self.is_main:
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
if "step" in checkpoint:
# patch for backward compatibility, 305e3ea
for key in [
"mel_spec.mel_stft.mel_scale.fb",
"mel_spec.mel_stft.spectrogram.window",
]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
self.accelerator.unwrap_model(self.model).load_state_dict(
checkpoint["model_state_dict"]
)
self.accelerator.unwrap_model(self.optimizer).load_state_dict(
checkpoint["optimizer_state_dict"]
)
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
step = checkpoint["step"]
else:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
self.accelerator.unwrap_model(self.model).load_state_dict(
checkpoint["model_state_dict"]
)
step = 0
if "scale" in checkpoint:
self.scale = float(checkpoint["scale"])
self.model.scale = self.scale
if "count" in checkpoint:
self.count = int(checkpoint["count"])
del checkpoint
gc.collect()
return step
def train(
self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None
):
if self.log_samples:
from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
nfe_step, sway_sampling_coef)
vocoder = load_vocoder(
vocoder_name=self.vocoder_name,
is_local=self.is_local_vocoder,
local_path=self.local_vocoder_path,
)
target_sample_rate = self.accelerator.unwrap_model(
self.model
).mel_spec.target_sample_rate
log_samples_path = f"{self.checkpoint_path}/samples"
os.makedirs(log_samples_path, exist_ok=True)
if exists(resumable_with_seed):
generator = torch.Generator()
generator.manual_seed(resumable_with_seed)
else:
generator = None
if self.batch_size_type == "sample":
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_size=self.batch_size,
shuffle=True,
generator=generator,
)
elif self.batch_size_type == "frame":
self.accelerator.even_batches = False
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(
sampler,
self.batch_size,
max_samples=self.max_samples,
random_seed=resumable_with_seed,
drop_last=False,
)
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_sampler=batch_sampler,
)
else:
raise ValueError(
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}"
)
# accelerator.prepare() dispatches batches to devices;
# which means the length of dataloader calculated before, should consider the number of devices
warmup_steps = (
self.num_warmup_updates * self.accelerator.num_processes
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
decay_steps = total_steps - warmup_steps
warmup_scheduler = LinearLR(
self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps
)
decay_scheduler = LinearLR(
self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps
)
self.scheduler = SequentialLR(
self.optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[warmup_steps],
)
train_dataloader, self.scheduler = self.accelerator.prepare(
train_dataloader, self.scheduler
) # actual steps = 1 gpu steps / gpus
start_step = self.load_checkpoint()
global_step = start_step
if exists(resumable_with_seed):
orig_epoch_step = len(train_dataloader)
skipped_epoch = int(start_step // orig_epoch_step)
skipped_batch = start_step % orig_epoch_step
skipped_dataloader = self.accelerator.skip_first_batches(
train_dataloader, num_batches=skipped_batch
)
else:
skipped_epoch = 0
for epoch in range(skipped_epoch, self.epochs):
self.model.train()
if exists(resumable_with_seed) and epoch == skipped_epoch:
progress_bar = tqdm(
skipped_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
initial=skipped_batch,
total=orig_epoch_step,
)
else:
progress_bar = tqdm(
train_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
)
for batch in progress_bar:
with self.accelerator.accumulate(self.model):
text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1)
mel_lengths = batch["mel_lengths"]
self.count += 1
if self.scale is None:
self.scale = mel_spec.std()
else:
self.scale += (mel_spec.std() - self.scale) / self.count
mel_spec = mel_spec / self.scale # normalize mel spectrogram
# TODO. add duration predictor training
if (
self.duration_predictor is not None
and self.accelerator.is_local_main_process
):
dur_loss = self.duration_predictor(
mel_spec, lens=batch.get("durations")
)
self.accelerator.log(
{"duration loss": dur_loss.item()}, step=global_step
)
loss, cond, pred, t = self.model(
mel_spec,
text=text_inputs,
lens=mel_lengths,
noise_scheduler=self.noise_scheduler,
)
self.accelerator.backward(loss)
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
if self.is_main and self.accelerator.sync_gradients:
self.ema_model.update()
global_step += 1
if self.accelerator.is_local_main_process:
self.accelerator.log(
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]},
step=global_step,
)
if self.logger == "tensorboard":
self.writer.add_scalar("loss", loss.item(), global_step)
self.writer.add_scalar(
"lr", self.scheduler.get_last_lr()[0], global_step
)
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
if (
global_step % (self.save_per_updates * self.grad_accumulation_steps)
== 0
):
self.save_checkpoint(global_step)
if self.log_samples and self.accelerator.is_local_main_process:
gen_mel_spec = (
pred[0].unsqueeze(0).permute(0, 2, 1) * self.scale
)
ref_mel_spec = (
cond[0].unsqueeze(0).permute(0, 2, 1) * self.scale
)
with torch.inference_mode():
if self.vocoder_name == "vocos":
gen_audio = vocoder.decode(gen_mel_spec).cpu()
ref_audio = vocoder.decode(ref_mel_spec).cpu()
elif self.vocoder_name == "bigvgan":
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
gen_audio = wandb.Audio(
gen_audio.float().numpy().squeeze(),
sample_rate=24000,
caption="time: "
+ str(t[0].squeeze().float().cpu().numpy()),
)
ref_audio = wandb.Audio(
ref_audio.float().numpy().squeeze(),
sample_rate=24000,
caption="time: "
+ str(t[0].squeeze().float().cpu().numpy()),
)
self.accelerator.log(
{
"gen_audio": gen_audio,
"ref_audio": ref_audio,
},
step=global_step,
)
# if self.log_samples and self.accelerator.is_local_main_process:
# ref_audio_len = mel_lengths[0]
# infer_text = [
# text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
# ]
# with torch.inference_mode():
# # generated, _ = self.accelerator.unwrap_model(self.model).sample(
# # cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
# # text=infer_text,
# # duration=ref_audio_len * 2,
# # steps=nfe_step,
# # cfg_strength=cfg_strength,
# # sway_sampling_coef=sway_sampling_coef,
# # )
# # generated = generated.to(torch.float32)
# # gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
# # ref_mel_spec = batch["mel"][0].unsqueeze(0)
# gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1)
# ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1)
# if self.vocoder_name == "vocos":
# gen_audio = vocoder.decode(gen_mel_spec).cpu()
# ref_audio = vocoder.decode(ref_mel_spec).cpu()
# elif self.vocoder_name == "bigvgan":
# gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
# ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
# torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
# torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)
self.save_checkpoint(global_step, last=True)
self.accelerator.end_training()