# Project EmbodiedGen # # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py # Both under the Apache License, Version 2.0. import json import os import time from collections import defaultdict from typing import Dict, Optional, Tuple import cv2 import imageio import numpy as np import torch import torch.nn.functional as F import tqdm import tyro import yaml from fused_ssim import fused_ssim from gsplat.distributed import cli from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy from torch import Tensor from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import ( PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, ) from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never from embodied_gen.data.datasets import PanoGSplatDataset from embodied_gen.utils.config import GsplatTrainConfig from embodied_gen.utils.gaussian import ( create_splats_with_optimizers, export_splats, resize_pinhole_intrinsics, set_random_seed, ) class Runner: """Engine for training and testing from gsplat example. Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py """ def __init__( self, local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig, ) -> None: set_random_seed(42 + local_rank) self.cfg = cfg self.world_rank = world_rank self.local_rank = local_rank self.world_size = world_size self.device = f"cuda:{local_rank}" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) # Setup output directories. self.ckpt_dir = f"{cfg.result_dir}/ckpts" os.makedirs(self.ckpt_dir, exist_ok=True) self.stats_dir = f"{cfg.result_dir}/stats" os.makedirs(self.stats_dir, exist_ok=True) self.render_dir = f"{cfg.result_dir}/renders" os.makedirs(self.render_dir, exist_ok=True) self.ply_dir = f"{cfg.result_dir}/ply" os.makedirs(self.ply_dir, exist_ok=True) # Tensorboard self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") self.trainset = PanoGSplatDataset(cfg.data_dir, split="train") self.valset = PanoGSplatDataset( cfg.data_dir, split="train", max_sample_num=6 ) self.testset = PanoGSplatDataset(cfg.data_dir, split="eval") self.scene_scale = cfg.scene_scale # Model self.splats, self.optimizers = create_splats_with_optimizers( self.trainset.points, self.trainset.points_rgb, init_num_pts=cfg.init_num_pts, init_extent=cfg.init_extent, init_opacity=cfg.init_opa, init_scale=cfg.init_scale, means_lr=cfg.means_lr, scales_lr=cfg.scales_lr, opacities_lr=cfg.opacities_lr, quats_lr=cfg.quats_lr, sh0_lr=cfg.sh0_lr, shN_lr=cfg.shN_lr, scene_scale=self.scene_scale, sh_degree=cfg.sh_degree, sparse_grad=cfg.sparse_grad, visible_adam=cfg.visible_adam, batch_size=cfg.batch_size, feature_dim=None, device=self.device, world_rank=world_rank, world_size=world_size, ) print("Model initialized. Number of GS:", len(self.splats["means"])) # Densification Strategy self.cfg.strategy.check_sanity(self.splats, self.optimizers) if isinstance(self.cfg.strategy, DefaultStrategy): self.strategy_state = self.cfg.strategy.initialize_state( scene_scale=self.scene_scale ) elif isinstance(self.cfg.strategy, MCMCStrategy): self.strategy_state = self.cfg.strategy.initialize_state() else: assert_never(self.cfg.strategy) # Losses & Metrics. self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to( self.device ) self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) if cfg.lpips_net == "alex": self.lpips = LearnedPerceptualImagePatchSimilarity( net_type="alex", normalize=True ).to(self.device) elif cfg.lpips_net == "vgg": # The 3DGS official repo uses lpips vgg, which is equivalent with the following: self.lpips = LearnedPerceptualImagePatchSimilarity( net_type="vgg", normalize=False ).to(self.device) else: raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") def rasterize_splats( self, camtoworlds: Tensor, Ks: Tensor, width: int, height: int, masks: Optional[Tensor] = None, rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] # rasterization does normalization internally quats = self.splats["quats"] # [N, 4] scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] image_ids = kwargs.pop("image_ids", None) colors = torch.cat( [self.splats["sh0"], self.splats["shN"]], 1 ) # [N, K, 3] if rasterize_mode is None: rasterize_mode = ( "antialiased" if self.cfg.antialiased else "classic" ) if camera_model is None: camera_model = self.cfg.camera_model render_colors, render_alphas, info = rasterization( means=means, quats=quats, scales=scales, opacities=opacities, colors=colors, viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] width=width, height=height, packed=self.cfg.packed, absgrad=( self.cfg.strategy.absgrad if isinstance(self.cfg.strategy, DefaultStrategy) else False ), sparse_grad=self.cfg.sparse_grad, rasterize_mode=rasterize_mode, distributed=self.world_size > 1, camera_model=self.cfg.camera_model, with_ut=self.cfg.with_ut, with_eval3d=self.cfg.with_eval3d, **kwargs, ) if masks is not None: render_colors[~masks] = 0 return render_colors, render_alphas, info def train(self): cfg = self.cfg device = self.device world_rank = self.world_rank # Dump cfg. if world_rank == 0: with open(f"{cfg.result_dir}/cfg.yml", "w") as f: yaml.dump(vars(cfg), f) max_steps = cfg.max_steps init_step = 0 schedulers = [ # means has a learning rate schedule, that end at 0.01 of the initial value torch.optim.lr_scheduler.ExponentialLR( self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) ), ] trainloader = torch.utils.data.DataLoader( self.trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=4, persistent_workers=True, pin_memory=True, ) trainloader_iter = iter(trainloader) # Training loop. global_tic = time.time() pbar = tqdm.tqdm(range(init_step, max_steps)) for step in pbar: try: data = next(trainloader_iter) except StopIteration: trainloader_iter = iter(trainloader) data = next(trainloader_iter) camtoworlds = data["camtoworld"].to(device) # [1, 4, 4] Ks = data["K"].to(device) # [1, 3, 3] pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] image_ids = data["image_id"].to(device) masks = ( data["mask"].to(device) if "mask" in data else None ) # [1, H, W] if cfg.depth_loss: points = data["points"].to(device) # [1, M, 2] depths_gt = data["depths"].to(device) # [1, M] height, width = pixels.shape[1:3] # sh schedule sh_degree_to_use = min( step // cfg.sh_degree_interval, cfg.sh_degree ) # forward renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, height=height, sh_degree=sh_degree_to_use, near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, render_mode="RGB+ED" if cfg.depth_loss else "RGB", masks=masks, ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] else: colors, depths = renders, None if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) self.cfg.strategy.step_pre_backward( params=self.splats, optimizers=self.optimizers, state=self.strategy_state, step=step, info=info, ) # loss l1loss = F.l1_loss(colors, pixels) ssimloss = 1.0 - fused_ssim( colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid", ) loss = ( l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda ) if cfg.depth_loss: # query depths from depth map points = torch.stack( [ points[:, :, 0] / (width - 1) * 2 - 1, points[:, :, 1] / (height - 1) * 2 - 1, ], dim=-1, ) # normalize to [-1, 1] grid = points.unsqueeze(2) # [1, M, 1, 2] depths = F.grid_sample( depths.permute(0, 3, 1, 2), grid, align_corners=True ) # [1, 1, M, 1] depths = depths.squeeze(3).squeeze(1) # [1, M] # calculate loss in disparity space disp = torch.where( depths > 0.0, 1.0 / depths, torch.zeros_like(depths) ) disp_gt = 1.0 / depths_gt # [1, M] depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda # regularizations if cfg.opacity_reg > 0.0: loss += ( cfg.opacity_reg * torch.sigmoid(self.splats["opacities"]).mean() ) if cfg.scale_reg > 0.0: loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean() loss.backward() desc = ( f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " ) if cfg.depth_loss: desc += f"depth loss={depthloss.item():.6f}| " pbar.set_description(desc) # write images (gt and render) # if world_rank == 0 and step % 800 == 0: # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() # canvas = canvas.reshape(-1, *canvas.shape[2:]) # imageio.imwrite( # f"{self.render_dir}/train_rank{self.world_rank}.png", # (canvas * 255).astype(np.uint8), # ) if ( world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0 ): mem = torch.cuda.max_memory_allocated() / 1024**3 self.writer.add_scalar("train/loss", loss.item(), step) self.writer.add_scalar("train/l1loss", l1loss.item(), step) self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) self.writer.add_scalar( "train/num_GS", len(self.splats["means"]), step ) self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar( "train/depthloss", depthloss.item(), step ) if cfg.tb_save_image: canvas = ( torch.cat([pixels, colors], dim=2) .detach() .cpu() .numpy() ) canvas = canvas.reshape(-1, *canvas.shape[2:]) self.writer.add_image("train/render", canvas, step) self.writer.flush() # save checkpoint before updating the model if ( step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1 ): mem = torch.cuda.max_memory_allocated() / 1024**3 stats = { "mem": mem, "ellipse_time": time.time() - global_tic, "num_GS": len(self.splats["means"]), } print("Step: ", step, stats) with open( f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", "w", ) as f: json.dump(stats, f) data = {"step": step, "splats": self.splats.state_dict()} torch.save( data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt", ) if ( step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 ) and cfg.save_ply: sh0 = self.splats["sh0"] shN = self.splats["shN"] means = self.splats["means"] scales = self.splats["scales"] quats = self.splats["quats"] opacities = self.splats["opacities"] export_splats( means=means, scales=scales, quats=quats, opacities=opacities, sh0=sh0, shN=shN, format="ply", save_to=f"{self.ply_dir}/point_cloud_{step}.ply", ) # Turn Gradients into Sparse Tensor before running optimizer if cfg.sparse_grad: assert ( cfg.packed ), "Sparse gradients only work with packed mode." gaussian_ids = info["gaussian_ids"] for k in self.splats.keys(): grad = self.splats[k].grad if grad is None or grad.is_sparse: continue self.splats[k].grad = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=grad[gaussian_ids], # [nnz, ...] size=self.splats[k].size(), # [N, ...] is_coalesced=len(Ks) == 1, ) if cfg.visible_adam: gaussian_cnt = self.splats.means.shape[0] if cfg.packed: visibility_mask = torch.zeros_like( self.splats["opacities"], dtype=bool ) visibility_mask.scatter_(0, info["gaussian_ids"], 1) else: visibility_mask = (info["radii"] > 0).all(-1).any(0) # optimize for optimizer in self.optimizers.values(): if cfg.visible_adam: optimizer.step(visibility_mask) else: optimizer.step() optimizer.zero_grad(set_to_none=True) for scheduler in schedulers: scheduler.step() # Run post-backward steps after backward and optimizer if isinstance(self.cfg.strategy, DefaultStrategy): self.cfg.strategy.step_post_backward( params=self.splats, optimizers=self.optimizers, state=self.strategy_state, step=step, info=info, packed=cfg.packed, ) elif isinstance(self.cfg.strategy, MCMCStrategy): self.cfg.strategy.step_post_backward( params=self.splats, optimizers=self.optimizers, state=self.strategy_state, step=step, info=info, lr=schedulers[0].get_last_lr()[0], ) else: assert_never(self.cfg.strategy) # eval the full set if step in [i - 1 for i in cfg.eval_steps]: self.eval(step) self.render_video(step) @torch.no_grad() def eval( self, step: int, stage: str = "val", canvas_h: int = 512, canvas_w: int = 1024, ): """Entry for evaluation.""" print("Running evaluation...") cfg = self.cfg device = self.device world_rank = self.world_rank valloader = torch.utils.data.DataLoader( self.valset, batch_size=1, shuffle=False, num_workers=1 ) ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(valloader): camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 height, width = pixels.shape[1:3] masks = data["mask"].to(device) if "mask" in data else None pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2)) torch.cuda.synchronize() tic = time.time() colors, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, height=height, sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, masks=masks, ) # [1, H, W, 3] torch.cuda.synchronize() ellipse_time += max(time.time() - tic, 1e-10) colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2)) colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if world_rank == 0: canvas = torch.cat(canvas_list, dim=2).squeeze(0) canvas = canvas.permute(1, 2, 0) # CHW -> HWC canvas = (canvas * 255).to(torch.uint8).cpu().numpy() cv2.imwrite( f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", canvas[..., ::-1], ) metrics["psnr"].append(self.psnr(colors, pixels)) metrics["ssim"].append(self.ssim(colors, pixels)) metrics["lpips"].append(self.lpips(colors, pixels)) if world_rank == 0: ellipse_time /= len(valloader) stats = { k: torch.stack(v).mean().item() for k, v in metrics.items() } stats.update( { "ellipse_time": ellipse_time, "num_GS": len(self.splats["means"]), } ) print( f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " f"Time: {stats['ellipse_time']:.3f}s/image " f"Number of GS: {stats['num_GS']}" ) # save stats as json with open( f"{self.stats_dir}/{stage}_step{step:04d}.json", "w" ) as f: json.dump(stats, f) # save stats to tensorboard for k, v in stats.items(): self.writer.add_scalar(f"{stage}/{k}", v, step) self.writer.flush() @torch.no_grad() def render_video( self, step: int, canvas_h: int = 512, canvas_w: int = 1024 ): testloader = torch.utils.data.DataLoader( self.testset, batch_size=1, shuffle=False, num_workers=1 ) images_cache = [] depth_global_min, depth_global_max = float("inf"), -float("inf") for data in testloader: camtoworlds = data["camtoworld"].to(self.device) Ks = resize_pinhole_intrinsics( data["K"].squeeze(), raw_hw=(data["image_h"].item(), data["image_w"].item()), new_hw=(canvas_h, canvas_w // 2), ).to(self.device) renders, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks[None, ...], width=canvas_w // 2, height=canvas_h, sh_degree=self.cfg.sh_degree, near_plane=self.cfg.near_plane, far_plane=self.cfg.far_plane, render_mode="RGB+ED", ) # [1, H, W, 4] colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] colors = (colors * 255).to(torch.uint8).cpu().numpy() depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device. images_cache.append([colors, depths]) depth_global_min = min(depth_global_min, depths.min().item()) depth_global_max = max(depth_global_max, depths.max().item()) video_path = f"{self.render_dir}/video_step{step}.mp4" writer = imageio.get_writer(video_path, fps=30) for rgb, depth in images_cache: depth_normalized = torch.clip( (depth - depth_global_min) / (depth_global_max - depth_global_min), 0, 1, ) depth_normalized = ( (depth_normalized * 255).to(torch.uint8).cpu().numpy() ) depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET) image = np.concatenate([rgb, depth_map], axis=1) writer.append_data(image) writer.close() def entrypoint( local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig ): runner = Runner(local_rank, world_rank, world_size, cfg) if cfg.ckpt is not None: # run eval only ckpts = [ torch.load(file, map_location=runner.device, weights_only=True) for file in cfg.ckpt ] for k in runner.splats.keys(): runner.splats[k].data = torch.cat( [ckpt["splats"][k] for ckpt in ckpts] ) step = ckpts[0]["step"] runner.eval(step=step) runner.render_video(step=step) else: runner.train() runner.render_video(step=cfg.max_steps - 1) if __name__ == "__main__": configs = { "default": ( "Gaussian splatting training using densification heuristics from the original paper.", GsplatTrainConfig( strategy=DefaultStrategy(verbose=True), ), ), "mcmc": ( "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", GsplatTrainConfig( init_scale=0.1, opacity_reg=0.01, scale_reg=0.01, strategy=MCMCStrategy(verbose=True), ), ), } cfg = tyro.extras.overridable_config_cli(configs) cfg.adjust_steps(cfg.steps_scaler) cli(entrypoint, cfg, verbose=True)