xinjie.wang
update
631a83a
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
# Both under the Apache License, Version 2.0.
import json
import os
import time
from collections import defaultdict
from typing import Dict, Optional, Tuple
import cv2
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import tyro
import yaml
from fused_ssim import fused_ssim
from gsplat.distributed import cli
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import (
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
)
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from typing_extensions import Literal, assert_never
from embodied_gen.data.datasets import PanoGSplatDataset
from embodied_gen.utils.config import GsplatTrainConfig
from embodied_gen.utils.gaussian import (
create_splats_with_optimizers,
export_splats,
resize_pinhole_intrinsics,
set_random_seed,
)
class Runner:
"""Engine for training and testing from gsplat example.
Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
"""
def __init__(
self,
local_rank: int,
world_rank,
world_size: int,
cfg: GsplatTrainConfig,
) -> None:
set_random_seed(42 + local_rank)
self.cfg = cfg
self.world_rank = world_rank
self.local_rank = local_rank
self.world_size = world_size
self.device = f"cuda:{local_rank}"
# Where to dump results.
os.makedirs(cfg.result_dir, exist_ok=True)
# Setup output directories.
self.ckpt_dir = f"{cfg.result_dir}/ckpts"
os.makedirs(self.ckpt_dir, exist_ok=True)
self.stats_dir = f"{cfg.result_dir}/stats"
os.makedirs(self.stats_dir, exist_ok=True)
self.render_dir = f"{cfg.result_dir}/renders"
os.makedirs(self.render_dir, exist_ok=True)
self.ply_dir = f"{cfg.result_dir}/ply"
os.makedirs(self.ply_dir, exist_ok=True)
# Tensorboard
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
self.valset = PanoGSplatDataset(
cfg.data_dir, split="train", max_sample_num=6
)
self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
self.scene_scale = cfg.scene_scale
# Model
self.splats, self.optimizers = create_splats_with_optimizers(
self.trainset.points,
self.trainset.points_rgb,
init_num_pts=cfg.init_num_pts,
init_extent=cfg.init_extent,
init_opacity=cfg.init_opa,
init_scale=cfg.init_scale,
means_lr=cfg.means_lr,
scales_lr=cfg.scales_lr,
opacities_lr=cfg.opacities_lr,
quats_lr=cfg.quats_lr,
sh0_lr=cfg.sh0_lr,
shN_lr=cfg.shN_lr,
scene_scale=self.scene_scale,
sh_degree=cfg.sh_degree,
sparse_grad=cfg.sparse_grad,
visible_adam=cfg.visible_adam,
batch_size=cfg.batch_size,
feature_dim=None,
device=self.device,
world_rank=world_rank,
world_size=world_size,
)
print("Model initialized. Number of GS:", len(self.splats["means"]))
# Densification Strategy
self.cfg.strategy.check_sanity(self.splats, self.optimizers)
if isinstance(self.cfg.strategy, DefaultStrategy):
self.strategy_state = self.cfg.strategy.initialize_state(
scene_scale=self.scene_scale
)
elif isinstance(self.cfg.strategy, MCMCStrategy):
self.strategy_state = self.cfg.strategy.initialize_state()
else:
assert_never(self.cfg.strategy)
# Losses & Metrics.
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
self.device
)
self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
if cfg.lpips_net == "alex":
self.lpips = LearnedPerceptualImagePatchSimilarity(
net_type="alex", normalize=True
).to(self.device)
elif cfg.lpips_net == "vgg":
# The 3DGS official repo uses lpips vgg, which is equivalent with the following:
self.lpips = LearnedPerceptualImagePatchSimilarity(
net_type="vgg", normalize=False
).to(self.device)
else:
raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
def rasterize_splats(
self,
camtoworlds: Tensor,
Ks: Tensor,
width: int,
height: int,
masks: Optional[Tensor] = None,
rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
**kwargs,
) -> Tuple[Tensor, Tensor, Dict]:
means = self.splats["means"] # [N, 3]
# quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
# rasterization does normalization internally
quats = self.splats["quats"] # [N, 4]
scales = torch.exp(self.splats["scales"]) # [N, 3]
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
image_ids = kwargs.pop("image_ids", None)
colors = torch.cat(
[self.splats["sh0"], self.splats["shN"]], 1
) # [N, K, 3]
if rasterize_mode is None:
rasterize_mode = (
"antialiased" if self.cfg.antialiased else "classic"
)
if camera_model is None:
camera_model = self.cfg.camera_model
render_colors, render_alphas, info = rasterization(
means=means,
quats=quats,
scales=scales,
opacities=opacities,
colors=colors,
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
Ks=Ks, # [C, 3, 3]
width=width,
height=height,
packed=self.cfg.packed,
absgrad=(
self.cfg.strategy.absgrad
if isinstance(self.cfg.strategy, DefaultStrategy)
else False
),
sparse_grad=self.cfg.sparse_grad,
rasterize_mode=rasterize_mode,
distributed=self.world_size > 1,
camera_model=self.cfg.camera_model,
with_ut=self.cfg.with_ut,
with_eval3d=self.cfg.with_eval3d,
**kwargs,
)
if masks is not None:
render_colors[~masks] = 0
return render_colors, render_alphas, info
def train(self):
cfg = self.cfg
device = self.device
world_rank = self.world_rank
# Dump cfg.
if world_rank == 0:
with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
yaml.dump(vars(cfg), f)
max_steps = cfg.max_steps
init_step = 0
schedulers = [
# means has a learning rate schedule, that end at 0.01 of the initial value
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
),
]
trainloader = torch.utils.data.DataLoader(
self.trainset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=4,
persistent_workers=True,
pin_memory=True,
)
trainloader_iter = iter(trainloader)
# Training loop.
global_tic = time.time()
pbar = tqdm.tqdm(range(init_step, max_steps))
for step in pbar:
try:
data = next(trainloader_iter)
except StopIteration:
trainloader_iter = iter(trainloader)
data = next(trainloader_iter)
camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
Ks = data["K"].to(device) # [1, 3, 3]
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
image_ids = data["image_id"].to(device)
masks = (
data["mask"].to(device) if "mask" in data else None
) # [1, H, W]
if cfg.depth_loss:
points = data["points"].to(device) # [1, M, 2]
depths_gt = data["depths"].to(device) # [1, M]
height, width = pixels.shape[1:3]
# sh schedule
sh_degree_to_use = min(
step // cfg.sh_degree_interval, cfg.sh_degree
)
# forward
renders, alphas, info = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=sh_degree_to_use,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
masks=masks,
)
if renders.shape[-1] == 4:
colors, depths = renders[..., 0:3], renders[..., 3:4]
else:
colors, depths = renders, None
if cfg.random_bkgd:
bkgd = torch.rand(1, 3, device=device)
colors = colors + bkgd * (1.0 - alphas)
self.cfg.strategy.step_pre_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
)
# loss
l1loss = F.l1_loss(colors, pixels)
ssimloss = 1.0 - fused_ssim(
colors.permute(0, 3, 1, 2),
pixels.permute(0, 3, 1, 2),
padding="valid",
)
loss = (
l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
)
if cfg.depth_loss:
# query depths from depth map
points = torch.stack(
[
points[:, :, 0] / (width - 1) * 2 - 1,
points[:, :, 1] / (height - 1) * 2 - 1,
],
dim=-1,
) # normalize to [-1, 1]
grid = points.unsqueeze(2) # [1, M, 1, 2]
depths = F.grid_sample(
depths.permute(0, 3, 1, 2), grid, align_corners=True
) # [1, 1, M, 1]
depths = depths.squeeze(3).squeeze(1) # [1, M]
# calculate loss in disparity space
disp = torch.where(
depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
)
disp_gt = 1.0 / depths_gt # [1, M]
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
loss += depthloss * cfg.depth_lambda
# regularizations
if cfg.opacity_reg > 0.0:
loss += (
cfg.opacity_reg
* torch.sigmoid(self.splats["opacities"]).mean()
)
if cfg.scale_reg > 0.0:
loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
loss.backward()
desc = (
f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
)
if cfg.depth_loss:
desc += f"depth loss={depthloss.item():.6f}| "
pbar.set_description(desc)
# write images (gt and render)
# if world_rank == 0 and step % 800 == 0:
# canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
# canvas = canvas.reshape(-1, *canvas.shape[2:])
# imageio.imwrite(
# f"{self.render_dir}/train_rank{self.world_rank}.png",
# (canvas * 255).astype(np.uint8),
# )
if (
world_rank == 0
and cfg.tb_every > 0
and step % cfg.tb_every == 0
):
mem = torch.cuda.max_memory_allocated() / 1024**3
self.writer.add_scalar("train/loss", loss.item(), step)
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
self.writer.add_scalar(
"train/num_GS", len(self.splats["means"]), step
)
self.writer.add_scalar("train/mem", mem, step)
if cfg.depth_loss:
self.writer.add_scalar(
"train/depthloss", depthloss.item(), step
)
if cfg.tb_save_image:
canvas = (
torch.cat([pixels, colors], dim=2)
.detach()
.cpu()
.numpy()
)
canvas = canvas.reshape(-1, *canvas.shape[2:])
self.writer.add_image("train/render", canvas, step)
self.writer.flush()
# save checkpoint before updating the model
if (
step in [i - 1 for i in cfg.save_steps]
or step == max_steps - 1
):
mem = torch.cuda.max_memory_allocated() / 1024**3
stats = {
"mem": mem,
"ellipse_time": time.time() - global_tic,
"num_GS": len(self.splats["means"]),
}
print("Step: ", step, stats)
with open(
f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
"w",
) as f:
json.dump(stats, f)
data = {"step": step, "splats": self.splats.state_dict()}
torch.save(
data,
f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
)
if (
step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
) and cfg.save_ply:
sh0 = self.splats["sh0"]
shN = self.splats["shN"]
means = self.splats["means"]
scales = self.splats["scales"]
quats = self.splats["quats"]
opacities = self.splats["opacities"]
export_splats(
means=means,
scales=scales,
quats=quats,
opacities=opacities,
sh0=sh0,
shN=shN,
format="ply",
save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
)
# Turn Gradients into Sparse Tensor before running optimizer
if cfg.sparse_grad:
assert (
cfg.packed
), "Sparse gradients only work with packed mode."
gaussian_ids = info["gaussian_ids"]
for k in self.splats.keys():
grad = self.splats[k].grad
if grad is None or grad.is_sparse:
continue
self.splats[k].grad = torch.sparse_coo_tensor(
indices=gaussian_ids[None], # [1, nnz]
values=grad[gaussian_ids], # [nnz, ...]
size=self.splats[k].size(), # [N, ...]
is_coalesced=len(Ks) == 1,
)
if cfg.visible_adam:
gaussian_cnt = self.splats.means.shape[0]
if cfg.packed:
visibility_mask = torch.zeros_like(
self.splats["opacities"], dtype=bool
)
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
else:
visibility_mask = (info["radii"] > 0).all(-1).any(0)
# optimize
for optimizer in self.optimizers.values():
if cfg.visible_adam:
optimizer.step(visibility_mask)
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
for scheduler in schedulers:
scheduler.step()
# Run post-backward steps after backward and optimizer
if isinstance(self.cfg.strategy, DefaultStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
packed=cfg.packed,
)
elif isinstance(self.cfg.strategy, MCMCStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
lr=schedulers[0].get_last_lr()[0],
)
else:
assert_never(self.cfg.strategy)
# eval the full set
if step in [i - 1 for i in cfg.eval_steps]:
self.eval(step)
self.render_video(step)
@torch.no_grad()
def eval(
self,
step: int,
stage: str = "val",
canvas_h: int = 512,
canvas_w: int = 1024,
):
"""Entry for evaluation."""
print("Running evaluation...")
cfg = self.cfg
device = self.device
world_rank = self.world_rank
valloader = torch.utils.data.DataLoader(
self.valset, batch_size=1, shuffle=False, num_workers=1
)
ellipse_time = 0
metrics = defaultdict(list)
for i, data in enumerate(valloader):
camtoworlds = data["camtoworld"].to(device)
Ks = data["K"].to(device)
pixels = data["image"].to(device) / 255.0
height, width = pixels.shape[1:3]
masks = data["mask"].to(device) if "mask" in data else None
pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
torch.cuda.synchronize()
tic = time.time()
colors, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
masks=masks,
) # [1, H, W, 3]
torch.cuda.synchronize()
ellipse_time += max(time.time() - tic, 1e-10)
colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
colors = torch.clamp(colors, 0.0, 1.0)
canvas_list = [pixels, colors]
if world_rank == 0:
canvas = torch.cat(canvas_list, dim=2).squeeze(0)
canvas = canvas.permute(1, 2, 0) # CHW -> HWC
canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
cv2.imwrite(
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
canvas[..., ::-1],
)
metrics["psnr"].append(self.psnr(colors, pixels))
metrics["ssim"].append(self.ssim(colors, pixels))
metrics["lpips"].append(self.lpips(colors, pixels))
if world_rank == 0:
ellipse_time /= len(valloader)
stats = {
k: torch.stack(v).mean().item() for k, v in metrics.items()
}
stats.update(
{
"ellipse_time": ellipse_time,
"num_GS": len(self.splats["means"]),
}
)
print(
f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
f"Time: {stats['ellipse_time']:.3f}s/image "
f"Number of GS: {stats['num_GS']}"
)
# save stats as json
with open(
f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
) as f:
json.dump(stats, f)
# save stats to tensorboard
for k, v in stats.items():
self.writer.add_scalar(f"{stage}/{k}", v, step)
self.writer.flush()
@torch.no_grad()
def render_video(
self, step: int, canvas_h: int = 512, canvas_w: int = 1024
):
testloader = torch.utils.data.DataLoader(
self.testset, batch_size=1, shuffle=False, num_workers=1
)
images_cache = []
depth_global_min, depth_global_max = float("inf"), -float("inf")
for data in testloader:
camtoworlds = data["camtoworld"].to(self.device)
Ks = resize_pinhole_intrinsics(
data["K"].squeeze(),
raw_hw=(data["image_h"].item(), data["image_w"].item()),
new_hw=(canvas_h, canvas_w // 2),
).to(self.device)
renders, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks[None, ...],
width=canvas_w // 2,
height=canvas_h,
sh_degree=self.cfg.sh_degree,
near_plane=self.cfg.near_plane,
far_plane=self.cfg.far_plane,
render_mode="RGB+ED",
) # [1, H, W, 4]
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
colors = (colors * 255).to(torch.uint8).cpu().numpy()
depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
images_cache.append([colors, depths])
depth_global_min = min(depth_global_min, depths.min().item())
depth_global_max = max(depth_global_max, depths.max().item())
video_path = f"{self.render_dir}/video_step{step}.mp4"
writer = imageio.get_writer(video_path, fps=30)
for rgb, depth in images_cache:
depth_normalized = torch.clip(
(depth - depth_global_min)
/ (depth_global_max - depth_global_min),
0,
1,
)
depth_normalized = (
(depth_normalized * 255).to(torch.uint8).cpu().numpy()
)
depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
image = np.concatenate([rgb, depth_map], axis=1)
writer.append_data(image)
writer.close()
def entrypoint(
local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
):
runner = Runner(local_rank, world_rank, world_size, cfg)
if cfg.ckpt is not None:
# run eval only
ckpts = [
torch.load(file, map_location=runner.device, weights_only=True)
for file in cfg.ckpt
]
for k in runner.splats.keys():
runner.splats[k].data = torch.cat(
[ckpt["splats"][k] for ckpt in ckpts]
)
step = ckpts[0]["step"]
runner.eval(step=step)
runner.render_video(step=step)
else:
runner.train()
runner.render_video(step=cfg.max_steps - 1)
if __name__ == "__main__":
configs = {
"default": (
"Gaussian splatting training using densification heuristics from the original paper.",
GsplatTrainConfig(
strategy=DefaultStrategy(verbose=True),
),
),
"mcmc": (
"Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
GsplatTrainConfig(
init_scale=0.1,
opacity_reg=0.01,
scale_reg=0.01,
strategy=MCMCStrategy(verbose=True),
),
),
}
cfg = tyro.extras.overridable_config_cli(configs)
cfg.adjust_steps(cfg.steps_scaler)
cli(entrypoint, cfg, verbose=True)