Spaces:
Running
on
Zero
Running
on
Zero
import warnings | |
warnings.filterwarnings("ignore") # ignore all warnings | |
import diffusers.utils.logging as diffusion_logging | |
diffusion_logging.set_verbosity_error() # ignore diffusers warnings | |
from typing import * | |
from torch import Tensor | |
from torch.nn.parallel import DistributedDataParallel | |
from accelerate.optimizer import AcceleratedOptimizer | |
from accelerate.scheduler import AcceleratedScheduler | |
from accelerate.data_loader import DataLoaderShard | |
import os | |
import argparse | |
import logging | |
import math | |
from collections import defaultdict | |
from packaging import version | |
import gc | |
from tqdm import tqdm | |
import wandb | |
import numpy as np | |
from skimage.metrics import structural_similarity as calculate_ssim | |
from lpips import LPIPS | |
import torch | |
import torch.nn.functional as tF | |
from einops import rearrange, repeat | |
import accelerate | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger as get_accelerate_logger | |
from accelerate import DataLoaderConfiguration, DeepSpeedPlugin | |
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
from diffusers.training_utils import compute_snr, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 | |
from src.options import opt_dict, Options | |
from src.data import GObjaverseParquetDataset, ParquetChunkDataSource, MultiEpochsChunkedDataLoader, yield_forever | |
from src.models import GSAutoencoderKL, GSRecon, get_optimizer, get_lr_scheduler | |
import src.utils.util as util | |
import src.utils.geo_util as geo_util | |
import src.utils.vis_util as vis_util | |
from extensions.diffusers_diffsplat import MyEMAModel, SD3TransformerMV2DModel, StableMVDiffusion3Pipeline | |
def log_validation( | |
dataloader, negative_prompt_embed, negative_pooled_prompt_embed, lpips_loss, gsrecon, gsvae, vae, transformer, | |
global_step, accelerator, args, opt: Options, | |
): | |
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(opt.pretrained_model_name_or_path, subfolder="scheduler") | |
pipeline = StableMVDiffusion3Pipeline( | |
text_encoder=None, tokenizer=None, | |
text_encoder_2=None, tokenizer_2=None, | |
text_encoder_3=None, tokenizer_3=None, | |
vae=vae, transformer=accelerator.unwrap_model(transformer), | |
scheduler=noise_scheduler, | |
) | |
pipeline.set_progress_bar_config(disable=True) | |
# pipeline.enable_xformers_memory_efficient_attention() | |
if args.seed >= 0: | |
generator = None | |
else: | |
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | |
images_dictlist, metrics_dictlist = defaultdict(list), defaultdict(list) | |
val_progress_bar = tqdm( | |
range(len(dataloader)) if args.max_val_steps is None else range(args.max_val_steps), | |
desc=f"Validation", | |
ncols=125, | |
disable=not accelerator.is_main_process | |
) | |
for i, batch in enumerate(dataloader): | |
V_in, V_cond, V = opt.num_input_views, opt.num_cond_views, opt.num_views # TODO: not support V_cond > V_in by now | |
cond_idx = [0] # the first view must be in inputs | |
if V_cond > 1: | |
cond_idx += np.random.choice(range(1, V), V_cond-1, replace=False).tolist() | |
imgs_cond = batch["image"][:, cond_idx, ...] # (B, V_cond, 3, H, W) | |
B = imgs_cond.shape[0] | |
imgs_out = batch["image"] # (B, V, 3, H, W); for visualization and evaluation | |
imgs_out = rearrange(imgs_out, "b v c h w -> (b v) c h w") | |
prompt_embeds = batch["prompt_embed"] # (B, N, D) | |
negative_prompt_embeds = repeat(negative_prompt_embed.to(accelerator.device), "n d -> b n d", b=B) | |
pooled_prompt_embeds = batch["pooled_prompt_embed"] # (B, D) | |
negative_pooled_prompt_embeds = repeat(negative_pooled_prompt_embed.to(accelerator.device), "d -> b d", b=B) | |
C2W = batch["C2W"] | |
fxfycxcy = batch["fxfycxcy"] | |
input_C2W = C2W[:, :V_in, ...] | |
input_fxfycxcy = fxfycxcy[:, :V_in, ...] | |
cond_C2W = C2W[:, cond_idx,...] | |
cond_fxfycxcy = fxfycxcy[:, cond_idx,...] | |
# Plucker embeddings | |
if opt.input_concat_plucker: | |
H = W = opt.input_res | |
plucker, _ = geo_util.plucker_ray(H, W, input_C2W, input_fxfycxcy) # (B, V_in, 6, H, W) | |
if opt.view_concat_condition: | |
cond_plucker, _ = geo_util.plucker_ray(H, W, cond_C2W, cond_fxfycxcy) # (B, V_cond, 6, H, W) | |
plucker = torch.cat([cond_plucker, plucker], dim=1) # (B, V_cond+V_in, 6, H, W) | |
plucker = rearrange(plucker, "b v c h w -> (b v) c h w") | |
else: | |
plucker = None | |
images_dictlist["gt"].append(imgs_out) # (B*V, C=3, H, W) | |
if opt.vis_coords and opt.load_coord: | |
coords_out = rearrange(batch["coord"], "b v c h w -> (b v) c h w") # (B*V, C=3, H, W) | |
images_dictlist["gt_coord"].append(coords_out) | |
if opt.vis_normals and opt.load_normal: | |
normals_out = rearrange(batch["normal"], "b v c h w -> (b v) c h w") # (B*V, C=3, H, W) | |
images_dictlist["gt_normal"].append(normals_out) | |
with torch.autocast("cuda", torch.bfloat16): | |
for guidance_scale in sorted(args.val_guidance_scales): | |
out = pipeline( | |
imgs_cond, num_inference_steps=opt.num_inference_steps, guidance_scale=guidance_scale, | |
output_type="latent", generator=generator, | |
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
plucker=plucker, num_views=V_in, | |
init_std=opt.init_std, init_noise_strength=opt.init_noise_strength, init_bg=opt.init_bg, | |
).images | |
# Rendering GS latents | |
out = out / gsvae.scaling_factor + gsvae.shift_factor | |
render_outputs = gsvae.decode_and_render_gslatents(gsrecon, out, input_C2W, input_fxfycxcy, C2W, fxfycxcy) | |
render_images = rearrange(render_outputs["image"], "b v c h w -> (b v) c h w") # (B*V, C=3, H, W) | |
images_dictlist[f"pred_cfg{guidance_scale:.1f}"].append(render_images) | |
if opt.vis_coords: | |
render_coords = rearrange(render_outputs["coord"], "b v c h w -> (b v) c h w") # (B*V, 3, H, W) | |
images_dictlist[f"pred_coord_cfg{guidance_scale:.1f}"].append(render_coords) | |
if opt.vis_normals: | |
render_normals = rearrange(render_outputs["normal"], "b v c h w -> (b v) c h w") # (B*V, 3, H, W) | |
images_dictlist[f"pred_normal_cfg{guidance_scale:.1f}"].append(render_normals) | |
# Decode to pseudo images | |
if opt.vis_pseudo_images: | |
out = (out - gsvae.shift_factor) * gsvae.scaling_factor / vae.config.scaling_factor + vae.config.shift_factor | |
images = vae.decode(out).sample.clamp(-1., 1.) * 0.5 + 0.5 | |
images_dictlist[f"pred_image_cfg{guidance_scale:.1f}"].append(images) # (B*V_in, 3, H, W) | |
################################ Compute generation metrics ################################ | |
lpips = lpips_loss( | |
# Downsampled to at most 256 to reduce memory cost | |
tF.interpolate(imgs_out * 2. - 1., (256, 256), mode="bilinear", align_corners=False), | |
tF.interpolate(render_images * 2. - 1., (256, 256), mode="bilinear", align_corners=False) | |
).mean() | |
psnr = -10. * torch.log10(tF.mse_loss(imgs_out, render_images)) | |
ssim = torch.tensor(calculate_ssim( | |
(rearrange(imgs_out, "bv c h w -> (bv c) h w").cpu().float().numpy() * 255.).astype(np.uint8), | |
(rearrange(render_images, "bv c h w -> (bv c) h w").cpu().float().numpy() * 255.).astype(np.uint8), | |
channel_axis=0, | |
), device=render_images.device) | |
lpips = accelerator.gather_for_metrics(lpips.repeat(B)).mean() | |
psnr = accelerator.gather_for_metrics(psnr.repeat(B)).mean() | |
ssim = accelerator.gather_for_metrics(ssim.repeat(B)).mean() | |
metrics_dictlist[f"lpips_cfg{guidance_scale:.1f}"].append(lpips) | |
metrics_dictlist[f"psnr_cfg{guidance_scale:.1f}"].append(psnr) | |
metrics_dictlist[f"ssim_cfg{guidance_scale:.1f}"].append(ssim) | |
if opt.coord_weight > 0.: | |
assert opt.load_coord | |
coord_mse = tF.mse_loss(coords_out, render_coords) | |
coord_mse = accelerator.gather_for_metrics(coord_mse.repeat(B)).mean() | |
metrics_dictlist[f"coord_mse_cfg{guidance_scale:.1f}"].append(coord_mse) | |
if opt.normal_weight > 0.: | |
assert opt.load_normal | |
normal_cosim = tF.cosine_similarity(normals_out, render_normals, dim=2).mean() | |
normal_cosim = accelerator.gather_for_metrics(normal_cosim.repeat(B)).mean() | |
metrics_dictlist[f"normal_cosim_cfg{guidance_scale:.1f}"].append(normal_cosim) | |
# Only log the last (biggest) cfg metrics in the progress bar | |
val_logs = { | |
"lpips": lpips.item(), | |
"psnr": psnr.item(), | |
"ssim": ssim.item(), | |
} | |
val_progress_bar.set_postfix(**val_logs) | |
val_progress_bar.update(1) | |
if args.max_val_steps is not None and i == (args.max_val_steps - 1): | |
break | |
val_progress_bar.close() | |
if accelerator.is_main_process: | |
formatted_images = [] | |
for k, v in images_dictlist.items(): # "gs_gt", "pred_cfg1.0", "pred_cfg3.0", ... | |
mvimages = torch.cat(v, dim=0) # (N*B*V, C, H, W) | |
mvimages = rearrange(mvimages, "(nb v) c h w -> nb v c h w", v=V if "image" not in k else V_in) | |
mvimages = mvimages[:min(mvimages.shape[0], 4), ...] # max show `4` samples; TODO: make it configurable | |
mvimages = rearrange(mvimages, "nb v c h w -> c (nb h) (v w)") | |
mvimages = vis_util.tensor_to_image(mvimages.detach()) | |
formatted_images.append(wandb.Image(mvimages, caption=k)) | |
wandb.log({"images/validation": formatted_images}, step=global_step) | |
for k, v in metrics_dictlist.items(): # "lpips_cfg1.0", "psnr_cfg3.0", ... | |
if "cfg1.0" in k: | |
wandb.log({f"validation_cfg1.0/{k}": torch.tensor(v).mean().item()}, step=global_step) | |
else: | |
wandb.log({f"validation/{k}": torch.tensor(v).mean().item()}, step=global_step) | |
def main(): | |
PROJECT_NAME = "DiffSplat" | |
parser = argparse.ArgumentParser( | |
description="Train a diffusion model for 3D object generation", | |
) | |
parser.add_argument( | |
"--config_file", | |
type=str, | |
required=True, | |
help="Path to the config file" | |
) | |
parser.add_argument( | |
"--tag", | |
type=str, | |
required=True, | |
help="Tag that refers to the current experiment" | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="out", | |
help="Path to the output directory" | |
) | |
parser.add_argument( | |
"--hdfs_dir", | |
type=str, | |
default=None, | |
help="Path to the HDFS directory to save checkpoints" | |
) | |
parser.add_argument( | |
"--wandb_token_path", | |
type=str, | |
default="wandb/token", | |
help="Path to the WandB login token" | |
) | |
parser.add_argument( | |
"--resume_from_iter", | |
type=int, | |
default=None, | |
help="The iteration to load the checkpoint from" | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=0, | |
help="Seed for the PRNG" | |
) | |
parser.add_argument( | |
"--offline_wandb", | |
action="store_true", | |
help="Use offline WandB for experiment tracking" | |
) | |
parser.add_argument( | |
"--max_train_steps", | |
type=int, | |
default=None, | |
help="The max iteration step for training" | |
) | |
parser.add_argument( | |
"--max_val_steps", | |
type=int, | |
default=1, | |
help="The max iteration step for validation" | |
) | |
parser.add_argument( | |
"--num_workers", | |
type=int, | |
default=32, | |
help="The number of processed spawned by the batch provider" | |
) | |
parser.add_argument( | |
"--pin_memory", | |
action="store_true", | |
help="Pin memory for the data loader" | |
) | |
parser.add_argument( | |
"--use_ema", | |
action="store_true", | |
help="Use EMA model for training" | |
) | |
parser.add_argument( | |
"--scale_lr", | |
action="store_true", | |
help="Scale lr with total batch size (base batch size: 256)" | |
) | |
parser.add_argument( | |
"--max_grad_norm", | |
type=float, | |
default=1., | |
help="Max gradient norm for gradient clipping" | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass" | |
) | |
parser.add_argument( | |
"--mixed_precision", | |
type=str, | |
default="fp16", | |
choices=["no", "fp16", "bf16"], | |
help="Type of mixed precision training" | |
) | |
parser.add_argument( | |
"--allow_tf32", | |
action="store_true", | |
help="Enable TF32 for faster training on Ampere GPUs" | |
) | |
parser.add_argument( | |
"--val_guidance_scales", | |
type=list, | |
nargs="+", | |
default=[1., 3., 7.5], | |
help="CFG scale used for validation" | |
) | |
parser.add_argument( | |
"--use_deepspeed", | |
action="store_true", | |
help="Use DeepSpeed for training" | |
) | |
parser.add_argument( | |
"--zero_stage", | |
type=int, | |
default=1, | |
choices=[1, 2, 3], # https://huggingface.co/docs/accelerate/usage_guides/deepspeed | |
help="ZeRO stage type for DeepSpeed" | |
) | |
parser.add_argument( | |
"--load_pretrained_gsrecon", | |
type=str, | |
default="gsrecon_gobj265k_cnp_even4", | |
help="Tag of a pretrained GSRecon in this project" | |
) | |
parser.add_argument( | |
"--load_pretrained_gsrecon_ckpt", | |
type=int, | |
default=-1, | |
help="Iteration of the pretrained GSRecon checkpoint" | |
) | |
parser.add_argument( | |
"--load_pretrained_gsvae", | |
type=str, | |
default="gsvae_gobj265k_sd3", | |
help="Tag of a pretrained GSVAE in this project" | |
) | |
parser.add_argument( | |
"--load_pretrained_gsvae_ckpt", | |
type=int, | |
default=-1, | |
help="Iteration of the pretrained GSVAE checkpoint" | |
) | |
parser.add_argument( | |
"--load_pretrained_model", | |
type=str, | |
default=None, | |
help="Tag of a pretrained MVTransformer in this project" | |
) | |
parser.add_argument( | |
"--load_pretrained_model_ckpt", | |
type=int, | |
default=-1, | |
help="Iteration of the pretrained MVTransformer checkpoint" | |
) | |
# Parse the arguments | |
args, extras = parser.parse_known_args() | |
args.val_guidance_scales = [float(x[0]) if isinstance(x, list) else float(x) for x in args.val_guidance_scales] | |
# Parse the config file | |
configs = util.get_configs(args.config_file, extras) # change yaml configs by `extras` | |
# Parse the option dict | |
opt = opt_dict[configs["opt_type"]] | |
if "opt" in configs: | |
for k, v in configs["opt"].items(): | |
setattr(opt, k, v) | |
opt.__post_init__() | |
# Create an experiment directory using the `tag` | |
exp_dir = os.path.join(args.output_dir, args.tag) | |
ckpt_dir = os.path.join(exp_dir, "checkpoints") | |
os.makedirs(ckpt_dir, exist_ok=True) | |
if args.hdfs_dir is not None: | |
args.project_hdfs_dir = args.hdfs_dir | |
args.hdfs_dir = os.path.join(args.hdfs_dir, args.tag) | |
os.system(f"hdfs dfs -mkdir -p {args.hdfs_dir}") | |
# Initialize the logger | |
logging.basicConfig( | |
format="%(asctime)s - %(message)s", | |
datefmt="%Y/%m/%d %H:%M:%S", | |
level=logging.INFO | |
) | |
logger = get_accelerate_logger(__name__, log_level="INFO") | |
file_handler = logging.FileHandler(os.path.join(exp_dir, "log.txt")) # output to file | |
file_handler.setFormatter(logging.Formatter( | |
fmt="%(asctime)s - %(message)s", | |
datefmt="%Y/%m/%d %H:%M:%S" | |
)) | |
logger.logger.addHandler(file_handler) | |
logger.logger.propagate = True # propagate to the root logger (console) | |
# Set DeepSpeed config | |
if args.use_deepspeed: | |
deepspeed_plugin = DeepSpeedPlugin( | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
gradient_clipping=args.max_grad_norm, | |
zero_stage=int(args.zero_stage), | |
offload_optimizer_device="cpu", # hard-coded here, TODO: make it configurable | |
) | |
else: | |
deepspeed_plugin = None | |
# Initialize the accelerator | |
accelerator = Accelerator( | |
project_dir=exp_dir, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
mixed_precision=args.mixed_precision, | |
split_batches=False, # batch size per GPU | |
dataloader_config=DataLoaderConfiguration(non_blocking=args.pin_memory), | |
deepspeed_plugin=deepspeed_plugin, | |
) | |
logger.info(f"Accelerator state:\n{accelerator.state}\n") | |
# Set the random seed | |
if args.seed >= 0: | |
accelerate.utils.set_seed(args.seed) | |
logger.info(f"You have chosen to seed([{args.seed}]) the experiment [{args.tag}]\n") | |
# Enable TF32 for faster training on Ampere GPUs | |
if args.allow_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# Prepare dataset | |
if accelerator.is_local_main_process: | |
if not os.path.exists("/tmp/test_dataset"): | |
os.system(opt.dataset_setup_script) | |
accelerator.wait_for_everyone() # other processes wait for the main process | |
# Load the training and validation dataset | |
assert opt.file_dir_train is not None and opt.file_name_train is not None and \ | |
opt.file_dir_test is not None and opt.file_name_test is not None | |
train_dataset = GObjaverseParquetDataset( | |
data_source=ParquetChunkDataSource(opt.file_dir_train, opt.file_name_train), | |
shuffle=True, | |
shuffle_buffer_size=-1, # `-1`: not shuffle actually | |
chunks_queue_max_size=1, # number of preloading chunks | |
# GObjaverse | |
opt=opt, | |
training=True, | |
) | |
val_dataset = GObjaverseParquetDataset( | |
data_source=ParquetChunkDataSource(opt.file_dir_test, opt.file_name_test), | |
shuffle=True, # shuffle for various visualization | |
shuffle_buffer_size=-1, # `-1`: not shuffle actually | |
chunks_queue_max_size=1, # number of preloading chunks | |
# GObjaverse | |
opt=opt, | |
training=False, | |
) | |
train_loader = MultiEpochsChunkedDataLoader( | |
train_dataset, | |
batch_size=configs["train"]["batch_size_per_gpu"], | |
num_workers=args.num_workers, | |
drop_last=True, | |
pin_memory=args.pin_memory, | |
) | |
val_loader = MultiEpochsChunkedDataLoader( | |
val_dataset, | |
batch_size=configs["val"]["batch_size_per_gpu"], | |
num_workers=args.num_workers, | |
drop_last=True, | |
pin_memory=args.pin_memory, | |
) | |
logger.info(f"Load [{len(train_dataset)}] training samples and [{len(val_dataset)}] validation samples\n") | |
negative_prompt_embed = train_dataset.negative_prompt_embed | |
negative_pooled_prompt_embed = train_dataset.negative_pooled_prompt_embed | |
# Compute the effective batch size and scale learning rate | |
total_batch_size = configs["train"]["batch_size_per_gpu"] * \ | |
accelerator.num_processes * args.gradient_accumulation_steps | |
configs["train"]["total_batch_size"] = total_batch_size | |
if args.scale_lr: | |
configs["optimizer"]["lr"] *= (total_batch_size / 256) | |
configs["lr_scheduler"]["max_lr"] = configs["optimizer"]["lr"] | |
# LPIPS loss | |
if accelerator.is_main_process: | |
_ = LPIPS(net="vgg") | |
del _ | |
accelerator.wait_for_everyone() # wait for pretrained backbone weights to be downloaded | |
lpips_loss = LPIPS(net="vgg").to(accelerator.device) | |
lpips_loss = lpips_loss.requires_grad_(False) | |
lpips_loss.eval() | |
# GSRecon | |
gsrecon = GSRecon(opt) | |
gsrecon = gsrecon.requires_grad_(False) | |
gsrecon = gsrecon.eval() | |
# Initialize the model, optimizer and lr scheduler | |
in_channels = 16 # hard-coded for SD3 | |
if opt.input_concat_plucker: | |
in_channels += 6 | |
if opt.input_concat_binary_mask: | |
in_channels += 1 | |
transformer_from_pretrained_kwargs = { | |
"sample_size": opt.input_res // 8, # `8` hard-coded for SD3 | |
"in_channels": in_channels, | |
"zero_init_conv_in": opt.zero_init_conv_in, | |
"view_concat_condition": opt.view_concat_condition, | |
"input_concat_plucker": opt.input_concat_plucker, | |
"input_concat_binary_mask": opt.input_concat_binary_mask, | |
} | |
vae = AutoencoderKL.from_pretrained(opt.pretrained_model_name_or_path, subfolder="vae") | |
if args.load_pretrained_model is None: | |
transformer, loading_info = SD3TransformerMV2DModel.from_pretrained_new(opt.pretrained_model_name_or_path, subfolder="transformer", | |
low_cpu_mem_usage=False, ignore_mismatched_sizes=True, output_loading_info=True, **transformer_from_pretrained_kwargs) | |
logger.info(f"Loading info: {loading_info}\n") | |
else: | |
logger.info(f"Load MVTransformer EMA checkpoint from [{args.load_pretrained_model}] iteration [{args.load_pretrained_model_ckpt:06d}]\n") | |
args.load_pretrained_model_ckpt = util.load_ckpt( | |
os.path.join(args.output_dir, args.load_pretrained_model, "checkpoints"), | |
args.load_pretrained_model_ckpt, | |
None if args.hdfs_dir is None else os.path.join(args.project_hdfs_dir, args.load_pretrained_model_ckpt), | |
None, # `None`: not load model ckpt here | |
accelerator, # manage the process states | |
) | |
path = f"out/{args.load_pretrained_model}/checkpoints/{args.load_pretrained_model_ckpt:06d}" | |
os.system(f"python3 extensions/merge_safetensors.py {path}/transformer_ema") # merge safetensors for loading | |
transformer, loading_info = SD3TransformerMV2DModel.from_pretrained_new(path, subfolder="transformer_ema", | |
low_cpu_mem_usage=False, ignore_mismatched_sizes=True, output_loading_info=True, **transformer_from_pretrained_kwargs) | |
logger.info(f"Loading info: {loading_info}\n") | |
gsvae = GSAutoencoderKL(opt) | |
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(opt.pretrained_model_name_or_path, subfolder="scheduler") | |
if args.use_ema: | |
ema_transformer = MyEMAModel( | |
transformer.parameters(), | |
model_cls=SD3TransformerMV2DModel, | |
model_config=transformer.config, | |
**configs["train"]["ema_kwargs"] | |
) | |
# Freeze VAE and GSVAE | |
vae.requires_grad_(False) | |
gsvae.requires_grad_(False) | |
vae.eval() | |
gsvae.eval() | |
trainable_module_names = [] | |
if opt.trainable_modules is None: | |
transformer.requires_grad_(True) | |
else: | |
transformer.requires_grad_(False) | |
for name, module in transformer.named_modules(): | |
for module_name in tuple(opt.trainable_modules.split(",")): | |
if module_name in name: | |
for params in module.parameters(): | |
params.requires_grad = True | |
trainable_module_names.append(name) | |
logger.info(f"Trainable parameter names: {trainable_module_names}\n") | |
# transformer.enable_xformers_memory_efficient_attention() # use `tF.scaled_dot_product_attention` instead | |
# `accelerate` 0.16.0 will have better support for customized saving | |
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | |
# Create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
def save_model_hook(models, weights, output_dir): | |
if accelerator.is_main_process: | |
if args.use_ema: | |
# NOTE: `pos_embed` of SD3 is not fixed parameters (register_buffer) as those in `PatchEmbed`, | |
# so `model = self.model_cls.from_config(self.model_config)` in `EMAModel` | |
# will initialize a wrong weight for `transformer.pos_embed.pos_embed`. | |
# Here, we manually handle this case for saving transformer EMA parameters. | |
# ema_transformer.save_pretrained(os.path.join(output_dir, "transformer_ema")) | |
from copy import deepcopy | |
model = deepcopy(accelerator.unwrap_model(transformer)) | |
state_dict = ema_transformer.state_dict() | |
state_dict.pop("shadow_params", None) | |
model.register_to_config(**state_dict) | |
ema_transformer.copy_to(model.parameters()) | |
model.save_pretrained(os.path.join(output_dir, "transformer_ema")) | |
del model | |
for i, model in enumerate(models): | |
model.save_pretrained(os.path.join(output_dir, "transformer")) | |
# Make sure to pop weight so that corresponding model is not saved again | |
if weights: | |
weights.pop() | |
def load_model_hook(models, input_dir): | |
if args.use_ema: | |
load_model = MyEMAModel.from_pretrained(os.path.join(input_dir, "transformer_ema"), SD3TransformerMV2DModel) | |
ema_transformer.load_state_dict(load_model.state_dict()) | |
ema_transformer.to(accelerator.device) | |
del load_model | |
for _ in range(len(models)): | |
# Pop models so that they are not loaded again | |
model = models.pop() | |
# Load diffusers style into model | |
load_model = SD3TransformerMV2DModel.from_pretrained(input_dir, subfolder="transformer") | |
model.register_to_config(**load_model.config) | |
model.load_state_dict(load_model.state_dict()) | |
del load_model | |
accelerator.register_save_state_pre_hook(save_model_hook) | |
accelerator.register_load_state_pre_hook(load_model_hook) | |
if opt.grad_checkpoint: | |
transformer.enable_gradient_checkpointing() | |
params, params_lr_mult, names_lr_mult = [], [], [] | |
for name, param in transformer.named_parameters(): | |
if opt.name_lr_mult is not None: | |
for k in opt.name_lr_mult.split(","): | |
if k in name: | |
params_lr_mult.append(param) | |
names_lr_mult.append(name) | |
if name not in names_lr_mult: | |
params.append(param) | |
else: | |
params.append(param) | |
optimizer = get_optimizer( | |
params=[ | |
{"params": params, "lr": configs["optimizer"]["lr"]}, | |
{"params": params_lr_mult, "lr": configs["optimizer"]["lr"] * opt.lr_mult} | |
], | |
**configs["optimizer"] | |
) | |
logger.info(f"Learning rate x [{opt.lr_mult}] parameter names: {names_lr_mult}\n") | |
configs["lr_scheduler"]["total_steps"] = configs["train"]["epochs"] * math.ceil( | |
len(train_loader) // accelerator.num_processes / args.gradient_accumulation_steps) # only account updated steps | |
configs["lr_scheduler"]["total_steps"] *= accelerator.num_processes # for lr scheduler setting | |
if "num_warmup_steps" in configs["lr_scheduler"]: | |
configs["lr_scheduler"]["num_warmup_steps"] *= accelerator.num_processes # for lr scheduler setting | |
lr_scheduler = get_lr_scheduler(optimizer=optimizer, **configs["lr_scheduler"]) | |
configs["lr_scheduler"]["total_steps"] //= accelerator.num_processes # reset for multi-gpu | |
if "num_warmup_steps" in configs["lr_scheduler"]: | |
configs["lr_scheduler"]["num_warmup_steps"] //= accelerator.num_processes # reset for multi-gpu | |
# Load pretrained reconstruction and gsvae models | |
logger.info(f"Load GSVAE checkpoint from [{args.load_pretrained_gsvae}] iteration [{args.load_pretrained_gsvae_ckpt:06d}]\n") | |
gsvae = util.load_ckpt( | |
os.path.join(args.output_dir, args.load_pretrained_gsvae, "checkpoints"), | |
args.load_pretrained_gsvae_ckpt, | |
None if args.hdfs_dir is None else os.path.join(args.project_hdfs_dir, args.load_pretrained_gsvae), | |
gsvae, accelerator | |
) | |
logger.info(f"Load GSRecon checkpoint from [{args.load_pretrained_gsrecon}] iteration [{args.load_pretrained_gsrecon_ckpt:06d}]\n") | |
gsrecon = util.load_ckpt( | |
os.path.join(args.output_dir, args.load_pretrained_gsrecon, "checkpoints"), | |
args.load_pretrained_gsrecon_ckpt, | |
None if args.hdfs_dir is None else os.path.join(args.project_hdfs_dir, args.load_pretrained_gsrecon), | |
gsrecon, accelerator | |
) | |
# Prepare everything with `accelerator` | |
transformer, optimizer, lr_scheduler, train_loader, val_loader = accelerator.prepare( | |
transformer, optimizer, lr_scheduler, train_loader, val_loader | |
) | |
# Set classes explicitly for everything | |
transformer: DistributedDataParallel | |
optimizer: AcceleratedOptimizer | |
lr_scheduler: AcceleratedScheduler | |
train_loader: DataLoaderShard | |
val_loader: DataLoaderShard | |
if args.use_ema: | |
ema_transformer.to(accelerator.device) | |
# For mixed precision training we cast all non-trainable weigths to half-precision | |
# as these weights are only used for inference, keeping weights in full precision is not required. | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# Move `gsrecon`, `vae` and `gsvae` to gpu and cast to `weight_dtype` | |
gsrecon.to(accelerator.device, dtype=weight_dtype) | |
vae.to(accelerator.device, dtype=weight_dtype) | |
gsvae.to(accelerator.device, dtype=weight_dtype) | |
# Training configs after distribution and accumulation setup | |
updated_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps) | |
total_updated_steps = configs["lr_scheduler"]["total_steps"] | |
if args.max_train_steps is None: | |
args.max_train_steps = total_updated_steps | |
assert configs["train"]["epochs"] * updated_steps_per_epoch == total_updated_steps | |
logger.info(f"Total batch size: [{total_batch_size}]") | |
logger.info(f"Learning rate: [{configs['optimizer']['lr']}]") | |
logger.info(f"Gradient Accumulation steps: [{args.gradient_accumulation_steps}]") | |
logger.info(f"Total epochs: [{configs['train']['epochs']}]") | |
logger.info(f"Total steps: [{total_updated_steps}]") | |
logger.info(f"Steps for updating per epoch: [{updated_steps_per_epoch}]") | |
logger.info(f"Steps for validation: [{len(val_loader)}]\n") | |
# (Optional) Load checkpoint | |
global_update_step = 0 | |
if args.resume_from_iter is not None: | |
logger.info(f"Load checkpoint from iteration [{args.resume_from_iter}]\n") | |
# Download from HDFS | |
if not os.path.exists(os.path.join(ckpt_dir, f'{args.resume_from_iter:06d}')): | |
args.resume_from_iter = util.load_ckpt( | |
ckpt_dir, | |
args.resume_from_iter, | |
args.hdfs_dir, | |
None, # `None`: not load model ckpt here | |
accelerator, # manage the process states | |
) | |
# Load everything | |
accelerator.load_state(os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}")) # torch < 2.4.0 here for `weights_only=False` | |
global_update_step = int(args.resume_from_iter) | |
# Save all experimental parameters and model architecture of this run to a file (args and configs) | |
if accelerator.is_main_process: | |
exp_params = util.save_experiment_params(args, configs, opt, exp_dir) | |
util.save_model_architecture(accelerator.unwrap_model(transformer), exp_dir) | |
# WandB logger | |
if accelerator.is_main_process: | |
if args.offline_wandb: | |
os.environ["WANDB_MODE"] = "offline" | |
with open(args.wandb_token_path, "r") as f: | |
os.environ["WANDB_API_KEY"] = f.read().strip() | |
wandb.init( | |
project=PROJECT_NAME, name=args.tag, | |
config=exp_params, dir=exp_dir, | |
resume=True | |
) | |
# Wandb artifact for logging experiment information | |
arti_exp_info = wandb.Artifact(args.tag, type="exp_info") | |
arti_exp_info.add_file(os.path.join(exp_dir, "params.yaml")) | |
arti_exp_info.add_file(os.path.join(exp_dir, "model.txt")) | |
arti_exp_info.add_file(os.path.join(exp_dir, "log.txt")) # only save the log before training | |
wandb.log_artifact(arti_exp_info) | |
def get_sigmas(timesteps: Tensor, n_dim=4, dtype=torch.float32): | |
sigmas = noise_scheduler.sigmas.to(dtype=dtype, device=accelerator.device) | |
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) | |
timesteps = timesteps.to(accelerator.device) | |
step_indices = [(schedule_timesteps == t).nonzero()[0].item() for t in timesteps] | |
sigma = sigmas[step_indices].flatten() | |
while len(sigma.shape) < n_dim: | |
sigma = sigma.unsqueeze(-1) | |
return sigma | |
# Start training | |
logger.logger.propagate = False # not propagate to the root logger (console) | |
progress_bar = tqdm( | |
range(total_updated_steps), | |
initial=global_update_step, | |
desc="Training", | |
ncols=125, | |
disable=not accelerator.is_main_process | |
) | |
for batch in yield_forever(train_loader): | |
if global_update_step == args.max_train_steps: | |
progress_bar.close() | |
logger.logger.propagate = True # propagate to the root logger (console) | |
if accelerator.is_main_process: | |
wandb.finish() | |
logger.info("Training finished!\n") | |
return | |
transformer.train() | |
with accelerator.accumulate(transformer): | |
V_in, V_cond, V = opt.num_input_views, opt.num_cond_views, opt.num_views # TODO: not support V_cond > V_in by now | |
cond_idx = [0] # the first view must be in inputs | |
if V_cond > 1: | |
cond_idx += np.random.choice(range(1, V), V_cond-1, replace=False).tolist() | |
imgs_cond = batch["image"][:, cond_idx, ...] # (B, V_cond, 3, H, W) | |
B = imgs_cond.shape[0] | |
prompt_embeds = batch["prompt_embed"] # (B, N, D) | |
negative_prompt_embeds = repeat(negative_prompt_embed.to(accelerator.device), "n d -> b n d", b=B) | |
pooled_prompt_embeds = batch["pooled_prompt_embed"] # (B, D) | |
negative_pooled_prompt_embeds = repeat(negative_pooled_prompt_embed.to(accelerator.device), "d -> b d", b=B) | |
imgs_out = batch["image"][:, :V_in, ...] | |
C2W = batch["C2W"] | |
fxfycxcy = batch["fxfycxcy"] | |
( | |
imgs_cond, prompt_embeds, negative_prompt_embeds, | |
pooled_prompt_embeds, negative_pooled_prompt_embeds, | |
imgs_out, C2W, fxfycxcy | |
) = ( | |
imgs_cond.to(weight_dtype), | |
prompt_embeds.to(weight_dtype), | |
negative_prompt_embeds.to(weight_dtype), | |
pooled_prompt_embeds.to(weight_dtype), | |
negative_pooled_prompt_embeds.to(weight_dtype), | |
imgs_out.to(weight_dtype), | |
C2W.to(weight_dtype), | |
fxfycxcy.to(weight_dtype), | |
) | |
input_C2W = C2W[:, :V_in, ...] | |
input_fxfycxcy = fxfycxcy[:, :V_in, ...] | |
cond_C2W = C2W[:, cond_idx, ...] | |
cond_fxfycxcy = fxfycxcy[:, cond_idx,...] | |
# (Optional) Plucker embeddings | |
if opt.input_concat_plucker: | |
H = W = opt.input_res | |
plucker, _ = geo_util.plucker_ray(H, W, input_C2W, input_fxfycxcy) # (B, V_in, 6, H, W) | |
if opt.view_concat_condition: | |
cond_plucker, _ = geo_util.plucker_ray(H, W, cond_C2W, cond_fxfycxcy) # (B, V_cond, 6, H, W) | |
plucker = torch.cat([cond_plucker, plucker], dim=1) # (B, V_cond+V_in, 6, H, W) | |
plucker = rearrange(plucker, "b v c h w -> (b v) c h w") | |
else: | |
plucker = None | |
# VAE input image condition | |
if opt.view_concat_condition: | |
with torch.no_grad(): | |
imgs_cond = rearrange(imgs_cond, "b v c h w -> (b v) c h w") | |
image_latents = vae.config.scaling_factor * (vae.encode( | |
imgs_cond * 2. - 1.).latent_dist.sample() - vae.config.shift_factor) # (B*V_cond, 4, H', W') | |
image_latents = rearrange(image_latents, "(b v) c h w -> b v c h w", v=V_cond) # (B, V_cond, 4, H', W') | |
# Get GS latents | |
if opt.input_normal: | |
imgs_out = torch.cat([imgs_out, batch["normal"][:, :V_in, ...].to(weight_dtype)], dim=2) | |
if opt.input_coord: | |
imgs_out = torch.cat([imgs_out, batch["coord"][:, :V_in, ...].to(weight_dtype)], dim=2) | |
with torch.no_grad(): | |
latents = gsvae.scaling_factor * (gsvae.get_gslatents(gsrecon, imgs_out, input_C2W, input_fxfycxcy) - gsvae.shift_factor) # (B*V_in, 4, H', W') | |
noise = torch.randn_like(latents) | |
# For weighting schemes where we sample timesteps non-uniformly | |
u = compute_density_for_timestep_sampling( | |
weighting_scheme=opt.weighting_scheme, | |
batch_size=B, | |
logit_mean=opt.logit_mean, | |
logit_std=opt.logit_std, | |
mode_scale=opt.mode_scale, | |
) | |
indices = (u * noise_scheduler.config.num_train_timesteps).long() | |
timesteps = noise_scheduler.timesteps[indices].to(accelerator.device) | |
timesteps = repeat(timesteps, "b -> (b v)", v=V_in) # same noise scale for different views of the same object | |
sigmas = get_sigmas(timesteps, len(latents.shape), weight_dtype) | |
latent_model_input = noisy_latents = (1. - sigmas) * latents + sigmas * noise | |
if opt.cfg_dropout_prob > 0.: | |
# Drop a group of multi-view images as a whole | |
random_p = torch.rand(B, device=latents.device) | |
# Sample masks for the conditioning VAE images | |
if opt.view_concat_condition: | |
image_mask_dtype = image_latents.dtype | |
image_mask = 1 - ( | |
(random_p >= opt.cfg_dropout_prob).to(image_mask_dtype) | |
* (random_p < 3 * opt.cfg_dropout_prob).to(image_mask_dtype) | |
) # actual dropout rate is 2 * `cfg.condition_drop_rate` | |
image_mask = image_mask.reshape(B, 1, 1, 1, 1) | |
# Final VAE image conditioning | |
image_latents = image_mask * image_latents | |
# Sample masks for the conditioning text prompts | |
text_mask_dtype = prompt_embeds.dtype | |
text_mask = 1 - ( | |
(random_p < 2 * opt.cfg_dropout_prob).to(text_mask_dtype) | |
) # actual dropout rate is 2 * `cfg.condition_drop_rate` | |
text_mask = text_mask.reshape(B, 1, 1) | |
# Final text conditioning | |
prompt_embeds = text_mask * prompt_embeds + (1 - text_mask) * negative_prompt_embeds | |
# Final pooled text conditioning | |
text_mask = text_mask.reshape(B, 1) | |
pooled_prompt_embeds = text_mask * pooled_prompt_embeds + (1 - text_mask) * negative_pooled_prompt_embeds | |
# Concatenate input latents with others | |
latent_model_input = rearrange(latent_model_input, "(b v) c h w -> b v c h w", v=V_in) | |
if opt.view_concat_condition: | |
latent_model_input = torch.cat([image_latents, latent_model_input], dim=1) # (B, V_in+V_cond, 4, H', W') | |
if opt.input_concat_plucker: | |
plucker = tF.interpolate(plucker, size=latent_model_input.shape[-2:], mode="bilinear", align_corners=False) | |
plucker = rearrange(plucker, "(b v) c h w -> b v c h w", v=V_in + (V_cond if opt.view_concat_condition else 0)) | |
latent_model_input = torch.cat([latent_model_input, plucker], dim=2) # (B, V_in(+V_cond), 4+6, H', W') | |
plucker = rearrange(plucker, "b v c h w -> (b v) c h w") | |
if opt.input_concat_binary_mask: | |
if opt.view_concat_condition: | |
latent_model_input = torch.cat([ | |
torch.cat([latent_model_input[:, :V_cond, ...], torch.zeros_like(latent_model_input[:, :V_cond, 0:1, ...])], dim=2), | |
torch.cat([latent_model_input[:, V_cond:, ...], torch.ones_like(latent_model_input[:, V_cond:, 0:1, ...])], dim=2), | |
], dim=1) # (B, V_in+V_cond, 4+6+1, H', W') | |
else: | |
latent_model_input = torch.cat([ | |
torch.cat([latent_model_input, torch.ones_like(latent_model_input[:, :, 0:1, ...])], dim=2), | |
], dim=1) # (B, V_in, 4+6+1, H', W') | |
latent_model_input = rearrange(latent_model_input, "b v c h w -> (b v) c h w") | |
timesteps_input = rearrange(timesteps, "(b v) -> b v", v=V_in)[:, 0] # (B,) | |
model_pred = transformer( | |
hidden_states=latent_model_input, | |
timestep=timesteps_input, | |
encoder_hidden_states=prompt_embeds, | |
pooled_projections=pooled_prompt_embeds, | |
joint_attention_kwargs=dict( | |
num_views=opt.num_input_views + (V_cond if opt.view_concat_condition else 0), | |
), | |
).sample | |
# Only keep the noise prediction for the latents | |
if opt.view_concat_condition: | |
model_pred = rearrange(model_pred, "(b v) c h w -> b v c h w", v=V_in+V_cond) | |
model_pred = rearrange(model_pred[:, V_cond:, ...], "b v c h w -> (b v) c h w") | |
if opt.precondition_outputs: # Section 5 of https://arxiv.org/abs/2206.00364 | |
model_pred = model_pred * (-sigmas) + noisy_latents # predicted x_0 | |
target = latents | |
else: # flow matching | |
target = noise - latents | |
# For these weighting schemes use a uniform timestep sampling, so post-weight the loss | |
weighting = compute_loss_weighting_for_sd3(opt.weighting_scheme, sigmas) | |
loss = weighting * tF.mse_loss(model_pred.float(), target.float(), reduction="none") | |
loss = rearrange(loss, "(b v) c h w -> b v c h w", v=V_in) | |
loss = loss.mean(dim=list(range(1, len(loss.shape)))) | |
# Rendering loss | |
use_rendering_loss = np.random.rand() < opt.rendering_loss_prob | |
if use_rendering_loss: | |
# Get predicted x_0 | |
if opt.precondition_outputs: | |
pred_original_latents = model_pred | |
else: | |
pred_original_latents = model_pred * (-sigmas) + noisy_latents | |
# Render the predicted latents | |
pred_original_latents = pred_original_latents.to(weight_dtype) | |
pred_original_latents = pred_original_latents / gsvae.scaling_factor + gsvae.shift_factor | |
pred_render_outputs = gsvae.decode_and_render_gslatents( | |
gsrecon, pred_original_latents, input_C2W, input_fxfycxcy, C2W, fxfycxcy, | |
use_tiny_decoder=opt.use_tiny_decoder, | |
) # (B, V, 3 or 1, H, W) | |
image_mse = tF.mse_loss(batch["image"], pred_render_outputs["image"], reduction="none") | |
mask_mse = tF.mse_loss(batch["mask"], pred_render_outputs["alpha"], reduction="none") | |
render_loss = image_mse + mask_mse # (B, V, C, H, W) | |
# Depth & Normal | |
if opt.coord_weight > 0: | |
assert opt.load_coord | |
coord_mse = tF.mse_loss(batch["coord"], pred_render_outputs["coord"], reduction="none") | |
render_loss += opt.coord_weight * coord_mse # (B, V, C, H, W) | |
else: | |
coord_mse = None | |
if opt.normal_weight > 0: | |
assert opt.load_normal | |
normal_cosim = tF.cosine_similarity(batch["normal"], pred_render_outputs["normal"], dim=2).unsqueeze(2) | |
render_loss += opt.normal_weight * (1. - normal_cosim) # (B, V, C, H, W) | |
else: | |
normal_cosim = None | |
# LPIPS | |
if opt.lpips_weight > 0.: | |
lpips, chunk = [], opt.chunk_size | |
for i in range(B*V): | |
_lpips = lpips_loss( | |
# Downsampled to at most 256 to reduce memory cost | |
tF.interpolate( | |
rearrange(batch["image"], "b v c h w -> (b v) c h w")[i:min(B*V, i+chunk), ...] * 2. - 1., | |
(256, 256), mode="bilinear", align_corners=False | |
), | |
tF.interpolate( | |
rearrange(pred_render_outputs["image"], "b v c h w -> (b v) c h w")[i:min(B*V, i+chunk), ...] * 2. - 1., | |
(256, 256), mode="bilinear", align_corners=False | |
) | |
) # (`chunk`, 1, 1, 1) | |
lpips.append(_lpips) | |
lpips = torch.cat(lpips, dim=0) # (B*V, 1, 1, 1) | |
lpips = rearrange(lpips, "(b v) c h w -> b v c h w", v=V) | |
render_loss += opt.lpips_weight * lpips # (B, V, C, H, W) | |
render_loss = render_loss.mean(dim=list(range(1, len(render_loss.shape)))) # (B,) | |
if opt.snr_gamma_rendering > 0.: | |
timesteps = rearrange(timesteps, "(b v) -> b v", v=V_in)[:, 0] # (B,) | |
snr = compute_snr(noise_scheduler, timesteps) | |
mse_loss_weights = torch.stack([snr, opt.snr_gamma_rendering * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] | |
render_loss = mse_loss_weights * render_loss | |
loss = opt.diffusion_weight * loss + opt.render_weight * render_loss # (B,) | |
# Metric: PNSR, SSIM and LPIPS | |
with torch.no_grad(): | |
psnr = -10 * torch.log10(torch.mean((batch["image"] - pred_render_outputs["image"].detach()) ** 2)) | |
ssim = torch.tensor(calculate_ssim( | |
(rearrange(batch["image"], "b v c h w -> (b v c) h w") | |
.cpu().float().numpy() * 255.).astype(np.uint8), | |
(rearrange(pred_render_outputs["image"].detach(), "b v c h w -> (b v c) h w") | |
.cpu().float().numpy() * 255.).astype(np.uint8), | |
channel_axis=0, | |
), device=batch["image"].device) | |
if opt.lpips_weight <= 0.: | |
lpips = lpips_loss( | |
# Downsampled to at most 256 to reduce memory cost | |
tF.interpolate( | |
rearrange(batch["image"], "b v c h w -> (b v) c h w") * 2. - 1., | |
(256, 256), mode="bilinear", align_corners=False | |
), | |
tF.interpolate( | |
rearrange(pred_render_outputs["image"].detach(), "b v c h w -> (b v) c h w") * 2. - 1., | |
(256, 256), mode="bilinear", align_corners=False | |
) | |
) | |
# Backpropagate | |
accelerator.backward(loss.mean()) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
# Gather the losses across all processes for logging (if we use distributed training) | |
loss = accelerator.gather(loss.detach()).mean() | |
if use_rendering_loss: | |
psnr = accelerator.gather(psnr.detach()).mean() | |
ssim = accelerator.gather(ssim.detach()).mean() | |
lpips = accelerator.gather(lpips.detach()).mean() | |
render_loss = accelerator.gather(render_loss.detach()).mean() | |
if coord_mse is not None: | |
coord_mse = accelerator.gather(coord_mse.detach()).mean() | |
if normal_cosim is not None: | |
normal_cosim = accelerator.gather(normal_cosim.detach()).mean() | |
logs = { | |
"loss": loss.item(), | |
"lr": lr_scheduler.get_last_lr()[0] | |
} | |
if args.use_ema: | |
ema_transformer.step(transformer.parameters()) | |
logs.update({"ema": ema_transformer.cur_decay_value}) | |
if use_rendering_loss: | |
logs.update({"render_loss": render_loss.item()}) | |
progress_bar.set_postfix(**logs) | |
progress_bar.update(1) | |
global_update_step += 1 | |
logger.info( | |
f"[{global_update_step:06d} / {total_updated_steps:06d}] " + | |
f"loss: {logs['loss']:.4f}, lr: {logs['lr']:.2e}" + | |
f", ema: {logs['ema']:.4f}" if args.use_ema else "" + | |
f", render: {logs['render_loss']:.4f}" if use_rendering_loss else "" | |
) | |
# Log the training progress | |
if global_update_step % configs["train"]["log_freq"] == 0 or global_update_step == 1 \ | |
or global_update_step % updated_steps_per_epoch == 0: # last step of an epoch | |
if accelerator.is_main_process: | |
wandb.log({ | |
"training/loss": logs["loss"], | |
"training/lr": logs["lr"], | |
}, step=global_update_step) | |
if args.use_ema: | |
wandb.log({ | |
"training/ema": logs["ema"] | |
}, step=global_update_step) | |
if use_rendering_loss: | |
wandb.log({ | |
"training/psnr": psnr.item(), | |
"training/ssim": ssim.item(), | |
"training/lpips": lpips.item(), | |
"training/render_loss": logs["render_loss"] | |
}, step=global_update_step) | |
if coord_mse is not None: | |
wandb.log({ | |
"training/coord_mse": coord_mse.item() | |
}, step=global_update_step) | |
if normal_cosim is not None: | |
wandb.log({ | |
"training/normal_cosim": normal_cosim.item() | |
}, step=global_update_step) | |
# Save checkpoint | |
if (global_update_step % configs["train"]["save_freq"] == 0 # 1. every `save_freq` steps | |
or global_update_step % (configs["train"]["save_freq_epoch"] * updated_steps_per_epoch) == 0 # 2. every `save_freq_epoch` epochs | |
or global_update_step == total_updated_steps): # 3. last step of an epoch | |
gc.collect() | |
if accelerator.distributed_type == accelerate.utils.DistributedType.DEEPSPEED: | |
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues | |
accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}")) | |
elif accelerator.is_main_process: | |
accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}")) | |
accelerator.wait_for_everyone() # ensure all processes have finished saving | |
if accelerator.is_main_process: | |
if args.hdfs_dir is not None: | |
util.save_ckpt(ckpt_dir, global_update_step, args.hdfs_dir) | |
gc.collect() | |
# Evaluate on the validation set | |
if (global_update_step == 1 | |
or (global_update_step % configs["train"]["early_eval_freq"] == 0 and | |
global_update_step < configs["train"]["early_eval"]) # 1. more frequently at the beginning | |
or global_update_step % configs["train"]["eval_freq"] == 0 # 2. every `eval_freq` steps | |
or global_update_step % (configs["train"]["eval_freq_epoch"] * updated_steps_per_epoch) == 0 # 3. every `eval_freq_epoch` epochs | |
or global_update_step == total_updated_steps): # 4. last step of an epoch | |
# Visualize images for rendering loss | |
if accelerator.is_main_process and use_rendering_loss: | |
train_vis_dict = { | |
"images_render": pred_render_outputs["image"], # (B, V, 3, H, W) | |
"images_gt": batch["image"], # (B, V, 3, H, W) | |
} | |
if opt.vis_coords: | |
train_vis_dict.update({ | |
"images_coord": pred_render_outputs["coord"], # (B, V, 3, H, W) | |
}) | |
if opt.load_coord: | |
train_vis_dict.update({ | |
"images_gt_coord": batch["coord"] # (B, V, 3, H, W) | |
}) | |
if opt.vis_normals: | |
train_vis_dict.update({ | |
"images_normal": pred_render_outputs["normal"], # (B, V, 3, H, W) | |
}) | |
if opt.load_normal: | |
train_vis_dict.update({ | |
"images_gt_normal": batch["normal"] # (B, V, 3, H, W) | |
}) | |
wandb.log({ | |
"images/training": vis_util.wandb_mvimage_log(train_vis_dict) | |
}, step=global_update_step) | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Use EMA parameters for evaluation | |
if args.use_ema: | |
# Store the Transformer parameters temporarily and load the EMA parameters to perform inference | |
ema_transformer.store(transformer.parameters()) | |
ema_transformer.copy_to(transformer.parameters()) | |
transformer.eval() | |
log_validation( | |
val_loader, | |
negative_prompt_embed, | |
negative_pooled_prompt_embed, | |
lpips_loss, | |
gsrecon, | |
gsvae, | |
vae, | |
transformer, | |
global_update_step, | |
accelerator, | |
args, | |
opt, | |
) | |
if args.use_ema: | |
# Switch back to the original Transformer parameters | |
ema_transformer.restore(transformer.parameters()) | |
torch.cuda.empty_cache() | |
gc.collect() | |
if __name__ == "__main__": | |
main() | |