Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
from dataclasses import dataclass, field | |
from typing import List, Optional, Union | |
from gsplat.strategy import DefaultStrategy, MCMCStrategy | |
from typing_extensions import Literal, assert_never | |
__all__ = [ | |
"Pano2MeshSRConfig", | |
"GsplatTrainConfig", | |
] | |
class Pano2MeshSRConfig: | |
mesh_file: str = "mesh_model.ply" | |
gs_data_file: str = "gs_data.pt" | |
device: str = "cuda" | |
blur_radius: int = 0 | |
faces_per_pixel: int = 8 | |
fov: int = 90 | |
pano_w: int = 2048 | |
pano_h: int = 1024 | |
cubemap_w: int = 512 | |
cubemap_h: int = 512 | |
pose_scale: float = 0.6 | |
pano_center_offset: tuple = (-0.2, 0.3) | |
inpaint_frame_stride: int = 20 | |
trajectory_dir: str = "apps/assets/example_scene/camera_trajectory" | |
visualize: bool = False | |
depth_scale_factor: float = 3.4092 | |
kernel_size: tuple = (9, 9) | |
upscale_factor: int = 4 | |
class GsplatTrainConfig: | |
# Path to the .pt files. If provide, it will skip training and run evaluation only. | |
ckpt: Optional[List[str]] = None | |
# Render trajectory path | |
render_traj_path: str = "interp" | |
# Path to the Mip-NeRF 360 dataset | |
data_dir: str = "outputs/bg" | |
# Downsample factor for the dataset | |
data_factor: int = 4 | |
# Directory to save results | |
result_dir: str = "outputs/bg" | |
# Every N images there is a test image | |
test_every: int = 8 | |
# Random crop size for training (experimental) | |
patch_size: Optional[int] = None | |
# A global scaler that applies to the scene size related parameters | |
global_scale: float = 1.0 | |
# Normalize the world space | |
normalize_world_space: bool = True | |
# Camera model | |
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" | |
# Port for the viewer server | |
port: int = 8080 | |
# Batch size for training. Learning rates are scaled automatically | |
batch_size: int = 1 | |
# A global factor to scale the number of training steps | |
steps_scaler: float = 1.0 | |
# Number of training steps | |
max_steps: int = 30_000 | |
# Steps to evaluate the model | |
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) | |
# Steps to save the model | |
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) | |
# Whether to save ply file (storage size can be large) | |
save_ply: bool = True | |
# Steps to save the model as ply | |
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) | |
# Whether to disable video generation during training and evaluation | |
disable_video: bool = False | |
# Initial number of GSs. Ignored if using sfm | |
init_num_pts: int = 100_000 | |
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm | |
init_extent: float = 3.0 | |
# Degree of spherical harmonics | |
sh_degree: int = 1 | |
# Turn on another SH degree every this steps | |
sh_degree_interval: int = 1000 | |
# Initial opacity of GS | |
init_opa: float = 0.1 | |
# Initial scale of GS | |
init_scale: float = 1.0 | |
# Weight for SSIM loss | |
ssim_lambda: float = 0.2 | |
# Near plane clipping distance | |
near_plane: float = 0.01 | |
# Far plane clipping distance | |
far_plane: float = 1e10 | |
# Strategy for GS densification | |
strategy: Union[DefaultStrategy, MCMCStrategy] = field( | |
default_factory=DefaultStrategy | |
) | |
# Use packed mode for rasterization, this leads to less memory usage but slightly slower. | |
packed: bool = False | |
# Use sparse gradients for optimization. (experimental) | |
sparse_grad: bool = False | |
# Use visible adam from Taming 3DGS. (experimental) | |
visible_adam: bool = False | |
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. | |
antialiased: bool = False | |
# Use random background for training to discourage transparency | |
random_bkgd: bool = False | |
# LR for 3D point positions | |
means_lr: float = 1.6e-4 | |
# LR for Gaussian scale factors | |
scales_lr: float = 5e-3 | |
# LR for alpha blending weights | |
opacities_lr: float = 5e-2 | |
# LR for orientation (quaternions) | |
quats_lr: float = 1e-3 | |
# LR for SH band 0 (brightness) | |
sh0_lr: float = 2.5e-3 | |
# LR for higher-order SH (detail) | |
shN_lr: float = 2.5e-3 / 20 | |
# Opacity regularization | |
opacity_reg: float = 0.0 | |
# Scale regularization | |
scale_reg: float = 0.0 | |
# Enable depth loss. (experimental) | |
depth_loss: bool = False | |
# Weight for depth loss | |
depth_lambda: float = 1e-2 | |
# Dump information to tensorboard every this steps | |
tb_every: int = 200 | |
# Save training images to tensorboard | |
tb_save_image: bool = False | |
lpips_net: Literal["vgg", "alex"] = "alex" | |
# 3DGUT (uncented transform + eval 3D) | |
with_ut: bool = False | |
with_eval3d: bool = False | |
scene_scale: float = 1.0 | |
def adjust_steps(self, factor: float): | |
self.eval_steps = [int(i * factor) for i in self.eval_steps] | |
self.save_steps = [int(i * factor) for i in self.save_steps] | |
self.ply_steps = [int(i * factor) for i in self.ply_steps] | |
self.max_steps = int(self.max_steps * factor) | |
self.sh_degree_interval = int(self.sh_degree_interval * factor) | |
strategy = self.strategy | |
if isinstance(strategy, DefaultStrategy): | |
strategy.refine_start_iter = int( | |
strategy.refine_start_iter * factor | |
) | |
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) | |
strategy.reset_every = int(strategy.reset_every * factor) | |
strategy.refine_every = int(strategy.refine_every * factor) | |
elif isinstance(strategy, MCMCStrategy): | |
strategy.refine_start_iter = int( | |
strategy.refine_start_iter * factor | |
) | |
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) | |
strategy.refine_every = int(strategy.refine_every * factor) | |
else: | |
assert_never(strategy) | |