Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import gc | |
import math | |
import os | |
import torch | |
import torch.nn as nn | |
import wandb | |
from accelerate import Accelerator | |
from accelerate.utils import DistributedDataParallelKwargs | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import LinearLR, SequentialLR | |
from torch.utils.data import DataLoader, Dataset, SequentialSampler | |
from tqdm import tqdm | |
from f5_tts.model import CFM | |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn | |
from f5_tts.model.utils import default, exists | |
from unimodel import UniModel | |
# trainer | |
class RunningStats: | |
def __init__(self): | |
self.count = 0 | |
self.mean = 0.0 | |
self.M2 = 0.0 # Sum of squared differences from the current mean | |
def update(self, x): | |
"""Update the running statistics with a new value x.""" | |
self.count += 1 | |
delta = x - self.mean | |
self.mean += delta / self.count | |
delta2 = x - self.mean | |
self.M2 += delta * delta2 | |
def variance(self): | |
"""Return the sample variance. Returns NaN if fewer than two samples.""" | |
return self.M2 / (self.count - 1) if self.count > 1 else float("nan") | |
def std(self): | |
"""Return the sample standard deviation.""" | |
return math.sqrt(self.variance) | |
class Trainer: | |
def __init__( | |
self, | |
model: UniModel, | |
epochs, | |
learning_rate, | |
num_warmup_updates=20000, | |
save_per_updates=1000, | |
checkpoint_path=None, | |
batch_size=32, | |
batch_size_type: str = "sample", | |
max_samples=32, | |
grad_accumulation_steps=1, | |
max_grad_norm=1.0, | |
noise_scheduler: str | None = None, | |
duration_predictor: torch.nn.Module | None = None, | |
wandb_project="test_e2-tts", | |
wandb_run_name="test_run", | |
wandb_resume_id: str = None, | |
last_per_steps=None, | |
log_step=1000, | |
accelerate_kwargs: dict = dict(), | |
bnb_optimizer: bool = False, | |
scale: float = 1.0, | |
# training parameters for DMDSpeech | |
num_student_step: int = 1, | |
gen_update_ratio: int = 5, | |
lambda_discriminator_loss: float = 1.0, | |
lambda_generator_loss: float = 1.0, | |
lambda_ctc_loss: float = 1.0, | |
lambda_sim_loss: float = 1.0, | |
num_GAN: int = 5000, | |
num_D: int = 500, | |
num_ctc: int = 5000, | |
num_sim: int = 10000, | |
num_simu: int = 1000, | |
): | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
logger = "wandb" if wandb.api.api_key else None | |
print(f"Using logger: {logger}") | |
self.accelerator = Accelerator( | |
log_with=logger, | |
kwargs_handlers=[ddp_kwargs], | |
gradient_accumulation_steps=grad_accumulation_steps, | |
**accelerate_kwargs, | |
) | |
if logger == "wandb": | |
if exists(wandb_resume_id): | |
init_kwargs = { | |
"wandb": { | |
"resume": "allow", | |
"name": wandb_run_name, | |
"id": wandb_resume_id, | |
} | |
} | |
else: | |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} | |
self.accelerator.init_trackers( | |
project_name=wandb_project, | |
init_kwargs=init_kwargs, | |
config={ | |
"epochs": epochs, | |
"learning_rate": learning_rate, | |
"num_warmup_updates": num_warmup_updates, | |
"batch_size": batch_size, | |
"batch_size_type": batch_size_type, | |
"max_samples": max_samples, | |
"grad_accumulation_steps": grad_accumulation_steps, | |
"max_grad_norm": max_grad_norm, | |
"gpus": self.accelerator.num_processes, | |
"noise_scheduler": noise_scheduler, | |
}, | |
) | |
self.model = model | |
self.scale = scale | |
self.epochs = epochs | |
self.num_warmup_updates = num_warmup_updates | |
self.save_per_updates = save_per_updates | |
self.last_per_steps = default( | |
last_per_steps, save_per_updates * grad_accumulation_steps | |
) | |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") | |
self.batch_size = batch_size | |
self.batch_size_type = batch_size_type | |
self.max_samples = max_samples | |
self.grad_accumulation_steps = grad_accumulation_steps | |
self.max_grad_norm = max_grad_norm | |
self.noise_scheduler = noise_scheduler | |
self.duration_predictor = duration_predictor | |
self.log_step = log_step | |
self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update | |
self.lambda_discriminator_loss = ( | |
lambda_discriminator_loss # weight for discriminator loss (L_adv) | |
) | |
self.lambda_generator_loss = ( | |
lambda_generator_loss # weight for generator loss (L_adv) | |
) | |
self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss | |
self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss | |
# create distillation schedule for student model | |
self.student_steps = torch.linspace(0.0, 1.0, num_student_step + 1)[:-1] | |
self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training | |
self.num_GAN = num_GAN # number of steps before adversarial training | |
self.num_D = num_D # number of steps to train the discriminator before adversarial training | |
self.num_ctc = num_ctc # number of steps before CTC training | |
self.num_sim = num_sim # number of steps before similarity training | |
self.num_simu = num_simu # number of steps before using simulated data | |
# Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible | |
if bnb_optimizer: | |
import bitsandbytes as bnb | |
self.optimizer_generator = bnb.optim.AdamW8bit( | |
self.model.feedforward_model.parameters(), lr=learning_rate | |
) | |
self.optimizer_guidance = bnb.optim.AdamW8bit( | |
self.model.guidance_model.parameters(), lr=learning_rate | |
) | |
else: | |
self.optimizer_generator = AdamW( | |
self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7 | |
) | |
self.optimizer_guidance = AdamW( | |
self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7 | |
) | |
self.model, self.optimizer_generator, self.optimizer_guidance = ( | |
self.accelerator.prepare( | |
self.model, self.optimizer_generator, self.optimizer_guidance | |
) | |
) | |
self.generator_norm = RunningStats() | |
self.guidance_norm = RunningStats() | |
def is_main(self): | |
return self.accelerator.is_main_process | |
def save_checkpoint(self, step, last=False): | |
self.accelerator.wait_for_everyone() | |
if self.is_main: | |
checkpoint = dict( | |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), | |
optimizer_generator_state_dict=self.accelerator.unwrap_model( | |
self.optimizer_generator | |
).state_dict(), | |
optimizer_guidance_state_dict=self.accelerator.unwrap_model( | |
self.optimizer_guidance | |
).state_dict(), | |
scheduler_generator_state_dict=self.scheduler_generator.state_dict(), | |
scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(), | |
step=step, | |
) | |
if not os.path.exists(self.checkpoint_path): | |
os.makedirs(self.checkpoint_path) | |
if last: | |
self.accelerator.save( | |
checkpoint, f"{self.checkpoint_path}/model_last.pt" | |
) | |
print(f"Saved last checkpoint at step {step}") | |
else: | |
self.accelerator.save( | |
checkpoint, f"{self.checkpoint_path}/model_{step}.pt" | |
) | |
def load_checkpoint(self): | |
if ( | |
not exists(self.checkpoint_path) | |
or not os.path.exists(self.checkpoint_path) | |
or not os.listdir(self.checkpoint_path) | |
): | |
return 0 | |
self.accelerator.wait_for_everyone() | |
if "model_last.pt" in os.listdir(self.checkpoint_path): | |
latest_checkpoint = "model_last.pt" | |
else: | |
latest_checkpoint = sorted( | |
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")], | |
key=lambda x: int("".join(filter(str.isdigit, x))), | |
)[-1] | |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ | |
checkpoint = torch.load( | |
f"{self.checkpoint_path}/{latest_checkpoint}", | |
weights_only=True, | |
map_location="cpu", | |
) | |
self.accelerator.unwrap_model(self.model).load_state_dict( | |
checkpoint["model_state_dict"], strict=False | |
) | |
# self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"]) | |
# self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"]) | |
# if self.scheduler_guidance: | |
# self.scheduler_guidance.load_state_dict(checkpoint["scheduler_guidance_state_dict"]) | |
# if self.scheduler_generator: | |
# self.scheduler_generator.load_state_dict(checkpoint["scheduler_generator_state_dict"]) | |
step = checkpoint["step"] | |
del checkpoint | |
gc.collect() | |
return step | |
def train( | |
self, | |
train_dataset: Dataset, | |
num_workers=64, | |
resumable_with_seed: int = None, | |
vocoder: nn.Module = None, | |
): | |
if exists(resumable_with_seed): | |
generator = torch.Generator() | |
generator.manual_seed(resumable_with_seed) | |
else: | |
generator = None | |
if self.batch_size_type == "sample": | |
train_dataloader = DataLoader( | |
train_dataset, | |
collate_fn=collate_fn, | |
num_workers=num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
batch_size=self.batch_size, | |
shuffle=True, | |
generator=generator, | |
) | |
elif self.batch_size_type == "frame": | |
self.accelerator.even_batches = False | |
sampler = SequentialSampler(train_dataset) | |
batch_sampler = DynamicBatchSampler( | |
sampler, | |
self.batch_size, | |
max_samples=self.max_samples, | |
random_seed=resumable_with_seed, | |
drop_last=False, | |
) | |
train_dataloader = DataLoader( | |
train_dataset, | |
collate_fn=collate_fn, | |
num_workers=num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
batch_sampler=batch_sampler, | |
) | |
else: | |
raise ValueError( | |
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" | |
) | |
# accelerator.prepare() dispatches batches to devices; | |
# which means the length of dataloader calculated before, should consider the number of devices | |
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes | |
# consider a fixed warmup steps while using accelerate multi-gpu ddp | |
# otherwise by default with split_batches=False, warmup steps change with num_processes | |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps | |
decay_steps = total_steps - warmup_steps | |
warmup_scheduler_generator = LinearLR( | |
self.optimizer_generator, | |
start_factor=1e-8, | |
end_factor=1.0, | |
total_iters=warmup_steps | |
// (self.gen_update_ratio * self.grad_accumulation_steps), | |
) | |
decay_scheduler_generator = LinearLR( | |
self.optimizer_generator, | |
start_factor=1.0, | |
end_factor=1e-8, | |
total_iters=decay_steps | |
// (self.gen_update_ratio * self.grad_accumulation_steps), | |
) | |
self.scheduler_generator = SequentialLR( | |
self.optimizer_generator, | |
schedulers=[warmup_scheduler_generator, decay_scheduler_generator], | |
milestones=[ | |
warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps) | |
], | |
) | |
warmup_scheduler_guidance = LinearLR( | |
self.optimizer_guidance, | |
start_factor=1e-8, | |
end_factor=1.0, | |
total_iters=warmup_steps, | |
) | |
decay_scheduler_guidance = LinearLR( | |
self.optimizer_guidance, | |
start_factor=1.0, | |
end_factor=1e-8, | |
total_iters=decay_steps, | |
) | |
self.scheduler_guidance = SequentialLR( | |
self.optimizer_guidance, | |
schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance], | |
milestones=[warmup_steps], | |
) | |
train_dataloader, self.scheduler_generator, self.scheduler_guidance = ( | |
self.accelerator.prepare( | |
train_dataloader, self.scheduler_generator, self.scheduler_guidance | |
) | |
) # actual steps = 1 gpu steps / gpus | |
start_step = self.load_checkpoint() | |
global_step = start_step | |
if exists(resumable_with_seed): | |
orig_epoch_step = len(train_dataloader) | |
skipped_epoch = int(start_step // orig_epoch_step) | |
skipped_batch = start_step % orig_epoch_step | |
skipped_dataloader = self.accelerator.skip_first_batches( | |
train_dataloader, num_batches=skipped_batch | |
) | |
else: | |
skipped_epoch = 0 | |
for epoch in range(skipped_epoch, self.epochs): | |
self.model.train() | |
if exists(resumable_with_seed) and epoch == skipped_epoch: | |
progress_bar = tqdm( | |
skipped_dataloader, | |
desc=f"Epoch {epoch+1}/{self.epochs}", | |
unit="step", | |
disable=not self.accelerator.is_local_main_process, | |
initial=skipped_batch, | |
total=orig_epoch_step, | |
) | |
else: | |
progress_bar = tqdm( | |
train_dataloader, | |
desc=f"Epoch {epoch+1}/{self.epochs}", | |
unit="step", | |
disable=not self.accelerator.is_local_main_process, | |
) | |
for batch in progress_bar: | |
update_generator = global_step % self.gen_update_ratio == 0 | |
with self.accelerator.accumulate(self.model): | |
metrics = {} | |
text_inputs = batch["text"] | |
mel_spec = batch["mel"].permute(0, 2, 1) | |
mel_lengths = batch["mel_lengths"] | |
mel_spec = mel_spec / self.scale | |
guidance_loss_dict, guidance_log_dict = self.model( | |
inp=mel_spec, | |
text=text_inputs, | |
lens=mel_lengths, | |
student_steps=self.student_steps, | |
update_generator=False, | |
use_simulated=global_step >= self.num_simu, | |
) | |
# if self.GAN and update_generator: | |
# # only add discriminator loss if GAN is enabled and generator is being updated | |
# guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0) | |
# metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"] | |
# self.accelerator.backward(guidance_cls_loss, retain_graph=True) | |
# if self.max_grad_norm > 0 and self.accelerator.sync_gradients: | |
# metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) | |
guidance_loss = 0 | |
guidance_loss += guidance_loss_dict["loss_fake_mean"] | |
metrics["loss/fake_score"] = guidance_loss_dict["loss_fake_mean"] | |
metrics["loss/guidance_loss"] = guidance_loss | |
if self.GAN and update_generator: | |
# only add discriminator loss if GAN is enabled and generator is being updated | |
guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * ( | |
self.lambda_discriminator_loss | |
if global_step >= self.num_GAN and update_generator | |
else 0 | |
) | |
metrics["loss/discriminator_loss"] = guidance_loss_dict[ | |
"guidance_cls_loss" | |
] | |
guidance_loss += guidance_cls_loss | |
self.accelerator.backward(guidance_loss) | |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients: | |
metrics["grad_norm_guidance"] = ( | |
self.accelerator.clip_grad_norm_( | |
self.model.parameters(), self.max_grad_norm | |
) | |
) | |
# if self.guidance_norm.count < 100: | |
# self.guidance_norm.update(metrics['grad_norm_guidance']) | |
# if metrics['grad_norm_guidance'] > self.guidance_norm.mean + 5 * self.guidance_norm.std: | |
# self.optimizer_generator.zero_grad() | |
# self.optimizer_guidance.zero_grad() | |
# print("Gradient explosion detected. Skipping batch.") | |
# elif self.guidance_norm.count >= 100: | |
# self.guidance_norm.update(metrics['grad_norm_guidance']) | |
self.optimizer_guidance.step() | |
self.scheduler_guidance.step() | |
self.optimizer_guidance.zero_grad() | |
self.optimizer_generator.zero_grad() # zero out the generator's gradient as well | |
if update_generator: | |
generator_loss_dict, generator_log_dict = self.model( | |
inp=mel_spec, | |
text=text_inputs, | |
lens=mel_lengths, | |
student_steps=self.student_steps, | |
update_generator=True, | |
use_simulated=global_step >= self.num_ctc, | |
) | |
# if self.GAN: | |
# gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0) | |
# metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"] | |
# self.accelerator.backward(gen_cls_loss, retain_graph=True) | |
# if self.max_grad_norm > 0 and self.accelerator.sync_gradients: | |
# metrics['grad_norm_generator'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) | |
generator_loss = 0 | |
generator_loss += generator_loss_dict["loss_dm"] | |
if "loss_mse" in generator_loss_dict: | |
generator_loss += generator_loss_dict["loss_mse"] | |
generator_loss += generator_loss_dict["loss_ctc"] * ( | |
self.lambda_ctc_loss if global_step >= self.num_ctc else 0 | |
) | |
generator_loss += generator_loss_dict["loss_sim"] * ( | |
self.lambda_sim_loss if global_step >= self.num_sim else 0 | |
) | |
generator_loss += generator_loss_dict["loss_kl"] * ( | |
self.lambda_ctc_loss if global_step >= self.num_ctc else 0 | |
) | |
if self.GAN: | |
gen_cls_loss = generator_loss_dict["gen_cls_loss"] * ( | |
self.lambda_generator_loss | |
if global_step >= (self.num_GAN + self.num_D) | |
and update_generator | |
else 0 | |
) | |
metrics["loss/gen_cls_loss"] = generator_loss_dict[ | |
"gen_cls_loss" | |
] | |
generator_loss += gen_cls_loss | |
metrics["loss/dm_loss"] = generator_loss_dict["loss_dm"] | |
metrics["loss/ctc_loss"] = generator_loss_dict["loss_ctc"] | |
metrics["loss/similarity_loss"] = generator_loss_dict[ | |
"loss_sim" | |
] | |
metrics["loss/generator_loss"] = generator_loss | |
if ( | |
"loss_mse" in generator_loss_dict | |
and generator_loss_dict["loss_mse"] != 0 | |
): | |
metrics["loss/mse_loss"] = generator_loss_dict["loss_mse"] | |
if ( | |
"loss_kl" in generator_loss_dict | |
and generator_loss_dict["loss_kl"] != 0 | |
): | |
metrics["loss/kl_loss"] = generator_loss_dict["loss_kl"] | |
self.accelerator.backward(generator_loss) | |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients: | |
metrics["grad_norm_generator"] = ( | |
self.accelerator.clip_grad_norm_( | |
self.model.parameters(), self.max_grad_norm | |
) | |
) | |
# self.generator_norm.update(metrics['grad_norm_generator']) | |
# if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std: | |
# self.optimizer_generator.zero_grad() | |
# self.optimizer_guidance.zero_grad() | |
# update_generator = False | |
# print("Gradient explosion detected. Skipping batch.") | |
if update_generator: | |
self.optimizer_generator.step() | |
self.scheduler_generator.step() | |
self.optimizer_generator.zero_grad() | |
self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well | |
global_step += 1 | |
if self.accelerator.is_local_main_process: | |
self.accelerator.log( | |
{ | |
**metrics, | |
"lr_generator": self.scheduler_generator.get_last_lr()[0], | |
"lr_guidance": self.scheduler_guidance.get_last_lr()[0], | |
}, | |
step=global_step, | |
) | |
if ( | |
global_step % self.log_step == 0 | |
and self.accelerator.is_local_main_process | |
and vocoder is not None | |
): | |
# log the first batch of the epoch | |
with torch.no_grad(): | |
generator_input = ( | |
generator_log_dict["generator_input"][0] | |
.unsqueeze(0) | |
.permute(0, 2, 1) | |
* self.scale | |
) | |
generator_input = vocoder.decode(generator_input.float().cpu()) | |
generator_input = wandb.Audio( | |
generator_input.float().numpy().squeeze(), | |
sample_rate=24000, | |
caption="time: " | |
+ str(generator_log_dict["time"][0].float().cpu().numpy()), | |
) | |
generator_output = ( | |
generator_log_dict["generator_output"][0] | |
.unsqueeze(0) | |
.permute(0, 2, 1) | |
* self.scale | |
) | |
generator_output = vocoder.decode( | |
generator_output.float().cpu() | |
) | |
generator_output = wandb.Audio( | |
generator_output.float().numpy().squeeze(), | |
sample_rate=24000, | |
caption="time: " | |
+ str(generator_log_dict["time"][0].float().cpu().numpy()), | |
) | |
generator_cond = ( | |
generator_log_dict["generator_cond"][0] | |
.unsqueeze(0) | |
.permute(0, 2, 1) | |
* self.scale | |
) | |
generator_cond = vocoder.decode(generator_cond.float().cpu()) | |
generator_cond = wandb.Audio( | |
generator_cond.float().numpy().squeeze(), | |
sample_rate=24000, | |
caption="time: " | |
+ str(generator_log_dict["time"][0].float().cpu().numpy()), | |
) | |
ground_truth = ( | |
generator_log_dict["ground_truth"][0] | |
.unsqueeze(0) | |
.permute(0, 2, 1) | |
* self.scale | |
) | |
ground_truth = vocoder.decode(ground_truth.float().cpu()) | |
ground_truth = wandb.Audio( | |
ground_truth.float().numpy().squeeze(), | |
sample_rate=24000, | |
caption="time: " | |
+ str(generator_log_dict["time"][0].float().cpu().numpy()), | |
) | |
dmtrain_noisy_inp = ( | |
generator_log_dict["dmtrain_noisy_inp"][0] | |
.unsqueeze(0) | |
.permute(0, 2, 1) | |
* self.scale | |
) | |
dmtrain_noisy_inp = vocoder.decode( | |
dmtrain_noisy_inp.float().cpu() | |
) | |
dmtrain_noisy_inp = wandb.Audio( | |
dmtrain_noisy_inp.float().numpy().squeeze(), | |
sample_rate=24000, | |
caption="dmtrain_time: " | |
+ str( | |
generator_log_dict["dmtrain_time"][0] | |
.float() | |
.cpu() | |
.numpy() | |
), | |
) | |
dmtrain_pred_real_image = ( | |
generator_log_dict["dmtrain_pred_real_image"][0] | |
.unsqueeze(0) | |
.permute(0, 2, 1) | |
* self.scale | |
) | |
dmtrain_pred_real_image = vocoder.decode( | |
dmtrain_pred_real_image.float().cpu() | |
) | |
dmtrain_pred_real_image = wandb.Audio( | |
dmtrain_pred_real_image.float().numpy().squeeze(), | |
sample_rate=24000, | |
caption="dmtrain_time: " | |
+ str( | |
generator_log_dict["dmtrain_time"][0] | |
.float() | |
.cpu() | |
.numpy() | |
), | |
) | |
dmtrain_pred_fake_image = ( | |
generator_log_dict["dmtrain_pred_fake_image"][0] | |
.unsqueeze(0) | |
.permute(0, 2, 1) | |
* self.scale | |
) | |
dmtrain_pred_fake_image = vocoder.decode( | |
dmtrain_pred_fake_image.float().cpu() | |
) | |
dmtrain_pred_fake_image = wandb.Audio( | |
dmtrain_pred_fake_image.float().numpy().squeeze(), | |
sample_rate=24000, | |
caption="dmtrain_time: " | |
+ str( | |
generator_log_dict["dmtrain_time"][0] | |
.float() | |
.cpu() | |
.numpy() | |
), | |
) | |
self.accelerator.log( | |
{ | |
"noisy_input": generator_input, | |
"output": generator_output, | |
"cond": generator_cond, | |
"ground_truth": ground_truth, | |
"dmtrain_noisy_inp": dmtrain_noisy_inp, | |
"dmtrain_pred_real_image": dmtrain_pred_real_image, | |
"dmtrain_pred_fake_image": dmtrain_pred_fake_image, | |
}, | |
step=global_step, | |
) | |
progress_bar.set_postfix(step=str(global_step), metrics=metrics) | |
if ( | |
global_step % (self.save_per_updates * self.grad_accumulation_steps) | |
== 0 | |
): | |
self.save_checkpoint(global_step) | |
if global_step % self.last_per_steps == 0: | |
self.save_checkpoint(global_step, last=True) | |
self.save_checkpoint(global_step, last=True) | |
self.accelerator.end_training() | |