Spaces:
Running
on
Zero
Running
on
Zero
import warnings | |
warnings.filterwarnings("ignore") # ignore all warnings | |
import diffusers.utils.logging as diffusion_logging | |
diffusion_logging.set_verbosity_error() # ignore diffusers warnings | |
from src.utils.typing_utils import * | |
import os | |
import argparse | |
import logging | |
import time | |
import math | |
import gc | |
from packaging import version | |
import trimesh | |
from PIL import Image | |
import numpy as np | |
import wandb | |
from tqdm import tqdm | |
import torch | |
import torch.nn.functional as tF | |
import accelerate | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger as get_accelerate_logger | |
from accelerate import DataLoaderConfiguration, DeepSpeedPlugin | |
from diffusers.training_utils import ( | |
compute_density_for_timestep_sampling, | |
compute_loss_weighting_for_sd3 | |
) | |
from transformers import ( | |
BitImageProcessor, | |
Dinov2Model, | |
) | |
from src.schedulers import RectifiedFlowScheduler | |
from src.models.autoencoders import TripoSGVAEModel | |
from src.models.transformers import PartCrafterDiTModel | |
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline | |
from src.datasets import ( | |
ObjaversePartDataset, | |
BatchedObjaversePartDataset, | |
MultiEpochsDataLoader, | |
yield_forever | |
) | |
from src.utils.data_utils import get_colored_mesh_composition | |
from src.utils.train_utils import ( | |
MyEMAModel, | |
get_configs, | |
get_optimizer, | |
get_lr_scheduler, | |
save_experiment_params, | |
save_model_architecture, | |
) | |
from src.utils.render_utils import ( | |
render_views_around_mesh, | |
render_normal_views_around_mesh, | |
make_grid_for_images_or_videos, | |
export_renderings | |
) | |
from src.utils.metric_utils import compute_cd_and_f_score_in_training | |
def main(): | |
PROJECT_NAME = "PartCrafter" | |
parser = argparse.ArgumentParser( | |
description="Train a diffusion model for 3D object generation", | |
) | |
parser.add_argument( | |
"--config", | |
type=str, | |
required=True, | |
help="Path to the config file" | |
) | |
parser.add_argument( | |
"--tag", | |
type=str, | |
default=None, | |
help="Tag that refers to the current experiment" | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="output", | |
help="Path to the output directory" | |
) | |
parser.add_argument( | |
"--resume_from_iter", | |
type=int, | |
default=None, | |
help="The iteration to load the checkpoint from" | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=0, | |
help="Seed for the PRNG" | |
) | |
parser.add_argument( | |
"--offline_wandb", | |
action="store_true", | |
help="Use offline WandB for experiment tracking" | |
) | |
parser.add_argument( | |
"--max_train_steps", | |
type=int, | |
default=None, | |
help="The max iteration step for training" | |
) | |
parser.add_argument( | |
"--max_val_steps", | |
type=int, | |
default=2, | |
help="The max iteration step for validation" | |
) | |
parser.add_argument( | |
"--num_workers", | |
type=int, | |
default=32, | |
help="The number of processed spawned by the batch provider" | |
) | |
parser.add_argument( | |
"--pin_memory", | |
action="store_true", | |
help="Pin memory for the data loader" | |
) | |
parser.add_argument( | |
"--use_ema", | |
action="store_true", | |
help="Use EMA model for training" | |
) | |
parser.add_argument( | |
"--scale_lr", | |
action="store_true", | |
help="Scale lr with total batch size (base batch size: 256)" | |
) | |
parser.add_argument( | |
"--max_grad_norm", | |
type=float, | |
default=1., | |
help="Max gradient norm for gradient clipping" | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass" | |
) | |
parser.add_argument( | |
"--mixed_precision", | |
type=str, | |
default="fp16", | |
choices=["no", "fp16", "bf16"], | |
help="Type of mixed precision training" | |
) | |
parser.add_argument( | |
"--allow_tf32", | |
action="store_true", | |
help="Enable TF32 for faster training on Ampere GPUs" | |
) | |
parser.add_argument( | |
"--val_guidance_scales", | |
type=list, | |
nargs="+", | |
default=[7.0], | |
help="CFG scale used for validation" | |
) | |
parser.add_argument( | |
"--use_deepspeed", | |
action="store_true", | |
help="Use DeepSpeed for training" | |
) | |
parser.add_argument( | |
"--zero_stage", | |
type=int, | |
default=1, | |
choices=[1, 2, 3], # https://huggingface.co/docs/accelerate/usage_guides/deepspeed | |
help="ZeRO stage type for DeepSpeed" | |
) | |
parser.add_argument( | |
"--from_scratch", | |
action="store_true", | |
help="Train from scratch" | |
) | |
parser.add_argument( | |
"--load_pretrained_model", | |
type=str, | |
default=None, | |
help="Tag of a pretrained PartCrafterDiTModel in this project" | |
) | |
parser.add_argument( | |
"--load_pretrained_model_ckpt", | |
type=int, | |
default=-1, | |
help="Iteration of the pretrained PartCrafterDiTModel checkpoint" | |
) | |
# Parse the arguments | |
args, extras = parser.parse_known_args() | |
# Parse the config file | |
configs = get_configs(args.config, extras) # change yaml configs by `extras` | |
args.val_guidance_scales = [float(x[0]) if isinstance(x, list) else float(x) for x in args.val_guidance_scales] | |
if args.max_val_steps > 0: | |
# If enable validation, the max_val_steps must be a multiple of nrow | |
# Always keep validation batchsize 1 | |
divider = configs["val"]["nrow"] | |
args.max_val_steps = max(args.max_val_steps, divider) | |
if args.max_val_steps % divider != 0: | |
args.max_val_steps = (args.max_val_steps // divider + 1) * divider | |
# Create an experiment directory using the `tag` | |
if args.tag is None: | |
args.tag = time.strftime("%Y%m%d_%H_%M_%S") | |
exp_dir = os.path.join(args.output_dir, args.tag) | |
ckpt_dir = os.path.join(exp_dir, "checkpoints") | |
eval_dir = os.path.join(exp_dir, "evaluations") | |
os.makedirs(ckpt_dir, exist_ok=True) | |
os.makedirs(eval_dir, exist_ok=True) | |
# Initialize the logger | |
logging.basicConfig( | |
format="%(asctime)s - %(message)s", | |
datefmt="%Y/%m/%d %H:%M:%S", | |
level=logging.INFO | |
) | |
logger = get_accelerate_logger(__name__, log_level="INFO") | |
file_handler = logging.FileHandler(os.path.join(exp_dir, "log.txt")) # output to file | |
file_handler.setFormatter(logging.Formatter( | |
fmt="%(asctime)s - %(message)s", | |
datefmt="%Y/%m/%d %H:%M:%S" | |
)) | |
logger.logger.addHandler(file_handler) | |
logger.logger.propagate = True # propagate to the root logger (console) | |
# Set DeepSpeed config | |
if args.use_deepspeed: | |
deepspeed_plugin = DeepSpeedPlugin( | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
gradient_clipping=args.max_grad_norm, | |
zero_stage=int(args.zero_stage), | |
offload_optimizer_device="cpu", # hard-coded here, TODO: make it configurable | |
) | |
else: | |
deepspeed_plugin = None | |
# Initialize the accelerator | |
accelerator = Accelerator( | |
project_dir=exp_dir, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
mixed_precision=args.mixed_precision, | |
split_batches=False, # batch size per GPU | |
dataloader_config=DataLoaderConfiguration(non_blocking=args.pin_memory), | |
deepspeed_plugin=deepspeed_plugin, | |
) | |
logger.info(f"Accelerator state:\n{accelerator.state}\n") | |
# Set the random seed | |
if args.seed >= 0: | |
accelerate.utils.set_seed(args.seed) | |
logger.info(f"You have chosen to seed([{args.seed}]) the experiment [{args.tag}]\n") | |
# Enable TF32 for faster training on Ampere GPUs | |
if args.allow_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
train_dataset = BatchedObjaversePartDataset( | |
configs=configs, | |
batch_size=configs["train"]["batch_size_per_gpu"], | |
is_main_process=accelerator.is_main_process, | |
shuffle=True, | |
training=True, | |
) | |
val_dataset = ObjaversePartDataset( | |
configs=configs, | |
training=False, | |
) | |
train_loader = MultiEpochsDataLoader( | |
train_dataset, | |
batch_size=configs["train"]["batch_size_per_gpu"], | |
num_workers=args.num_workers, | |
drop_last=True, | |
pin_memory=args.pin_memory, | |
collate_fn=train_dataset.collate_fn, | |
) | |
val_loader = MultiEpochsDataLoader( | |
val_dataset, | |
batch_size=configs["val"]["batch_size_per_gpu"], | |
num_workers=args.num_workers, | |
drop_last=True, | |
pin_memory=args.pin_memory, | |
) | |
random_val_loader = MultiEpochsDataLoader( | |
val_dataset, | |
batch_size=configs["val"]["batch_size_per_gpu"], | |
shuffle=True, | |
num_workers=args.num_workers, | |
drop_last=True, | |
pin_memory=args.pin_memory, | |
) | |
logger.info(f"Loaded [{len(train_dataset)}] training samples and [{len(val_dataset)}] validation samples\n") | |
# Compute the effective batch size and scale learning rate | |
total_batch_size = configs["train"]["batch_size_per_gpu"] * \ | |
accelerator.num_processes * args.gradient_accumulation_steps | |
configs["train"]["total_batch_size"] = total_batch_size | |
if args.scale_lr: | |
configs["optimizer"]["lr"] *= (total_batch_size / 256) | |
configs["lr_scheduler"]["max_lr"] = configs["optimizer"]["lr"] | |
# Initialize the model | |
logger.info("Initializing the model...") | |
vae = TripoSGVAEModel.from_pretrained( | |
configs["model"]["pretrained_model_name_or_path"], | |
subfolder="vae" | |
) | |
feature_extractor_dinov2 = BitImageProcessor.from_pretrained( | |
configs["model"]["pretrained_model_name_or_path"], | |
subfolder="feature_extractor_dinov2" | |
) | |
image_encoder_dinov2 = Dinov2Model.from_pretrained( | |
configs["model"]["pretrained_model_name_or_path"], | |
subfolder="image_encoder_dinov2" | |
) | |
enable_part_embedding = configs["model"]["transformer"].get("enable_part_embedding", True) | |
enable_local_cross_attn = configs["model"]["transformer"].get("enable_local_cross_attn", True) | |
enable_global_cross_attn = configs["model"]["transformer"].get("enable_global_cross_attn", True) | |
global_attn_block_ids = configs["model"]["transformer"].get("global_attn_block_ids", None) | |
if global_attn_block_ids is not None: | |
global_attn_block_ids = list(global_attn_block_ids) | |
global_attn_block_id_range = configs["model"]["transformer"].get("global_attn_block_id_range", None) | |
if global_attn_block_id_range is not None: | |
global_attn_block_id_range = list(global_attn_block_id_range) | |
if args.from_scratch: | |
logger.info(f"Initialize PartCrafterDiTModel from scratch\n") | |
transformer = PartCrafterDiTModel.from_config( | |
os.path.join( | |
configs["model"]["pretrained_model_name_or_path"], | |
"transformer" | |
), | |
enable_part_embedding=enable_part_embedding, | |
enable_local_cross_attn=enable_local_cross_attn, | |
enable_global_cross_attn=enable_global_cross_attn, | |
global_attn_block_ids=global_attn_block_ids, | |
global_attn_block_id_range=global_attn_block_id_range, | |
) | |
elif args.load_pretrained_model is None: | |
logger.info(f"Load pretrained TripoSGDiTModel to initialize PartCrafterDiTModel from [{configs['model']['pretrained_model_name_or_path']}]\n") | |
transformer, loading_info = PartCrafterDiTModel.from_pretrained( | |
configs["model"]["pretrained_model_name_or_path"], | |
subfolder="transformer", | |
low_cpu_mem_usage=False, | |
output_loading_info=True, | |
enable_part_embedding=enable_part_embedding, | |
enable_local_cross_attn=enable_local_cross_attn, | |
enable_global_cross_attn=enable_global_cross_attn, | |
global_attn_block_ids=global_attn_block_ids, | |
global_attn_block_id_range=global_attn_block_id_range, | |
) | |
else: | |
logger.info(f"Load PartCrafterDiTModel EMA checkpoint from [{args.load_pretrained_model}] iteration [{args.load_pretrained_model_ckpt:06d}]\n") | |
path = os.path.join( | |
args.output_dir, | |
args.load_pretrained_model, | |
"checkpoints", | |
f"{args.load_pretrained_model_ckpt:06d}" | |
) | |
transformer, loading_info = PartCrafterDiTModel.from_pretrained( | |
path, | |
subfolder="transformer_ema", | |
low_cpu_mem_usage=False, | |
output_loading_info=True, | |
enable_part_embedding=enable_part_embedding, | |
enable_local_cross_attn=enable_local_cross_attn, | |
enable_global_cross_attn=enable_global_cross_attn, | |
global_attn_block_ids=global_attn_block_ids, | |
global_attn_block_id_range=global_attn_block_id_range, | |
) | |
if not args.from_scratch: | |
for v in loading_info.values(): | |
if v and len(v) > 0: | |
logger.info(f"Loading info of PartCrafterDiTModel: {loading_info}\n") | |
break | |
noise_scheduler = RectifiedFlowScheduler.from_pretrained( | |
configs["model"]["pretrained_model_name_or_path"], | |
subfolder="scheduler" | |
) | |
if args.use_ema: | |
ema_transformer = MyEMAModel( | |
transformer.parameters(), | |
model_cls=PartCrafterDiTModel, | |
model_config=transformer.config, | |
**configs["train"]["ema_kwargs"] | |
) | |
# Freeze VAE and image encoder | |
vae.requires_grad_(False) | |
image_encoder_dinov2.requires_grad_(False) | |
vae.eval() | |
image_encoder_dinov2.eval() | |
trainable_modules = configs["train"].get("trainable_modules", None) | |
if trainable_modules is None: | |
transformer.requires_grad_(True) | |
else: | |
trainable_module_names = [] | |
transformer.requires_grad_(False) | |
for name, module in transformer.named_modules(): | |
for module_name in tuple(trainable_modules.split(",")): | |
if module_name in name: | |
for params in module.parameters(): | |
params.requires_grad = True | |
trainable_module_names.append(name) | |
logger.info(f"Trainable parameter names: {trainable_module_names}\n") | |
# transformer.enable_xformers_memory_efficient_attention() # use `tF.scaled_dot_product_attention` instead | |
# `accelerate` 0.16.0 will have better support for customized saving | |
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | |
# Create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
def save_model_hook(models, weights, output_dir): | |
if accelerator.is_main_process: | |
if args.use_ema: | |
ema_transformer.save_pretrained(os.path.join(output_dir, "transformer_ema")) | |
for i, model in enumerate(models): | |
model.save_pretrained(os.path.join(output_dir, "transformer")) | |
# Make sure to pop weight so that corresponding model is not saved again | |
if weights: | |
weights.pop() | |
def load_model_hook(models, input_dir): | |
if args.use_ema: | |
load_model = MyEMAModel.from_pretrained(os.path.join(input_dir, "transformer_ema"), PartCrafterDiTModel) | |
ema_transformer.load_state_dict(load_model.state_dict()) | |
ema_transformer.to(accelerator.device) | |
del load_model | |
for _ in range(len(models)): | |
# Pop models so that they are not loaded again | |
model = models.pop() | |
# Load diffusers style into model | |
load_model = PartCrafterDiTModel.from_pretrained(input_dir, subfolder="transformer") | |
model.register_to_config(**load_model.config) | |
model.load_state_dict(load_model.state_dict()) | |
del load_model | |
accelerator.register_save_state_pre_hook(save_model_hook) | |
accelerator.register_load_state_pre_hook(load_model_hook) | |
if configs["train"]["grad_checkpoint"]: | |
transformer.enable_gradient_checkpointing() | |
# Initialize the optimizer and learning rate scheduler | |
logger.info("Initializing the optimizer and learning rate scheduler...\n") | |
name_lr_mult = configs["train"].get("name_lr_mult", None) | |
lr_mult = configs["train"].get("lr_mult", 1.0) | |
params, params_lr_mult, names_lr_mult = [], [], [] | |
for name, param in transformer.named_parameters(): | |
if name_lr_mult is not None: | |
for k in name_lr_mult.split(","): | |
if k in name: | |
params_lr_mult.append(param) | |
names_lr_mult.append(name) | |
if name not in names_lr_mult: | |
params.append(param) | |
else: | |
params.append(param) | |
optimizer = get_optimizer( | |
params=[ | |
{"params": params, "lr": configs["optimizer"]["lr"]}, | |
{"params": params_lr_mult, "lr": configs["optimizer"]["lr"] * lr_mult} | |
], | |
**configs["optimizer"] | |
) | |
if name_lr_mult is not None: | |
logger.info(f"Learning rate x [{lr_mult}] parameter names: {names_lr_mult}\n") | |
configs["lr_scheduler"]["total_steps"] = configs["train"]["epochs"] * math.ceil( | |
len(train_loader) // accelerator.num_processes / args.gradient_accumulation_steps) # only account updated steps | |
configs["lr_scheduler"]["total_steps"] *= accelerator.num_processes # for lr scheduler setting | |
if "num_warmup_steps" in configs["lr_scheduler"]: | |
configs["lr_scheduler"]["num_warmup_steps"] *= accelerator.num_processes # for lr scheduler setting | |
lr_scheduler = get_lr_scheduler(optimizer=optimizer, **configs["lr_scheduler"]) | |
configs["lr_scheduler"]["total_steps"] //= accelerator.num_processes # reset for multi-gpu | |
if "num_warmup_steps" in configs["lr_scheduler"]: | |
configs["lr_scheduler"]["num_warmup_steps"] //= accelerator.num_processes # reset for multi-gpu | |
# Prepare everything with `accelerator` | |
transformer, optimizer, lr_scheduler, train_loader, val_loader, random_val_loader = accelerator.prepare( | |
transformer, optimizer, lr_scheduler, train_loader, val_loader, random_val_loader | |
) | |
# Set classes explicitly for everything | |
transformer: DistributedDataParallel | |
optimizer: AcceleratedOptimizer | |
lr_scheduler: AcceleratedScheduler | |
train_loader: DataLoaderShard | |
val_loader: DataLoaderShard | |
random_val_loader: DataLoaderShard | |
if args.use_ema: | |
ema_transformer.to(accelerator.device) | |
# For mixed precision training we cast all non-trainable weigths to half-precision | |
# as these weights are only used for inference, keeping weights in full precision is not required. | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# Move `vae` and `image_encoder_dinov2` to gpu and cast to `weight_dtype` | |
vae.to(accelerator.device, dtype=weight_dtype) | |
image_encoder_dinov2.to(accelerator.device, dtype=weight_dtype) | |
# Training configs after distribution and accumulation setup | |
updated_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps) | |
total_updated_steps = configs["lr_scheduler"]["total_steps"] | |
if args.max_train_steps is None: | |
args.max_train_steps = total_updated_steps | |
assert configs["train"]["epochs"] * updated_steps_per_epoch == total_updated_steps | |
if accelerator.num_processes > 1 and accelerator.is_main_process: | |
print() | |
accelerator.wait_for_everyone() | |
logger.info(f"Total batch size: [{total_batch_size}]") | |
logger.info(f"Learning rate: [{configs['optimizer']['lr']}]") | |
logger.info(f"Gradient Accumulation steps: [{args.gradient_accumulation_steps}]") | |
logger.info(f"Total epochs: [{configs['train']['epochs']}]") | |
logger.info(f"Total steps: [{total_updated_steps}]") | |
logger.info(f"Steps for updating per epoch: [{updated_steps_per_epoch}]") | |
logger.info(f"Steps for validation: [{len(val_loader)}]\n") | |
# (Optional) Load checkpoint | |
global_update_step = 0 | |
if args.resume_from_iter is not None: | |
if args.resume_from_iter < 0: | |
args.resume_from_iter = int(sorted(os.listdir(ckpt_dir))[-1]) | |
logger.info(f"Load checkpoint from iteration [{args.resume_from_iter}]\n") | |
# Load everything | |
if version.parse(torch.__version__) >= version.parse("2.4.0"): | |
torch.serialization.add_safe_globals([ | |
int, list, dict, | |
defaultdict, | |
Any, | |
DictConfig, ListConfig, Metadata, ContainerMetadata, AnyNode | |
]) # avoid deserialization error when loading optimizer state | |
accelerator.load_state(os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}")) # torch < 2.4.0 here for `weights_only=False` | |
global_update_step = int(args.resume_from_iter) | |
# Save all experimental parameters and model architecture of this run to a file (args and configs) | |
if accelerator.is_main_process: | |
exp_params = save_experiment_params(args, configs, exp_dir) | |
save_model_architecture(accelerator.unwrap_model(transformer), exp_dir) | |
# WandB logger | |
if accelerator.is_main_process: | |
if args.offline_wandb: | |
os.environ["WANDB_MODE"] = "offline" | |
wandb.init( | |
project=PROJECT_NAME, name=args.tag, | |
config=exp_params, dir=exp_dir, | |
resume=True | |
) | |
# Wandb artifact for logging experiment information | |
arti_exp_info = wandb.Artifact(args.tag, type="exp_info") | |
arti_exp_info.add_file(os.path.join(exp_dir, "params.yaml")) | |
arti_exp_info.add_file(os.path.join(exp_dir, "model.txt")) | |
arti_exp_info.add_file(os.path.join(exp_dir, "log.txt")) # only save the log before training | |
wandb.log_artifact(arti_exp_info) | |
def get_sigmas(timesteps: Tensor, n_dim: int, dtype=torch.float32): | |
sigmas = noise_scheduler.sigmas.to(dtype=dtype, device=accelerator.device) | |
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) | |
timesteps = timesteps.to(accelerator.device) | |
step_indices = [(schedule_timesteps == t).nonzero()[0].item() for t in timesteps] | |
sigma = sigmas[step_indices].flatten() | |
while len(sigma.shape) < n_dim: | |
sigma = sigma.unsqueeze(-1) | |
return sigma | |
# Start training | |
if accelerator.is_main_process: | |
print() | |
logger.info(f"Start training into {exp_dir}\n") | |
logger.logger.propagate = False # not propagate to the root logger (console) | |
progress_bar = tqdm( | |
range(total_updated_steps), | |
initial=global_update_step, | |
desc="Training", | |
ncols=125, | |
disable=not accelerator.is_main_process | |
) | |
for batch in yield_forever(train_loader): | |
if global_update_step == args.max_train_steps: | |
progress_bar.close() | |
logger.logger.propagate = True # propagate to the root logger (console) | |
if accelerator.is_main_process: | |
wandb.finish() | |
logger.info("Training finished!\n") | |
return | |
transformer.train() | |
with accelerator.accumulate(transformer): | |
images = batch["images"] # [N, H, W, 3] | |
with torch.no_grad(): | |
images = feature_extractor_dinov2(images=images, return_tensors="pt").pixel_values | |
images = images.to(device=accelerator.device, dtype=weight_dtype) | |
with torch.no_grad(): | |
image_embeds = image_encoder_dinov2(images).last_hidden_state | |
negative_image_embeds = torch.zeros_like(image_embeds) | |
part_surfaces = batch["part_surfaces"] # [N, P, 6] | |
part_surfaces = part_surfaces.to(device=accelerator.device, dtype=weight_dtype) | |
num_parts = batch["num_parts"] # [M, ] The shape of num_parts is not fixed | |
num_objects = num_parts.shape[0] # M | |
with torch.no_grad(): | |
latents = vae.encode( | |
part_surfaces, | |
**configs["model"]["vae"] | |
).latent_dist.sample() | |
noise = torch.randn_like(latents) | |
# For weighting schemes where we sample timesteps non-uniformly | |
u = compute_density_for_timestep_sampling( | |
weighting_scheme=configs["train"]["weighting_scheme"], | |
batch_size=num_objects, | |
logit_mean=configs["train"]["logit_mean"], | |
logit_std=configs["train"]["logit_std"], | |
mode_scale=configs["train"]["mode_scale"], | |
) | |
indices = (u * noise_scheduler.config.num_train_timesteps).long() | |
timesteps = noise_scheduler.timesteps[indices].to(accelerator.device) # [M, ] | |
# Repeat the timesteps for each part | |
timesteps = timesteps.repeat_interleave(num_parts) # [N, ] | |
sigmas = get_sigmas(timesteps, len(latents.shape), weight_dtype) | |
latent_model_input = noisy_latents = (1. - sigmas) * latents + sigmas * noise | |
if configs["train"]["cfg_dropout_prob"] > 0: | |
# We use the same dropout mask for the same part | |
dropout_mask = torch.rand(num_objects, device=accelerator.device) < configs["train"]["cfg_dropout_prob"] # [M, ] | |
dropout_mask = dropout_mask.repeat_interleave(num_parts) # [N, ] | |
if dropout_mask.any(): | |
image_embeds[dropout_mask] = negative_image_embeds[dropout_mask] | |
model_pred = transformer( | |
hidden_states=latent_model_input, | |
timestep=timesteps, | |
encoder_hidden_states=image_embeds, | |
attention_kwargs={"num_parts": num_parts} | |
).sample | |
if configs["train"]["training_objective"] == "x0": # Section 5 of https://arxiv.org/abs/2206.00364 | |
model_pred = model_pred * (-sigmas) + noisy_latents # predicted x_0 | |
target = latents | |
elif configs["train"]["training_objective"] == 'v': # flow matching | |
target = noise - latents | |
elif configs["train"]["training_objective"] == '-v': # reverse flow matching | |
# The training objective for TripoSG is the reverse of the flow matching objective. | |
# It uses "different directions", i.e., the negative velocity. | |
# This is probably a mistake in engineering, not very harmful. | |
# In TripoSG's rectified flow scheduler, prev_sample = sample + (sigma - sigma_next) * model_output | |
# See TripoSG's scheduler https://github.com/VAST-AI-Research/TripoSG/blob/main/triposg/schedulers/scheduling_rectified_flow.py#L296 | |
# While in diffusers's flow matching scheduler, prev_sample = sample + (sigma_next - sigma) * model_output | |
# See https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L454 | |
target = latents - noise | |
else: | |
raise ValueError(f"Unknown training objective [{configs['train']['training_objective']}]") | |
# For these weighting schemes use a uniform timestep sampling, so post-weight the loss | |
weighting = compute_loss_weighting_for_sd3( | |
configs["train"]["weighting_scheme"], | |
sigmas | |
) | |
loss = weighting * tF.mse_loss(model_pred.float(), target.float(), reduction="none") | |
loss = loss.mean(dim=list(range(1, len(loss.shape)))) | |
# Backpropagate | |
accelerator.backward(loss.mean()) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
# Gather the losses across all processes for logging (if we use distributed training) | |
loss = accelerator.gather(loss.detach()).mean() | |
logs = { | |
"loss": loss.item(), | |
"lr": lr_scheduler.get_last_lr()[0] | |
} | |
if args.use_ema: | |
ema_transformer.step(transformer.parameters()) | |
logs.update({"ema": ema_transformer.cur_decay_value}) | |
progress_bar.set_postfix(**logs) | |
progress_bar.update(1) | |
global_update_step += 1 | |
logger.info( | |
f"[{global_update_step:06d} / {total_updated_steps:06d}] " + | |
f"loss: {logs['loss']:.4f}, lr: {logs['lr']:.2e}" + | |
f", ema: {logs['ema']:.4f}" if args.use_ema else "" | |
) | |
# Log the training progress | |
if ( | |
global_update_step % configs["train"]["log_freq"] == 0 | |
or global_update_step == 1 | |
or global_update_step % updated_steps_per_epoch == 0 # last step of an epoch | |
): | |
if accelerator.is_main_process: | |
wandb.log({ | |
"training/loss": logs["loss"], | |
"training/lr": logs["lr"], | |
}, step=global_update_step) | |
if args.use_ema: | |
wandb.log({ | |
"training/ema": logs["ema"] | |
}, step=global_update_step) | |
# Save checkpoint | |
if ( | |
global_update_step % configs["train"]["save_freq"] == 0 # 1. every `save_freq` steps | |
or global_update_step % (configs["train"]["save_freq_epoch"] * updated_steps_per_epoch) == 0 # 2. every `save_freq_epoch` epochs | |
or global_update_step == total_updated_steps # 3. last step of an epoch | |
# or global_update_step == 1 # 4. first step | |
): | |
gc.collect() | |
if accelerator.distributed_type == accelerate.utils.DistributedType.DEEPSPEED: | |
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues | |
accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}")) | |
elif accelerator.is_main_process: | |
accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}")) | |
accelerator.wait_for_everyone() # ensure all processes have finished saving | |
gc.collect() | |
# Evaluate on the validation set | |
if args.max_val_steps > 0 and ( | |
(global_update_step % configs["train"]["early_eval_freq"] == 0 and global_update_step < configs["train"]["early_eval"]) # 1. more frequently at the beginning | |
or global_update_step % configs["train"]["eval_freq"] == 0 # 2. every `eval_freq` steps | |
or global_update_step % (configs["train"]["eval_freq_epoch"] * updated_steps_per_epoch) == 0 # 3. every `eval_freq_epoch` epochs | |
or global_update_step == total_updated_steps # 4. last step of an epoch | |
or global_update_step == 1 # 5. first step | |
): | |
# Use EMA parameters for evaluation | |
if args.use_ema: | |
# Store the Transformer parameters temporarily and load the EMA parameters to perform inference | |
ema_transformer.store(transformer.parameters()) | |
ema_transformer.copy_to(transformer.parameters()) | |
transformer.eval() | |
log_validation( | |
val_loader, random_val_loader, | |
feature_extractor_dinov2, image_encoder_dinov2, | |
vae, transformer, | |
global_update_step, eval_dir, | |
accelerator, logger, | |
args, configs | |
) | |
if args.use_ema: | |
# Switch back to the original Transformer parameters | |
ema_transformer.restore(transformer.parameters()) | |
torch.cuda.empty_cache() | |
gc.collect() | |
def log_validation( | |
dataloader, random_dataloader, | |
feature_extractor_dinov2, image_encoder_dinov2, | |
vae, transformer, | |
global_step, eval_dir, | |
accelerator, logger, | |
args, configs | |
): | |
val_noise_scheduler = RectifiedFlowScheduler.from_pretrained( | |
configs["model"]["pretrained_model_name_or_path"], | |
subfolder="scheduler" | |
) | |
pipeline = PartCrafterPipeline( | |
vae=vae, | |
transformer=accelerator.unwrap_model(transformer), | |
scheduler=val_noise_scheduler, | |
feature_extractor_dinov2=feature_extractor_dinov2, | |
image_encoder_dinov2=image_encoder_dinov2, | |
) | |
pipeline.set_progress_bar_config(disable=True) | |
# pipeline.enable_xformers_memory_efficient_attention() | |
if args.seed >= 0: | |
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | |
else: | |
generator = None | |
val_progress_bar = tqdm( | |
range(len(dataloader)) if args.max_val_steps is None else range(args.max_val_steps), | |
desc=f"Validation [{global_step:06d}]", | |
ncols=125, | |
disable=not accelerator.is_main_process | |
) | |
medias_dictlist, metrics_dictlist = defaultdict(list), defaultdict(list) | |
val_dataloder, random_val_dataloader = yield_forever(dataloader), yield_forever(random_dataloader) | |
val_step = 0 | |
while val_step < args.max_val_steps: | |
if val_step < args.max_val_steps // 2: | |
# fix the first half | |
batch = next(val_dataloder) | |
else: | |
# randomly sample the next batch | |
batch = next(random_val_dataloader) | |
images = batch["images"] | |
if len(images.shape) == 5: | |
images = images[0] # (1, N, H, W, 3) -> (N, H, W, 3) | |
images = [Image.fromarray(image) for image in images.cpu().numpy()] | |
part_surfaces = batch["part_surfaces"].cpu().numpy() | |
if len(part_surfaces.shape) == 4: | |
part_surfaces = part_surfaces[0] # (1, N, P, 6) -> (N, P, 6) | |
N = len(images) | |
val_progress_bar.set_postfix( | |
{"num_parts": N} | |
) | |
with torch.autocast("cuda", torch.float16): | |
for guidance_scale in sorted(args.val_guidance_scales): | |
pred_part_meshes = pipeline( | |
images, | |
num_inference_steps=configs['val']['num_inference_steps'], | |
num_tokens=configs['model']['vae']['num_tokens'], | |
guidance_scale=guidance_scale, | |
attention_kwargs={"num_parts": N}, | |
generator=generator, | |
max_num_expanded_coords=configs['val']['max_num_expanded_coords'], | |
use_flash_decoder=configs['val']['use_flash_decoder'], | |
).meshes | |
# Save the generated meshes | |
if accelerator.is_main_process: | |
local_eval_dir = os.path.join(eval_dir, f"{global_step:06d}", f"guidance_scale_{guidance_scale:.1f}") | |
os.makedirs(local_eval_dir, exist_ok=True) | |
rendered_images_list, rendered_normals_list = [], [] | |
# 1. save the gt image | |
images[0].save(os.path.join(local_eval_dir, f"{val_step:04d}.png")) | |
# 2. save the generated part meshes | |
for n in range(N): | |
if pred_part_meshes[n] is None: | |
# If the generated mesh is None (decoing error), use a dummy mesh | |
pred_part_meshes[n] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]]) | |
pred_part_meshes[n].export(os.path.join(local_eval_dir, f"{val_step:04d}_{n:02d}.glb")) | |
# 3. render the generated mesh and save the rendered images | |
pred_mesh = get_colored_mesh_composition(pred_part_meshes) | |
rendered_images: List[Image.Image] = render_views_around_mesh( | |
pred_mesh, | |
num_views=configs['val']['rendering']['num_views'], | |
radius=configs['val']['rendering']['radius'], | |
) | |
rendered_normals: List[Image.Image] = render_normal_views_around_mesh( | |
pred_mesh, | |
num_views=configs['val']['rendering']['num_views'], | |
radius=configs['val']['rendering']['radius'], | |
) | |
export_renderings( | |
rendered_images, | |
os.path.join(local_eval_dir, f"{val_step:04d}.gif"), | |
fps=configs['val']['rendering']['fps'] | |
) | |
export_renderings( | |
rendered_normals, | |
os.path.join(local_eval_dir, f"{val_step:04d}_normals.gif"), | |
fps=configs['val']['rendering']['fps'] | |
) | |
rendered_images_list.append(rendered_images) | |
rendered_normals_list.append(rendered_normals) | |
medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/gt_image"] += [images[0]] # List[Image.Image] TODO: support batch size > 1 | |
medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/pred_rendered_images"] += rendered_images_list # List[List[Image.Image]] | |
medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/pred_rendered_normals"] += rendered_normals_list # List[List[Image.Image]] | |
################################ Compute generation metrics ################################ | |
parts_chamfer_distances, parts_f_scores = [], [] | |
for n in range(N): | |
# gt_part_surface = part_surfaces[n] | |
# pred_part_mesh = pred_part_meshes[n] | |
# if pred_part_mesh is None: | |
# # If the generated mesh is None (decoing error), use a dummy mesh | |
# pred_part_mesh = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]]) | |
# part_cd, part_f = compute_cd_and_f_score_in_training( | |
# gt_part_surface, pred_part_mesh, | |
# num_samples=configs['val']['metric']['cd_num_samples'], | |
# threshold=configs['val']['metric']['f1_score_threshold'], | |
# metric=configs['val']['metric']['cd_metric'] | |
# ) | |
# # avoid nan | |
# part_cd = configs['val']['metric']['default_cd'] if np.isnan(part_cd) else part_cd | |
# part_f = configs['val']['metric']['default_f1'] if np.isnan(part_f) else part_f | |
# parts_chamfer_distances.append(part_cd) | |
# parts_f_scores.append(part_f) | |
# TODO: Fix this | |
# Disable chamfer distance and F1 score for now | |
parts_chamfer_distances.append(0.0) | |
parts_f_scores.append(0.0) | |
parts_chamfer_distances = torch.tensor(parts_chamfer_distances, device=accelerator.device) | |
parts_f_scores = torch.tensor(parts_f_scores, device=accelerator.device) | |
metrics_dictlist[f"parts_chamfer_distance_cfg{guidance_scale:.1f}"].append(parts_chamfer_distances.mean()) | |
metrics_dictlist[f"parts_f_score_cfg{guidance_scale:.1f}"].append(parts_f_scores.mean()) | |
# Only log the last (biggest) cfg metrics in the progress bar | |
val_logs = { | |
"parts_chamfer_distance": parts_chamfer_distances.mean().item(), | |
"parts_f_score": parts_f_scores.mean().item(), | |
} | |
val_progress_bar.set_postfix(**val_logs) | |
logger.info( | |
f"Validation [{val_step:02d}/{args.max_val_steps:02d}] " + | |
f"parts_chamfer_distance: {val_logs['parts_chamfer_distance']:.4f}, parts_f_score: {val_logs['parts_f_score']:.4f}" | |
) | |
logger.info( | |
f"parts_chamfer_distances: {[f'{x:.4f}' for x in parts_chamfer_distances.tolist()]}" | |
) | |
logger.info( | |
f"parts_f_scores: {[f'{x:.4f}' for x in parts_f_scores.tolist()]}" | |
) | |
val_step += 1 | |
val_progress_bar.update(1) | |
val_progress_bar.close() | |
if accelerator.is_main_process: | |
for key, value in medias_dictlist.items(): | |
if isinstance(value[0], Image.Image): # assuming gt_image | |
image_grid = make_grid_for_images_or_videos( | |
value, | |
nrow=configs['val']['nrow'], | |
return_type='pil', | |
) | |
image_grid.save(os.path.join(eval_dir, f"{global_step:06d}", f"{key}.png")) | |
wandb.log({f"validation/{key}": wandb.Image(image_grid)}, step=global_step) | |
else: # assuming pred_rendered_images or pred_rendered_normals | |
image_grids = make_grid_for_images_or_videos( | |
value, | |
nrow=configs['val']['nrow'], | |
return_type='ndarray', | |
) | |
wandb.log({ | |
f"validation/{key}": wandb.Video( | |
image_grids, | |
fps=configs['val']['rendering']['fps'], | |
format="gif" | |
)}, step=global_step) | |
image_grids = [Image.fromarray(image_grid.transpose(1, 2, 0)) for image_grid in image_grids] | |
export_renderings( | |
image_grids, | |
os.path.join(eval_dir, f"{global_step:06d}", f"{key}.gif"), | |
fps=configs['val']['rendering']['fps'] | |
) | |
for k, v in metrics_dictlist.items(): | |
wandb.log({f"validation/{k}": torch.tensor(v).mean().item()}, step=global_step) | |
if __name__ == "__main__": | |
main() | |