xinjie.wang
update
575f14d
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# Part of the code comes from https://github.com/nerfstudio-project/gsplat
# Both under the Apache License, Version 2.0.
import math
import random
from io import BytesIO
from typing import Dict, Literal, Optional, Tuple
import numpy as np
import torch
import trimesh
from gsplat.optimizers import SelectiveAdam
from scipy.spatial.transform import Rotation
from sklearn.neighbors import NearestNeighbors
from torch import Tensor
from embodied_gen.models.gs_model import GaussianOperator
__all__ = [
"set_random_seed",
"export_splats",
"create_splats_with_optimizers",
"compute_pinhole_intrinsics",
"resize_pinhole_intrinsics",
"restore_scene_scale_and_position",
]
def knn(x: Tensor, K: int = 4) -> Tensor:
x_np = x.cpu().numpy()
model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
distances, _ = model.kneighbors(x_np)
return torch.from_numpy(distances).to(x)
def rgb_to_sh(rgb: Tensor) -> Tensor:
C0 = 0.28209479177387814
return (rgb - 0.5) / C0
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def splat2ply_bytes(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
shN: torch.Tensor,
) -> bytes:
num_splats = means.shape[0]
buffer = BytesIO()
# Write PLY header
buffer.write(b"ply\n")
buffer.write(b"format binary_little_endian 1.0\n")
buffer.write(f"element vertex {num_splats}\n".encode())
buffer.write(b"property float x\n")
buffer.write(b"property float y\n")
buffer.write(b"property float z\n")
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
buffer.write(f"property float {prefix}_{j}\n".encode())
buffer.write(b"property float opacity\n")
for i in range(scales.shape[1]):
buffer.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
buffer.write(f"property float rot_{i}\n".encode())
buffer.write(b"end_header\n")
# Concatenate all tensors in the correct order
splat_data = torch.cat(
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
)
# Ensure correct dtype
splat_data = splat_data.to(torch.float32)
# Write binary data
float_dtype = np.dtype(np.float32).newbyteorder("<")
buffer.write(
splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
)
return buffer.getvalue()
def export_splats(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
shN: torch.Tensor,
format: Literal["ply"] = "ply",
save_to: Optional[str] = None,
) -> bytes:
"""Export a Gaussian Splats model to bytes in PLY file format."""
total_splats = means.shape[0]
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
assert quats.shape == (
total_splats,
4,
), "Quaternions must be of shape (N, 4)"
assert opacities.shape == (
total_splats,
), "Opacities must be of shape (N,)"
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
assert (
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
), f"shN must be of shape (N, K, 3), got {shN.shape}"
# Reshape spherical harmonics
sh0 = sh0.squeeze(1) # Shape (N, 3)
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
# Check for NaN or Inf values
invalid_mask = (
torch.isnan(means).any(dim=1)
| torch.isinf(means).any(dim=1)
| torch.isnan(scales).any(dim=1)
| torch.isinf(scales).any(dim=1)
| torch.isnan(quats).any(dim=1)
| torch.isinf(quats).any(dim=1)
| torch.isnan(opacities).any(dim=0)
| torch.isinf(opacities).any(dim=0)
| torch.isnan(sh0).any(dim=1)
| torch.isinf(sh0).any(dim=1)
| torch.isnan(shN).any(dim=1)
| torch.isinf(shN).any(dim=1)
)
# Filter out invalid entries
valid_mask = ~invalid_mask
means = means[valid_mask]
scales = scales[valid_mask]
quats = quats[valid_mask]
opacities = opacities[valid_mask]
sh0 = sh0[valid_mask]
shN = shN[valid_mask]
if format == "ply":
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
else:
raise ValueError(f"Unsupported format: {format}")
if save_to:
with open(save_to, "wb") as binary_file:
binary_file.write(data)
return data
def create_splats_with_optimizers(
points: np.ndarray = None,
points_rgb: np.ndarray = None,
init_num_pts: int = 100_000,
init_extent: float = 3.0,
init_opacity: float = 0.1,
init_scale: float = 1.0,
means_lr: float = 1.6e-4,
scales_lr: float = 5e-3,
opacities_lr: float = 5e-2,
quats_lr: float = 1e-3,
sh0_lr: float = 2.5e-3,
shN_lr: float = 2.5e-3 / 20,
scene_scale: float = 1.0,
sh_degree: int = 3,
sparse_grad: bool = False,
visible_adam: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
device: str = "cuda",
world_rank: int = 0,
world_size: int = 1,
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
if points is not None and points_rgb is not None:
points = torch.from_numpy(points).float()
rgbs = torch.from_numpy(points_rgb / 255.0).float()
else:
points = (
init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
)
rgbs = torch.rand((init_num_pts, 3))
# Initialize the GS size to be the average dist of the 3 nearest neighbors
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
dist_avg = torch.sqrt(dist2_avg)
scales = (
torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
) # [N, 3]
# Distribute the GSs to different ranks (also works for single rank)
points = points[world_rank::world_size]
rgbs = rgbs[world_rank::world_size]
scales = scales[world_rank::world_size]
N = points.shape[0]
quats = torch.rand((N, 4)) # [N, 4]
opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
params = [
# name, value, lr
("means", torch.nn.Parameter(points), means_lr * scene_scale),
("scales", torch.nn.Parameter(scales), scales_lr),
("quats", torch.nn.Parameter(quats), quats_lr),
("opacities", torch.nn.Parameter(opacities), opacities_lr),
]
if feature_dim is None:
# color is SH coefficients.
colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
colors[:, 0, :] = rgb_to_sh(rgbs)
params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
else:
# features will be used for appearance and view-dependent shading
features = torch.rand(N, feature_dim) # [N, feature_dim]
params.append(("features", torch.nn.Parameter(features), sh0_lr))
colors = torch.logit(rgbs) # [N, 3]
params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
# Scale learning rate based on batch size, reference:
# https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
# Note that this would not make the training exactly equivalent, see
# https://arxiv.org/pdf/2402.18824v1
BS = batch_size * world_size
optimizer_class = None
if sparse_grad:
optimizer_class = torch.optim.SparseAdam
elif visible_adam:
optimizer_class = SelectiveAdam
else:
optimizer_class = torch.optim.Adam
optimizers = {
name: optimizer_class(
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
eps=1e-15 / math.sqrt(BS),
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
)
for name, _, lr in params
}
return splats, optimizers
def compute_pinhole_intrinsics(
image_w: int, image_h: int, fov_deg: float
) -> np.ndarray:
fov_rad = np.deg2rad(fov_deg)
fx = image_w / (2 * np.tan(fov_rad / 2))
fy = fx # assuming square pixels
cx = image_w / 2
cy = image_h / 2
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
return K
def resize_pinhole_intrinsics(
raw_K: np.ndarray | torch.Tensor,
raw_hw: tuple[int, int],
new_hw: tuple[int, int],
) -> np.ndarray:
raw_h, raw_w = raw_hw
new_h, new_w = new_hw
scale_x = new_w / raw_w
scale_y = new_h / raw_h
new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
new_K[0, 0] *= scale_x # fx
new_K[0, 2] *= scale_x # cx
new_K[1, 1] *= scale_y # fy
new_K[1, 2] *= scale_y # cy
return new_K
def restore_scene_scale_and_position(
real_height: float, mesh_path: str, gs_path: str
) -> None:
"""Scales a mesh and corresponding GS model to match a given real-world height.
Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
applies scaling and vertical alignment, and updates both the mesh and GS model.
Args:
real_height (float): Target real-world height among Z axis.
mesh_path (str): Path to the input mesh file.
gs_path (str): Path to the Gaussian Splatting model file.
"""
mesh = trimesh.load(mesh_path)
z_min = np.percentile(mesh.vertices[:, 1], 1)
z_max = np.percentile(mesh.vertices[:, 1], 99)
height = z_max - z_min
scale = real_height / height
rot = Rotation.from_quat([0, 1, 0, 0])
mesh.vertices = rot.apply(mesh.vertices)
mesh.vertices[:, 1] -= z_min
mesh.vertices *= scale
mesh.export(mesh_path)
gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
gs_model = gs_model.get_gaussians(
instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
)
gs_model.rescale(scale)
gs_model.save_to_ply(gs_path)