Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from dataclasses import dataclass | |
from copy import deepcopy | |
HDFS_DIR = "<HDFS_DIR>" # data is stored in an internal HDFS in this project | |
class Options: | |
# Dataset | |
input_res: int = 256 | |
## Camera | |
num_input_views: int = 4 | |
num_views: int = 8 | |
load_even_views: bool = True | |
exclude_topdown_views: bool = False | |
norm_camera: bool = True | |
norm_radius: float = 1.4 # the min distance in GObjaverse (cf. `RichDreamer` Sec. 3.1); only used when `norm_camera` is True | |
fxfy: float = 1422.222 / 1024 # for GObjaverse only (https://github.com/modelscope/richdreamer/issues/10#issuecomment-1890870640) | |
## Content | |
load_albedo: bool = False | |
load_normal: bool = True | |
load_coord: bool = True | |
load_mr: bool = False | |
load_canny: bool = False | |
load_depth: bool = False | |
normalize_depth: bool = True # to [0, 1] | |
dataset_name: Literal[ | |
"gobj83k", | |
"gobj265k", | |
"gobj1m", | |
] = "gobj83k" | |
dataset_size: int = None # set later | |
prompt_embed_dir: Optional[str] = None # set later | |
## ParquetDataset | |
file_dir_train: str = f"{HDFS_DIR}/GObjaverse_parquet" | |
file_name_train: str = None # set later | |
file_dir_test: str = "/tmp/test_dataset" | |
file_name_test: str = "GObjaverse-val" | |
dataset_setup_script: str = f"mkdir -p /tmp/test_dataset && hdfs dfs -ls {HDFS_DIR}/GObjaverse_parquet/GObjaverse-val-* | grep '^-' | " + "awk '{print $8}' | xargs -n 1 -P 5 -I {} hdfs dfs -get {} /tmp/test_dataset" | |
# GSRecon | |
input_albedo: bool = False | |
input_normal: bool = True | |
input_coord: bool = True | |
input_mr: bool = False | |
## Transformer | |
llama_style: bool = True | |
patch_size: int = 8 | |
dim: int = 512 | |
num_blocks: int = 12 | |
num_heads: int = 8 | |
grad_checkpoint: bool = True | |
## Rendering | |
render_type: Literal[ | |
"default", | |
"deferred", | |
] = "default" | |
deferred_bp_patch_size: int = 64 | |
znear: float = 0.01 | |
zfar: float = 100. | |
scale_min: float = 0.0005 | |
scale_max: float = 0.02 | |
# Elevation estimation | |
elevest_backbone_name: Literal[ | |
"dinov2_vits14_reg", | |
"dinov2_vitb14_reg", | |
"dinov2_vitl14_reg", | |
] = "dinov2_vitb14_reg" | |
freeze_backbone: bool = False | |
ele_min: float = -40. # actual min: -30. | |
ele_max: float = 10. # actual max: 5. | |
elevest_num_classes: int = 25 | |
elevest_reg_weight: float = 1. | |
# GSVAE | |
vae_from_scratch: bool = False | |
use_tinyae: bool = False | |
freeze_encoder: bool = False | |
use_tiny_decoder: bool = False | |
scaling_factor: Optional[float] = None | |
shift_factor: Optional[float] = None | |
# GSDiff | |
pretrained_model_name_or_path: Literal[ | |
"stable-diffusion-v1-5/stable-diffusion-v1-5", | |
"stabilityai/stable-diffusion-2-1", | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
"PixArt-alpha/PixArt-XL-2-256x256", | |
"PixArt-alpha/PixArt-XL-2-512x512", | |
"PixArt-alpha/PixArt-XL-2-1024-MS", | |
"PixArt-alpha/PixArt-Sigma-XL-2-256x256", | |
"PixArt-alpha/PixArt-Sigma-XL-2-512-MS", | |
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", | |
"stabilityai/stable-diffusion-3-medium-diffusers", | |
"stabilityai/stable-diffusion-3.5-medium", | |
"stabilityai/stable-diffusion-3.5-large", | |
"black-forest-labs/FLUX.1-dev", | |
"madebyollin/sdxl-vae-fp16-fix", | |
"lambdalabs/sd-image-variations-diffusers", | |
"stabilityai/stable-diffusion-2-1-unclip", | |
"chenguolin/sv3d-diffusers", | |
] = "stable-diffusion-v1-5/stable-diffusion-v1-5" | |
load_fp16vae_for_sdxl: bool = True | |
## Config | |
from_scratch: bool = False | |
cfg_dropout_prob: float = 0.05 # actual prob is x2; see the training code | |
snr_gamma: float = 0. # Min-SNR trick; `0.` menas not used | |
num_inference_steps: int = 20 | |
noise_scheduler_type: Literal[ | |
"ddim", | |
"dpmsolver++", | |
"sde-dpmsolver++", | |
] = "dpmsolver++" | |
prediction_type: Optional[str] = None # `None` means using default prediction type | |
beta_schedule: Optional[str] = None # `None` means using the default beta schedule | |
edm_style_training: bool = False # EDM scheduling; cf. https://arxiv.org/pdf/2206.00364 | |
common_tricks: bool = True # cf. https://arxiv.org/pdf/2305.08891 (including: 1. trailing timestep spacing, 2. rescaling to zero snr) | |
### SD3; cf. https://arxiv.org/pdf/2403.03206 | |
weighting_scheme: Literal[ | |
"sigma_sqrt", | |
"logit_normal", | |
"mode", | |
"cosmap", | |
] = "logit_normal" | |
logit_mean: float = 0. | |
logit_std: float = 1. | |
mode_scale: float = 1.29 | |
precondition_outputs: bool = False # whether prediction x_0 | |
## Model | |
trainable_modules: Optional[str] = None # train all parameters if None | |
name_lr_mult: Optional[str] = None | |
lr_mult: float = 1. | |
### Conditioning | |
zero_init_conv_in: bool = True # whether zero_init new conv_in params | |
view_concat_condition: bool = False # `True` for image-cond | |
input_concat_plucker: bool = True | |
input_concat_binary_mask: bool = False | |
num_cond_views: int = 1 | |
### Inference | |
init_std: float = 0. # cf. Instant3D inference trick, `0.` means not used | |
init_noise_strength: float = 0.98 # used with `init_std`; cf. Instant3D inference trick, `1.` means not used | |
init_bg: float = 0. # used with `init_std` and `init_noise_strength`; gray background for the initialization | |
### ControlNet | |
controlnet_type: Literal[ | |
"normal", | |
"depth", | |
"canny", | |
] = "normal" | |
controlnet_input_channels: int = 3 | |
guess_mode: bool = False | |
controlnet_scale: float = 1. | |
## Rendering loss | |
rendering_loss_prob: float = 0. | |
snr_gamma_rendering: float = 0. # Min-SNR trick for rendering loss; `0.` menas not used | |
# Training | |
chunk_size: int = 1 # chunk size for GSRecon and GSVAE inference to save memory | |
coord_weight: float = 0. # render coords for supervision | |
normal_weight: float = 0. # render normals for supervision | |
recon_weight: float = 1. # GSVAE reconstruction weight | |
render_weight: float = 1. # GSVAE rendering weight | |
diffusion_weight: float = 1. # GSDiff diffusion weight | |
## LPIPS | |
lpips_resize: int = 256 # `0` means no resizing | |
lpips_weight: float = 1. # lpips weight in GSRecon, GSVAE, GSDiff rendering | |
lpips_warmup_start: int = 0 | |
lpips_warmup_end: int = 0 | |
# Visualization | |
vis_pseudo_images: bool = False # decode Gaussian latents by the image decoder | |
vis_coords: bool = False | |
vis_normals: bool = False | |
def __post_init__(self): | |
if self.dataset_name == "gobj83k": | |
self.dataset_size = 83296 | |
self.file_name_train = "GObjaverse-train-280k-83k" | |
elif self.dataset_name == "gobj265k": | |
self.dataset_size = 265232 | |
self.file_name_train = "GObjaverse-train-280k" | |
else: | |
raise ValueError(f"Unknown dataset name: {self.dataset_name}") | |
def _update_opt(opt: Options, **kwargs) -> Options: | |
new_opt = deepcopy(opt) | |
for k, v in kwargs.items(): | |
setattr(new_opt, k, v) | |
return new_opt | |
# Set all options for different tasks and models | |
opt_dict: Dict[str, Options] = {} | |
# GRM | |
opt_dict["gsrecon"] = Options(dataset_name="gobj265k") | |
# Elevation estimation | |
opt_dict["elevest"] = Options( | |
input_res=224, | |
load_even_views=False, | |
exclude_topdown_views=True, | |
load_normal=False, | |
load_coord=False, | |
dataset_name="gobj265k", | |
name_lr_mult="backbone", | |
lr_mult=0.1, | |
) | |
# GSVAE | |
## SD-based | |
opt_dict["gsvae"] = Options(dataset_name="gobj265k") | |
## SDXL-based | |
opt_dict["gsvae_sdxl"] = _update_opt( | |
opt_dict["gsvae"], | |
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0", | |
) | |
opt_dict["gsvae_sdxl_fp16"] = _update_opt( | |
opt_dict["gsvae"], | |
pretrained_model_name_or_path="madebyollin/sdxl-vae-fp16-fix", | |
) | |
## SD3-based | |
opt_dict["gsvae_sd3m"] = _update_opt( | |
opt_dict["gsvae"], | |
pretrained_model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers", | |
) | |
opt_dict["gsvae_sd35m"] = _update_opt( | |
opt_dict["gsvae"], | |
pretrained_model_name_or_path="stabilityai/stable-diffusion-3.5-medium", | |
) | |
# GSDiff | |
## SD15-based | |
opt_dict["gsdiff_sd15"] = Options( | |
prompt_embed_dir="/tmp/GObjaverse_sd15_prompt_embeds", | |
pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5", | |
) | |
## SDXL-based | |
opt_dict["gsdiff_sdxl"] = Options( | |
prompt_embed_dir="/tmp/GObjaverse_sdxl_prompt_embeds", | |
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0", | |
) | |
## PAA-based | |
opt_dict["gsdiff_paa"] = Options( | |
prompt_embed_dir="/tmp/GObjaverse_paa_prompt_embeds", | |
pretrained_model_name_or_path="PixArt-alpha/PixArt-XL-2-512x512", | |
) | |
## PAS-based | |
opt_dict["gsdiff_pas"] = Options( | |
prompt_embed_dir="/tmp/GObjaverse_pas_prompt_embeds", | |
pretrained_model_name_or_path="PixArt-alpha/PixArt-Sigma-XL-2-512-MS", | |
) | |
## SD3-based | |
opt_dict["gsdiff_sd3m"] = Options( | |
prompt_embed_dir="/tmp/GObjaverse_sd3m_prompt_embeds", | |
pretrained_model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers", | |
) | |
opt_dict["gsdiff_sd35m"] = Options( | |
prompt_embed_dir="/tmp/GObjaverse_sd35m_prompt_embeds", | |
pretrained_model_name_or_path="stabilityai/stable-diffusion-3.5-medium", | |
) | |