Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
# Part of the code comes from https://github.com/nerfstudio-project/gsplat | |
# Both under the Apache License, Version 2.0. | |
import math | |
import random | |
from io import BytesIO | |
from typing import Dict, Literal, Optional, Tuple | |
import numpy as np | |
import torch | |
import trimesh | |
from gsplat.optimizers import SelectiveAdam | |
from scipy.spatial.transform import Rotation | |
from sklearn.neighbors import NearestNeighbors | |
from torch import Tensor | |
from embodied_gen.models.gs_model import GaussianOperator | |
__all__ = [ | |
"set_random_seed", | |
"export_splats", | |
"create_splats_with_optimizers", | |
"compute_pinhole_intrinsics", | |
"resize_pinhole_intrinsics", | |
"restore_scene_scale_and_position", | |
] | |
def knn(x: Tensor, K: int = 4) -> Tensor: | |
x_np = x.cpu().numpy() | |
model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) | |
distances, _ = model.kneighbors(x_np) | |
return torch.from_numpy(distances).to(x) | |
def rgb_to_sh(rgb: Tensor) -> Tensor: | |
C0 = 0.28209479177387814 | |
return (rgb - 0.5) / C0 | |
def set_random_seed(seed: int): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
def splat2ply_bytes( | |
means: torch.Tensor, | |
scales: torch.Tensor, | |
quats: torch.Tensor, | |
opacities: torch.Tensor, | |
sh0: torch.Tensor, | |
shN: torch.Tensor, | |
) -> bytes: | |
num_splats = means.shape[0] | |
buffer = BytesIO() | |
# Write PLY header | |
buffer.write(b"ply\n") | |
buffer.write(b"format binary_little_endian 1.0\n") | |
buffer.write(f"element vertex {num_splats}\n".encode()) | |
buffer.write(b"property float x\n") | |
buffer.write(b"property float y\n") | |
buffer.write(b"property float z\n") | |
for i, data in enumerate([sh0, shN]): | |
prefix = "f_dc" if i == 0 else "f_rest" | |
for j in range(data.shape[1]): | |
buffer.write(f"property float {prefix}_{j}\n".encode()) | |
buffer.write(b"property float opacity\n") | |
for i in range(scales.shape[1]): | |
buffer.write(f"property float scale_{i}\n".encode()) | |
for i in range(quats.shape[1]): | |
buffer.write(f"property float rot_{i}\n".encode()) | |
buffer.write(b"end_header\n") | |
# Concatenate all tensors in the correct order | |
splat_data = torch.cat( | |
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1 | |
) | |
# Ensure correct dtype | |
splat_data = splat_data.to(torch.float32) | |
# Write binary data | |
float_dtype = np.dtype(np.float32).newbyteorder("<") | |
buffer.write( | |
splat_data.detach().cpu().numpy().astype(float_dtype).tobytes() | |
) | |
return buffer.getvalue() | |
def export_splats( | |
means: torch.Tensor, | |
scales: torch.Tensor, | |
quats: torch.Tensor, | |
opacities: torch.Tensor, | |
sh0: torch.Tensor, | |
shN: torch.Tensor, | |
format: Literal["ply"] = "ply", | |
save_to: Optional[str] = None, | |
) -> bytes: | |
"""Export a Gaussian Splats model to bytes in PLY file format.""" | |
total_splats = means.shape[0] | |
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)" | |
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)" | |
assert quats.shape == ( | |
total_splats, | |
4, | |
), "Quaternions must be of shape (N, 4)" | |
assert opacities.shape == ( | |
total_splats, | |
), "Opacities must be of shape (N,)" | |
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)" | |
assert ( | |
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3 | |
), f"shN must be of shape (N, K, 3), got {shN.shape}" | |
# Reshape spherical harmonics | |
sh0 = sh0.squeeze(1) # Shape (N, 3) | |
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3) | |
# Check for NaN or Inf values | |
invalid_mask = ( | |
torch.isnan(means).any(dim=1) | |
| torch.isinf(means).any(dim=1) | |
| torch.isnan(scales).any(dim=1) | |
| torch.isinf(scales).any(dim=1) | |
| torch.isnan(quats).any(dim=1) | |
| torch.isinf(quats).any(dim=1) | |
| torch.isnan(opacities).any(dim=0) | |
| torch.isinf(opacities).any(dim=0) | |
| torch.isnan(sh0).any(dim=1) | |
| torch.isinf(sh0).any(dim=1) | |
| torch.isnan(shN).any(dim=1) | |
| torch.isinf(shN).any(dim=1) | |
) | |
# Filter out invalid entries | |
valid_mask = ~invalid_mask | |
means = means[valid_mask] | |
scales = scales[valid_mask] | |
quats = quats[valid_mask] | |
opacities = opacities[valid_mask] | |
sh0 = sh0[valid_mask] | |
shN = shN[valid_mask] | |
if format == "ply": | |
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN) | |
else: | |
raise ValueError(f"Unsupported format: {format}") | |
if save_to: | |
with open(save_to, "wb") as binary_file: | |
binary_file.write(data) | |
return data | |
def create_splats_with_optimizers( | |
points: np.ndarray = None, | |
points_rgb: np.ndarray = None, | |
init_num_pts: int = 100_000, | |
init_extent: float = 3.0, | |
init_opacity: float = 0.1, | |
init_scale: float = 1.0, | |
means_lr: float = 1.6e-4, | |
scales_lr: float = 5e-3, | |
opacities_lr: float = 5e-2, | |
quats_lr: float = 1e-3, | |
sh0_lr: float = 2.5e-3, | |
shN_lr: float = 2.5e-3 / 20, | |
scene_scale: float = 1.0, | |
sh_degree: int = 3, | |
sparse_grad: bool = False, | |
visible_adam: bool = False, | |
batch_size: int = 1, | |
feature_dim: Optional[int] = None, | |
device: str = "cuda", | |
world_rank: int = 0, | |
world_size: int = 1, | |
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: | |
if points is not None and points_rgb is not None: | |
points = torch.from_numpy(points).float() | |
rgbs = torch.from_numpy(points_rgb / 255.0).float() | |
else: | |
points = ( | |
init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) | |
) | |
rgbs = torch.rand((init_num_pts, 3)) | |
# Initialize the GS size to be the average dist of the 3 nearest neighbors | |
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] | |
dist_avg = torch.sqrt(dist2_avg) | |
scales = ( | |
torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) | |
) # [N, 3] | |
# Distribute the GSs to different ranks (also works for single rank) | |
points = points[world_rank::world_size] | |
rgbs = rgbs[world_rank::world_size] | |
scales = scales[world_rank::world_size] | |
N = points.shape[0] | |
quats = torch.rand((N, 4)) # [N, 4] | |
opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] | |
params = [ | |
# name, value, lr | |
("means", torch.nn.Parameter(points), means_lr * scene_scale), | |
("scales", torch.nn.Parameter(scales), scales_lr), | |
("quats", torch.nn.Parameter(quats), quats_lr), | |
("opacities", torch.nn.Parameter(opacities), opacities_lr), | |
] | |
if feature_dim is None: | |
# color is SH coefficients. | |
colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] | |
colors[:, 0, :] = rgb_to_sh(rgbs) | |
params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr)) | |
params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr)) | |
else: | |
# features will be used for appearance and view-dependent shading | |
features = torch.rand(N, feature_dim) # [N, feature_dim] | |
params.append(("features", torch.nn.Parameter(features), sh0_lr)) | |
colors = torch.logit(rgbs) # [N, 3] | |
params.append(("colors", torch.nn.Parameter(colors), sh0_lr)) | |
splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) | |
# Scale learning rate based on batch size, reference: | |
# https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ | |
# Note that this would not make the training exactly equivalent, see | |
# https://arxiv.org/pdf/2402.18824v1 | |
BS = batch_size * world_size | |
optimizer_class = None | |
if sparse_grad: | |
optimizer_class = torch.optim.SparseAdam | |
elif visible_adam: | |
optimizer_class = SelectiveAdam | |
else: | |
optimizer_class = torch.optim.Adam | |
optimizers = { | |
name: optimizer_class( | |
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], | |
eps=1e-15 / math.sqrt(BS), | |
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero. | |
betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), | |
) | |
for name, _, lr in params | |
} | |
return splats, optimizers | |
def compute_pinhole_intrinsics( | |
image_w: int, image_h: int, fov_deg: float | |
) -> np.ndarray: | |
fov_rad = np.deg2rad(fov_deg) | |
fx = image_w / (2 * np.tan(fov_rad / 2)) | |
fy = fx # assuming square pixels | |
cx = image_w / 2 | |
cy = image_h / 2 | |
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) | |
return K | |
def resize_pinhole_intrinsics( | |
raw_K: np.ndarray | torch.Tensor, | |
raw_hw: tuple[int, int], | |
new_hw: tuple[int, int], | |
) -> np.ndarray: | |
raw_h, raw_w = raw_hw | |
new_h, new_w = new_hw | |
scale_x = new_w / raw_w | |
scale_y = new_h / raw_h | |
new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone() | |
new_K[0, 0] *= scale_x # fx | |
new_K[0, 2] *= scale_x # cx | |
new_K[1, 1] *= scale_y # fy | |
new_K[1, 2] *= scale_y # cy | |
return new_K | |
def restore_scene_scale_and_position( | |
real_height: float, mesh_path: str, gs_path: str | |
) -> None: | |
"""Scales a mesh and corresponding GS model to match a given real-world height. | |
Uses the 1st and 99th percentile of mesh Z-axis to estimate height, | |
applies scaling and vertical alignment, and updates both the mesh and GS model. | |
Args: | |
real_height (float): Target real-world height among Z axis. | |
mesh_path (str): Path to the input mesh file. | |
gs_path (str): Path to the Gaussian Splatting model file. | |
""" | |
mesh = trimesh.load(mesh_path) | |
z_min = np.percentile(mesh.vertices[:, 1], 1) | |
z_max = np.percentile(mesh.vertices[:, 1], 99) | |
height = z_max - z_min | |
scale = real_height / height | |
rot = Rotation.from_quat([0, 1, 0, 0]) | |
mesh.vertices = rot.apply(mesh.vertices) | |
mesh.vertices[:, 1] -= z_min | |
mesh.vertices *= scale | |
mesh.export(mesh_path) | |
gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path) | |
gs_model = gs_model.get_gaussians( | |
instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0]) | |
) | |
gs_model.rescale(scale) | |
gs_model.save_to_ply(gs_path) | |