|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script trains a ZipVoice model with the flow-matching loss. |
|
|
|
Usage: |
|
|
|
python3 -m zipvoice.bin.train_zipvoice \ |
|
--world-size 8 \ |
|
--use-fp16 1 \ |
|
--num-epochs 11 \ |
|
--max-duration 500 \ |
|
--lr-hours 30000 \ |
|
--model-config conf/zipvoice_base.json \ |
|
--tokenizer emilia \ |
|
--token-file "data/tokens_emilia.txt" \ |
|
--dataset emilia \ |
|
--manifest-dir data/fbank \ |
|
--exp-dir exp/zipvoice |
|
""" |
|
|
|
import argparse |
|
import copy |
|
import json |
|
import logging |
|
import os |
|
from functools import partial |
|
from pathlib import Path |
|
from shutil import copyfile |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.multiprocessing as mp |
|
import torch.nn as nn |
|
from lhotse.cut import Cut, CutSet |
|
from lhotse.utils import fix_random_seed |
|
from torch import Tensor |
|
from torch.amp import GradScaler, autocast |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.optim import Optimizer |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
import zipvoice.utils.diagnostics as diagnostics |
|
from zipvoice.dataset.datamodule import TtsDataModule |
|
from zipvoice.models.zipvoice import ZipVoice |
|
from zipvoice.tokenizer.tokenizer import ( |
|
EmiliaTokenizer, |
|
EspeakTokenizer, |
|
LibriTTSTokenizer, |
|
SimpleTokenizer, |
|
) |
|
from zipvoice.utils.checkpoint import ( |
|
load_checkpoint, |
|
remove_checkpoints, |
|
resume_checkpoint, |
|
save_checkpoint, |
|
save_checkpoint_with_global_batch_idx, |
|
update_averaged_model, |
|
) |
|
from zipvoice.utils.common import ( |
|
AttributeDict, |
|
MetricsTracker, |
|
cleanup_dist, |
|
get_adjusted_batch_count, |
|
get_env_info, |
|
get_parameter_groups_with_lrs, |
|
prepare_input, |
|
set_batch_count, |
|
setup_dist, |
|
setup_logger, |
|
str2bool, |
|
) |
|
from zipvoice.utils.hooks import register_inf_check_hooks |
|
from zipvoice.utils.lr_scheduler import Eden, FixedLRScheduler, LRScheduler |
|
from zipvoice.utils.optim import ScaledAdam |
|
|
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--world-size", |
|
type=int, |
|
default=1, |
|
help="Number of GPUs for DDP training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--master-port", |
|
type=int, |
|
default=12356, |
|
help="Master port to use for DDP training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--tensorboard", |
|
type=str2bool, |
|
default=True, |
|
help="Should various information be logged in tensorboard.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num-epochs", |
|
type=int, |
|
default=11, |
|
help="Number of epochs to train.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num-iters", |
|
type=int, |
|
default=0, |
|
help="Number of iter to train, will ignore num_epochs if > 0.", |
|
) |
|
|
|
parser.add_argument( |
|
"--start-epoch", |
|
type=int, |
|
default=1, |
|
help="""Resume training from this epoch. It should be positive. |
|
If larger than 1, it will load checkpoint from |
|
exp-dir/epoch-{start_epoch-1}.pt |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--checkpoint", |
|
type=str, |
|
default=None, |
|
help="""Checkpoints of pre-trained models, will load it if not None |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--exp-dir", |
|
type=str, |
|
default="exp/zipvoice", |
|
help="""The experiment dir. |
|
It specifies the directory where all training related |
|
files, e.g., checkpoints, log, etc, are saved |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--base-lr", type=float, default=0.02, help="The base learning rate." |
|
) |
|
|
|
parser.add_argument( |
|
"--lr-batches", |
|
type=float, |
|
default=7500, |
|
help="""Number of steps that affects how rapidly the learning rate |
|
decreases. We suggest not to change this.""", |
|
) |
|
|
|
parser.add_argument( |
|
"--lr-epochs", |
|
type=float, |
|
default=10, |
|
help="""Number of epochs that affects how rapidly the learning rate decreases. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--lr-hours", |
|
type=float, |
|
default=0, |
|
help="""If positive, --epoch is ignored and it specifies the number of hours |
|
that affects how rapidly the learning rate decreases. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--ref-duration", |
|
type=float, |
|
default=50, |
|
help="""Reference batch duration for purposes of adjusting batch counts for" |
|
setting various schedules inside the model". |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--finetune", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to use the fine-tuning mode, will used a fixed learning rate " |
|
"schedule and skip the large dropout phase.", |
|
) |
|
|
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=42, |
|
help="The seed for random generators intended for reproducibility", |
|
) |
|
|
|
parser.add_argument( |
|
"--print-diagnostics", |
|
type=str2bool, |
|
default=False, |
|
help="Accumulate stats on activations, print them and exit.", |
|
) |
|
|
|
parser.add_argument( |
|
"--scan-oom", |
|
type=str2bool, |
|
default=False, |
|
help="Scan pessimistic batches to see whether they cause OOMs.", |
|
) |
|
|
|
parser.add_argument( |
|
"--inf-check", |
|
type=str2bool, |
|
default=False, |
|
help="Add hooks to check for infinite module outputs and gradients.", |
|
) |
|
|
|
parser.add_argument( |
|
"--save-every-n", |
|
type=int, |
|
default=5000, |
|
help="""Save checkpoint after processing this number of batches" |
|
periodically. We save checkpoint to exp-dir/ whenever |
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename |
|
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' |
|
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the |
|
end of each epoch where `xxx` is the epoch number counting from 1. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--keep-last-k", |
|
type=int, |
|
default=30, |
|
help="""Only keep this number of checkpoints on disk. |
|
For instance, if it is 3, there are only 3 checkpoints |
|
in the exp-dir with filenames `checkpoint-xxx.pt`. |
|
It does not affect checkpoints with name `epoch-xxx.pt`. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--average-period", |
|
type=int, |
|
default=200, |
|
help="""Update the averaged model, namely `model_avg`, after processing |
|
this number of batches. `model_avg` is a separate version of model, |
|
in which each floating-point parameter is the average of all the |
|
parameters from the start of training. Each time we take the average, |
|
we do: `model_avg = model * (average_period / batch_idx_train) + |
|
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--use-fp16", |
|
type=str2bool, |
|
default=True, |
|
help="Whether to use half precision training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--feat-scale", |
|
type=float, |
|
default=0.1, |
|
help="The scale factor of fbank feature", |
|
) |
|
|
|
parser.add_argument( |
|
"--condition-drop-ratio", |
|
type=float, |
|
default=0.2, |
|
help="The drop rate of text condition during training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--dataset", |
|
type=str, |
|
default="emilia", |
|
choices=["emilia", "libritts", "custom"], |
|
help="The used training dataset", |
|
) |
|
|
|
parser.add_argument( |
|
"--train-manifest", |
|
type=str, |
|
help="Path of the training manifest", |
|
) |
|
|
|
parser.add_argument( |
|
"--dev-manifest", |
|
type=str, |
|
help="Path of the validation manifest", |
|
) |
|
|
|
parser.add_argument( |
|
"--min-len", |
|
type=float, |
|
default=1.0, |
|
help="The minimum audio length used for training", |
|
) |
|
|
|
parser.add_argument( |
|
"--max-len", |
|
type=float, |
|
default=30.0, |
|
help="The maximum audio length used for training", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-config", |
|
type=str, |
|
default="conf/zipvoice_base.json", |
|
help="The model configuration file.", |
|
) |
|
|
|
parser.add_argument( |
|
"--tokenizer", |
|
type=str, |
|
default="emilia", |
|
choices=["emilia", "libritts", "espeak", "simple"], |
|
help="Tokenizer type.", |
|
) |
|
|
|
parser.add_argument( |
|
"--lang", |
|
type=str, |
|
default="en-us", |
|
help="Language identifier, used when tokenizer type is espeak. see" |
|
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md", |
|
) |
|
|
|
parser.add_argument( |
|
"--token-file", |
|
type=str, |
|
default="data/tokens_emilia.txt", |
|
help="The file that contains information that maps tokens to ids," |
|
"which is a text file with '{token}\t{token_id}' per line.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def get_params() -> AttributeDict: |
|
"""Return a dict containing training parameters. |
|
|
|
All training related parameters that are not passed from the commandline |
|
are saved in the variable `params`. |
|
|
|
Commandline options are merged into `params` after they are parsed, so |
|
you can also access them via `params`. |
|
|
|
Explanation of options saved in `params`: |
|
|
|
- best_train_loss: Best training loss so far. It is used to select |
|
the model that has the lowest training loss. It is |
|
updated during the training. |
|
|
|
- best_valid_loss: Best validation loss so far. It is used to select |
|
the model that has the lowest validation loss. It is |
|
updated during the training. |
|
|
|
- best_train_epoch: It is the epoch that has the best training loss. |
|
|
|
- best_valid_epoch: It is the epoch that has the best validation loss. |
|
|
|
- batch_idx_train: Used to writing statistics to tensorboard. It |
|
contains number of batches trained so far across |
|
epochs. |
|
|
|
- log_interval: Print training loss if batch_idx % log_interval` is 0 |
|
|
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0 |
|
|
|
- env_info: A dict containing information about the environment. |
|
|
|
""" |
|
params = AttributeDict( |
|
{ |
|
"best_train_loss": float("inf"), |
|
"best_valid_loss": float("inf"), |
|
"best_train_epoch": -1, |
|
"best_valid_epoch": -1, |
|
"batch_idx_train": 0, |
|
"log_interval": 50, |
|
"reset_interval": 200, |
|
"env_info": get_env_info(), |
|
} |
|
) |
|
|
|
return params |
|
|
|
|
|
def compute_fbank_loss( |
|
params: AttributeDict, |
|
model: Union[nn.Module, DDP], |
|
features: Tensor, |
|
features_lens: Tensor, |
|
tokens: List[List[int]], |
|
is_training: bool, |
|
) -> Tuple[Tensor, MetricsTracker]: |
|
""" |
|
Compute loss given the model and its inputs. |
|
|
|
Args: |
|
params: |
|
Parameters for training. See :func:`get_params`. |
|
model: |
|
The model for training. |
|
features: |
|
The target acoustic feature. |
|
features_lens: |
|
The number of frames of each utterance. |
|
tokens: |
|
Input tokens that representing the transcripts. |
|
is_training: |
|
True for training. False for validation. When it is True, this |
|
function enables autograd during computation; when it is False, it |
|
disables autograd. |
|
""" |
|
|
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device |
|
|
|
batch_size, num_frames, _ = features.shape |
|
|
|
features = torch.nn.functional.pad( |
|
features, (0, 0, 0, num_frames - features.size(1)) |
|
) |
|
noise = torch.randn_like(features) |
|
|
|
|
|
if is_training: |
|
t = torch.rand(batch_size, 1, 1, device=device) |
|
else: |
|
t = ( |
|
(torch.arange(batch_size, device=device) / batch_size) |
|
.unsqueeze(1) |
|
.unsqueeze(2) |
|
) |
|
with torch.set_grad_enabled(is_training): |
|
|
|
loss = model( |
|
tokens=tokens, |
|
features=features, |
|
features_lens=features_lens, |
|
noise=noise, |
|
t=t, |
|
condition_drop_ratio=params.condition_drop_ratio, |
|
) |
|
|
|
assert loss.requires_grad == is_training |
|
info = MetricsTracker() |
|
num_frames = features_lens.sum().item() |
|
info["frames"] = num_frames |
|
info["loss"] = loss.detach().cpu().item() * num_frames |
|
|
|
return loss, info |
|
|
|
|
|
def train_one_epoch( |
|
params: AttributeDict, |
|
model: Union[nn.Module, DDP], |
|
optimizer: Optimizer, |
|
scheduler: LRSchedulerType, |
|
train_dl: torch.utils.data.DataLoader, |
|
valid_dl: torch.utils.data.DataLoader, |
|
scaler: GradScaler, |
|
model_avg: Optional[nn.Module] = None, |
|
tb_writer: Optional[SummaryWriter] = None, |
|
world_size: int = 1, |
|
rank: int = 0, |
|
) -> None: |
|
"""Train the model for one epoch. |
|
|
|
The training loss from the mean of all frames is saved in |
|
`params.train_loss`. It runs the validation process every |
|
`params.valid_interval` batches. |
|
|
|
Args: |
|
params: |
|
It is returned by :func:`get_params`. |
|
model: |
|
The model for training. |
|
optimizer: |
|
The optimizer. |
|
scheduler: |
|
The learning rate scheduler, we call step() every epoch. |
|
train_dl: |
|
Dataloader for the training dataset. |
|
valid_dl: |
|
Dataloader for the validation dataset. |
|
scaler: |
|
The scaler used for mix precision training. |
|
tb_writer: |
|
Writer to write log messages to tensorboard. |
|
world_size: |
|
Number of nodes in DDP training. If it is 1, DDP is disabled. |
|
rank: |
|
The rank of the node in DDP training. If no DDP is used, it should |
|
be set to 0. |
|
""" |
|
model.train() |
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device |
|
|
|
|
|
tot_loss = MetricsTracker() |
|
|
|
saved_bad_model = False |
|
|
|
def save_bad_model(suffix: str = ""): |
|
save_checkpoint( |
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", |
|
model=model, |
|
model_avg=model_avg, |
|
params=params, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
sampler=train_dl.sampler, |
|
scaler=scaler, |
|
rank=0, |
|
) |
|
|
|
for batch_idx, batch in enumerate(train_dl): |
|
|
|
if batch_idx % 10 == 0: |
|
if params.finetune: |
|
set_batch_count(model, get_adjusted_batch_count(params) + 100000) |
|
else: |
|
set_batch_count(model, get_adjusted_batch_count(params)) |
|
|
|
if ( |
|
params.batch_idx_train > 0 |
|
and params.batch_idx_train % params.valid_interval == 0 |
|
and not params.print_diagnostics |
|
): |
|
logging.info("Computing validation loss") |
|
valid_info = compute_validation_loss( |
|
params=params, |
|
model=model, |
|
valid_dl=valid_dl, |
|
world_size=world_size, |
|
) |
|
model.train() |
|
logging.info( |
|
f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train}," |
|
f" validation: {valid_info}" |
|
) |
|
logging.info( |
|
f"Maximum memory allocated so far is " |
|
f"{torch.cuda.max_memory_allocated() // 1000000}MB" |
|
) |
|
if tb_writer is not None: |
|
valid_info.write_summary( |
|
tb_writer, "train/valid_", params.batch_idx_train |
|
) |
|
|
|
params.batch_idx_train += 1 |
|
|
|
batch_size = len(batch["text"]) |
|
|
|
tokens, features, features_lens = prepare_input( |
|
params=params, |
|
batch=batch, |
|
device=device, |
|
return_tokens=True, |
|
return_feature=True, |
|
) |
|
|
|
try: |
|
with autocast("cuda", enabled=params.use_fp16): |
|
loss, loss_info = compute_fbank_loss( |
|
params=params, |
|
model=model, |
|
features=features, |
|
features_lens=features_lens, |
|
tokens=tokens, |
|
is_training=True, |
|
) |
|
|
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info |
|
|
|
scaler.scale(loss).backward() |
|
|
|
scheduler.step_batch(params.batch_idx_train) |
|
|
|
if params.lr_hours > 0: |
|
scheduler.step_epoch( |
|
params.batch_idx_train |
|
* params.max_duration |
|
* params.world_size |
|
/ 3600 |
|
) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad() |
|
except Exception as e: |
|
logging.info(f"Caught exception : {e}.") |
|
save_bad_model() |
|
raise |
|
|
|
if params.print_diagnostics and batch_idx == 5: |
|
return |
|
|
|
if ( |
|
rank == 0 |
|
and params.batch_idx_train > 0 |
|
and params.batch_idx_train % params.average_period == 0 |
|
): |
|
update_averaged_model( |
|
params=params, |
|
model_cur=model, |
|
model_avg=model_avg, |
|
) |
|
|
|
if ( |
|
params.batch_idx_train > 0 |
|
and params.batch_idx_train % params.save_every_n == 0 |
|
): |
|
save_checkpoint_with_global_batch_idx( |
|
out_dir=params.exp_dir, |
|
global_batch_idx=params.batch_idx_train, |
|
model=model, |
|
model_avg=model_avg, |
|
params=params, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
sampler=train_dl.sampler, |
|
scaler=scaler, |
|
rank=rank, |
|
) |
|
remove_checkpoints( |
|
out_dir=params.exp_dir, |
|
topk=params.keep_last_k, |
|
rank=rank, |
|
) |
|
if params.num_iters > 0 and params.batch_idx_train > params.num_iters: |
|
break |
|
if params.batch_idx_train % 100 == 0 and params.use_fp16: |
|
|
|
|
|
|
|
cur_grad_scale = scaler._scale.item() |
|
|
|
if cur_grad_scale < 1024.0 or ( |
|
cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0 |
|
): |
|
scaler.update(cur_grad_scale * 2.0) |
|
if cur_grad_scale < 0.01: |
|
if not saved_bad_model: |
|
save_bad_model(suffix="-first-warning") |
|
saved_bad_model = True |
|
logging.warning(f"Grad scale is small: {cur_grad_scale}") |
|
if cur_grad_scale < 1.0e-05: |
|
save_bad_model() |
|
raise RuntimeError( |
|
f"grad_scale is too small, exiting: {cur_grad_scale}" |
|
) |
|
|
|
if params.batch_idx_train % params.log_interval == 0: |
|
cur_lr = max(scheduler.get_last_lr()) |
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 |
|
|
|
logging.info( |
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, " |
|
f"global_batch_idx: {params.batch_idx_train}, " |
|
f"batch size: {batch_size}, " |
|
f"loss[{loss_info}], tot_loss[{tot_loss}], " |
|
f"cur_lr: {cur_lr:.2e}, " |
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") |
|
) |
|
|
|
if tb_writer is not None: |
|
tb_writer.add_scalar( |
|
"train/learning_rate", cur_lr, params.batch_idx_train |
|
) |
|
loss_info.write_summary( |
|
tb_writer, "train/current_", params.batch_idx_train |
|
) |
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) |
|
if params.use_fp16: |
|
tb_writer.add_scalar( |
|
"train/grad_scale", |
|
cur_grad_scale, |
|
params.batch_idx_train, |
|
) |
|
|
|
loss_value = tot_loss["loss"] |
|
params.train_loss = loss_value |
|
if params.train_loss < params.best_train_loss: |
|
params.best_train_epoch = params.cur_epoch |
|
params.best_train_loss = params.train_loss |
|
|
|
|
|
def compute_validation_loss( |
|
params: AttributeDict, |
|
model: Union[nn.Module, DDP], |
|
valid_dl: torch.utils.data.DataLoader, |
|
world_size: int = 1, |
|
) -> MetricsTracker: |
|
"""Run the validation process.""" |
|
|
|
model.eval() |
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device |
|
|
|
|
|
tot_loss = MetricsTracker() |
|
|
|
for batch_idx, batch in enumerate(valid_dl): |
|
tokens, features, features_lens = prepare_input( |
|
params=params, |
|
batch=batch, |
|
device=device, |
|
return_tokens=True, |
|
return_feature=True, |
|
) |
|
|
|
loss, loss_info = compute_fbank_loss( |
|
params=params, |
|
model=model, |
|
features=features, |
|
features_lens=features_lens, |
|
tokens=tokens, |
|
is_training=False, |
|
) |
|
assert loss.requires_grad is False |
|
tot_loss = tot_loss + loss_info |
|
|
|
if world_size > 1: |
|
tot_loss.reduce(loss.device) |
|
|
|
loss_value = tot_loss["loss"] |
|
if loss_value < params.best_valid_loss: |
|
params.best_valid_epoch = params.cur_epoch |
|
params.best_valid_loss = loss_value |
|
|
|
return tot_loss |
|
|
|
|
|
def display_and_save_batch( |
|
batch: dict, |
|
params: AttributeDict, |
|
) -> None: |
|
"""Display the batch statistics and save the batch into disk. |
|
|
|
Args: |
|
batch: |
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` |
|
for the content in it. |
|
params: |
|
Parameters for training. See :func:`get_params`. |
|
sp: |
|
The BPE model. |
|
""" |
|
from lhotse.utils import uuid4 |
|
|
|
filename = f"{params.exp_dir}/batch-{uuid4()}.pt" |
|
logging.info(f"Saving batch to {filename}") |
|
torch.save(batch, filename) |
|
|
|
features = batch["features"] |
|
tokens = batch["tokens"] |
|
|
|
logging.info(f"features shape: {features.shape}") |
|
num_tokens = sum(len(i) for i in tokens) |
|
logging.info(f"num tokens: {num_tokens}") |
|
|
|
|
|
def scan_pessimistic_batches_for_oom( |
|
model: Union[nn.Module, DDP], |
|
train_dl: torch.utils.data.DataLoader, |
|
optimizer: torch.optim.Optimizer, |
|
params: AttributeDict, |
|
): |
|
from lhotse.dataset import find_pessimistic_batches |
|
|
|
logging.info( |
|
"Sanity check -- see if any of the batches in epoch 1 would cause OOM." |
|
) |
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device |
|
|
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler) |
|
for criterion, cuts in batches.items(): |
|
batch = train_dl.dataset[cuts] |
|
tokens, features, features_lens = prepare_input( |
|
params=params, |
|
batch=batch, |
|
device=device, |
|
return_tokens=True, |
|
return_feature=True, |
|
) |
|
try: |
|
with autocast("cuda", enabled=params.use_fp16): |
|
|
|
loss, loss_info = compute_fbank_loss( |
|
params=params, |
|
model=model, |
|
features=features, |
|
features_lens=features_lens, |
|
tokens=tokens, |
|
is_training=True, |
|
) |
|
loss.backward() |
|
optimizer.zero_grad() |
|
except Exception as e: |
|
if "CUDA out of memory" in str(e): |
|
logging.error( |
|
"Your GPU ran out of memory with the current " |
|
"max_duration setting. We recommend decreasing " |
|
"max_duration and trying again.\n" |
|
f"Failing criterion: {criterion} " |
|
f"(={crit_values[criterion]}) ..." |
|
) |
|
display_and_save_batch(batch, params=params) |
|
raise |
|
logging.info( |
|
f"Maximum memory allocated so far is " |
|
f"{torch.cuda.max_memory_allocated() // 1000000}MB" |
|
) |
|
|
|
|
|
def tokenize_text(c: Cut, tokenizer): |
|
text = c.supervisions[0].text |
|
tokens = tokenizer.texts_to_token_ids([text]) |
|
c.supervisions[0].tokens = tokens[0] |
|
return c |
|
|
|
|
|
def run(rank, world_size, args): |
|
""" |
|
Args: |
|
rank: |
|
It is a value between 0 and `world_size-1`, which is |
|
passed automatically by `mp.spawn()` in :func:`main`. |
|
The node with rank 0 is responsible for saving checkpoint. |
|
world_size: |
|
Number of GPUs for DDP training. |
|
args: |
|
The return value of get_parser().parse_args() |
|
""" |
|
params = get_params() |
|
params.update(vars(args)) |
|
params.valid_interval = params.save_every_n |
|
|
|
if params.num_iters > 0: |
|
params.num_epochs = 1000000 |
|
with open(params.model_config, "r") as f: |
|
model_config = json.load(f) |
|
params.update(model_config["model"]) |
|
params.update(model_config["feature"]) |
|
|
|
fix_random_seed(params.seed) |
|
if world_size > 1: |
|
setup_dist(rank, world_size, params.master_port) |
|
|
|
os.makedirs(f"{params.exp_dir}", exist_ok=True) |
|
copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json") |
|
copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt") |
|
setup_logger(f"{params.exp_dir}/log/log-train") |
|
|
|
if args.tensorboard and rank == 0: |
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") |
|
else: |
|
tb_writer = None |
|
|
|
if torch.cuda.is_available(): |
|
params.device = torch.device("cuda", rank) |
|
else: |
|
params.device = torch.device("cpu") |
|
logging.info(f"Device: {params.device}") |
|
|
|
if params.tokenizer == "emilia": |
|
tokenizer = EmiliaTokenizer(token_file=params.token_file) |
|
elif params.tokenizer == "libritts": |
|
tokenizer = LibriTTSTokenizer(token_file=params.token_file) |
|
elif params.tokenizer == "espeak": |
|
tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang) |
|
else: |
|
assert params.tokenizer == "simple" |
|
tokenizer = SimpleTokenizer(token_file=params.token_file) |
|
|
|
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} |
|
params.update(tokenizer_config) |
|
|
|
logging.info(params) |
|
|
|
logging.info("About to create model") |
|
|
|
model = ZipVoice( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
|
|
if params.checkpoint is not None: |
|
logging.info(f"Loading pre-trained model from {params.checkpoint}") |
|
_ = load_checkpoint(filename=params.checkpoint, model=model, strict=True) |
|
num_param = sum([p.numel() for p in model.parameters()]) |
|
logging.info(f"Number of parameters : {num_param}") |
|
|
|
model_avg: Optional[nn.Module] = None |
|
if rank == 0: |
|
|
|
model_avg = copy.deepcopy(model).to(torch.float64) |
|
|
|
assert params.start_epoch > 0, params.start_epoch |
|
if params.start_epoch > 1: |
|
checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg) |
|
|
|
model = model.to(params.device) |
|
if world_size > 1: |
|
logging.info("Using DDP") |
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True) |
|
|
|
optimizer = ScaledAdam( |
|
get_parameter_groups_with_lrs( |
|
model, |
|
lr=params.base_lr, |
|
include_names=True, |
|
), |
|
lr=params.base_lr, |
|
clipping_scale=2.0, |
|
) |
|
|
|
assert params.lr_hours >= 0 |
|
|
|
if params.finetune: |
|
scheduler = FixedLRScheduler(optimizer) |
|
elif params.lr_hours > 0: |
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_hours) |
|
else: |
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) |
|
|
|
scaler = GradScaler("cuda", enabled=params.use_fp16) |
|
|
|
if params.start_epoch > 1 and checkpoints is not None: |
|
|
|
if "optimizer" in checkpoints: |
|
logging.info("Loading optimizer state dict") |
|
optimizer.load_state_dict(checkpoints["optimizer"]) |
|
|
|
|
|
if "scheduler" in checkpoints: |
|
logging.info("Loading scheduler state dict") |
|
scheduler.load_state_dict(checkpoints["scheduler"]) |
|
|
|
if "grad_scaler" in checkpoints: |
|
logging.info("Loading grad scaler state dict") |
|
scaler.load_state_dict(checkpoints["grad_scaler"]) |
|
|
|
if params.print_diagnostics: |
|
opts = diagnostics.TensorDiagnosticOptions( |
|
512 |
|
) |
|
diagnostic = diagnostics.attach_diagnostics(model, opts) |
|
|
|
if params.inf_check: |
|
register_inf_check_hooks(model) |
|
|
|
def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float): |
|
if c.duration < min_len or c.duration > max_len: |
|
return False |
|
return True |
|
|
|
_remove_short_and_long_utt = partial( |
|
remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len |
|
) |
|
|
|
datamodule = TtsDataModule(args) |
|
if params.dataset == "emilia": |
|
train_cuts = CutSet.mux( |
|
datamodule.train_emilia_EN_cuts(), |
|
datamodule.train_emilia_ZH_cuts(), |
|
weights=[46000, 49000], |
|
) |
|
train_cuts = train_cuts.filter(_remove_short_and_long_utt) |
|
dev_cuts = CutSet.mux( |
|
datamodule.dev_emilia_EN_cuts(), |
|
datamodule.dev_emilia_ZH_cuts(), |
|
weights=[0.5, 0.5], |
|
) |
|
elif params.dataset == "libritts": |
|
train_cuts = datamodule.train_libritts_cuts() |
|
train_cuts = train_cuts.filter(_remove_short_and_long_utt) |
|
dev_cuts = datamodule.dev_libritts_cuts() |
|
else: |
|
assert params.dataset == "custom" |
|
train_cuts = datamodule.train_custom_cuts(params.train_manifest) |
|
train_cuts = train_cuts.filter(_remove_short_and_long_utt) |
|
dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest) |
|
|
|
dev_cuts = dev_cuts.filter(_remove_short_and_long_utt) |
|
|
|
_tokenize_text = partial(tokenize_text, tokenizer=tokenizer) |
|
train_cuts = train_cuts.map(_tokenize_text) |
|
dev_cuts = dev_cuts.map(_tokenize_text) |
|
|
|
train_dl = datamodule.train_dataloaders(train_cuts) |
|
|
|
valid_dl = datamodule.dev_dataloaders(dev_cuts) |
|
|
|
if params.scan_oom: |
|
scan_pessimistic_batches_for_oom( |
|
model=model, |
|
train_dl=train_dl, |
|
optimizer=optimizer, |
|
params=params, |
|
) |
|
|
|
logging.info("Training started") |
|
|
|
for epoch in range(params.start_epoch, params.num_epochs + 1): |
|
logging.info(f"Start epoch {epoch}") |
|
|
|
if params.lr_hours == 0: |
|
scheduler.step_epoch(epoch - 1) |
|
fix_random_seed(params.seed + epoch - 1) |
|
train_dl.sampler.set_epoch(epoch - 1) |
|
|
|
params.cur_epoch = epoch |
|
|
|
if tb_writer is not None: |
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) |
|
|
|
train_one_epoch( |
|
params=params, |
|
model=model, |
|
model_avg=model_avg, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
train_dl=train_dl, |
|
valid_dl=valid_dl, |
|
scaler=scaler, |
|
tb_writer=tb_writer, |
|
world_size=world_size, |
|
rank=rank, |
|
) |
|
|
|
if params.num_iters > 0 and params.batch_idx_train > params.num_iters: |
|
break |
|
|
|
if params.print_diagnostics: |
|
diagnostic.print_diagnostics() |
|
break |
|
|
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" |
|
save_checkpoint( |
|
filename=filename, |
|
params=params, |
|
model=model, |
|
model_avg=model_avg, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
sampler=train_dl.sampler, |
|
scaler=scaler, |
|
rank=rank, |
|
) |
|
|
|
if rank == 0: |
|
if params.best_train_epoch == params.cur_epoch: |
|
best_train_filename = params.exp_dir / "best-train-loss.pt" |
|
copyfile(src=filename, dst=best_train_filename) |
|
|
|
if params.best_valid_epoch == params.cur_epoch: |
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt" |
|
copyfile(src=filename, dst=best_valid_filename) |
|
|
|
logging.info("Done!") |
|
|
|
if world_size > 1: |
|
torch.distributed.barrier() |
|
cleanup_dist() |
|
|
|
|
|
def main(): |
|
parser = get_parser() |
|
TtsDataModule.add_arguments(parser) |
|
args = parser.parse_args() |
|
args.exp_dir = Path(args.exp_dir) |
|
|
|
world_size = args.world_size |
|
assert world_size >= 1 |
|
if world_size > 1: |
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) |
|
else: |
|
run(rank=0, world_size=1, args=args) |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.set_num_threads(1) |
|
torch.set_num_interop_threads(1) |
|
main() |
|
|