import logging import os import random import time import warnings from dataclasses import dataclass, field from shutil import copy, rmtree import torch import tyro from huggingface_hub import snapshot_download from packaging import version # Suppress warnings warnings.filterwarnings("ignore", category=FutureWarning) logging.getLogger("transformers").setLevel(logging.ERROR) logging.getLogger("diffusers").setLevel(logging.ERROR) # TorchVision monkey patch for >0.16 if version.parse(torch.__version__) >= version.parse("0.16"): import sys import types import torchvision.transforms.functional as TF functional_tensor = types.ModuleType( "torchvision.transforms.functional_tensor" ) functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor from gsplat.distributed import cli from txt2panoimg import Text2360PanoramaImagePipeline from embodied_gen.trainer.gsplat_trainer import ( DefaultStrategy, GsplatTrainConfig, ) from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline from embodied_gen.utils.config import Pano2MeshSRConfig from embodied_gen.utils.gaussian import restore_scene_scale_and_position from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.utils.log import logger from embodied_gen.utils.process_media import is_image_file, parse_text_prompts from embodied_gen.validators.quality_checkers import ( PanoHeightEstimator, PanoImageOccChecker, ) __all__ = [ "generate_pano_image", "entrypoint", ] @dataclass class Scene3DGenConfig: prompts: list[str] # Text desc of indoor room or style reference image. output_dir: str seed: int | None = None real_height: float | None = None # The real height of the room in meters. pano_image_only: bool = False disable_pano_check: bool = False keep_middle_result: bool = False n_retry: int = 7 gs3d: GsplatTrainConfig = field( default_factory=lambda: GsplatTrainConfig( strategy=DefaultStrategy(verbose=True), max_steps=4000, init_opa=0.9, opacity_reg=2e-3, sh_degree=0, means_lr=1e-4, scales_lr=1e-3, ) ) def generate_pano_image( prompt: str, output_path: str, pipeline, seed: int, n_retry: int, checker=None, num_inference_steps: int = 40, ) -> None: for i in range(n_retry): logger.info( f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}" ) if is_image_file(prompt): raise NotImplementedError("Image mode not implemented yet.") else: txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture" inputs = { "prompt": txt_prompt, "num_inference_steps": num_inference_steps, "upscale": False, "seed": seed, } pano_image = pipeline(inputs) pano_image.save(output_path) if checker is None: break flag, response = checker(pano_image) logger.warning(f"{response}, image saved in {output_path}") if flag is True or flag is None: break seed = random.randint(0, 100000) return def entrypoint(*args, **kwargs): cfg = tyro.cli(Scene3DGenConfig) # Init global models. model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage") IMG2PANO_PIPE = Text2360PanoramaImagePipeline( model_path, torch_dtype=torch.float16, device="cuda" ) PANOMESH_CFG = Pano2MeshSRConfig() PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG) PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000]) PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT) prompts = parse_text_prompts(cfg.prompts) for idx, prompt in enumerate(prompts): start_time = time.time() output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}") os.makedirs(output_dir, exist_ok=True) pano_path = os.path.join(output_dir, "pano_image.png") with open(f"{output_dir}/prompt.txt", "w") as f: f.write(prompt) generate_pano_image( prompt, pano_path, IMG2PANO_PIPE, cfg.seed if cfg.seed is not None else random.randint(0, 100000), cfg.n_retry, checker=None if cfg.disable_pano_check else PANO_CHECKER, ) if cfg.pano_image_only: continue logger.info("GEN and REPAIR Mesh from Panorama...") PANO2MESH_PIPE(pano_path, output_dir) logger.info("TRAIN 3DGS from Mesh Init and Cube Image...") cfg.gs3d.data_dir = output_dir cfg.gs3d.result_dir = f"{output_dir}/gaussian" cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler) torch.set_default_device("cpu") # recover default setting. cli(gsplat_entrypoint, cfg.gs3d, verbose=True) # Clean up the middle results. gs_path = ( f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply" ) copy(gs_path, f"{output_dir}/gs_model.ply") video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4" copy(video_path, f"{output_dir}/video.mp4") gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml" copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml") if not cfg.keep_middle_result: rmtree(cfg.gs3d.result_dir, ignore_errors=True) os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}") real_height = ( PANOHEIGHT_ESTOR(pano_path) if cfg.real_height is None else cfg.real_height ) gs_path = os.path.join(output_dir, "gs_model.ply") mesh_path = os.path.join(output_dir, "mesh_model.ply") restore_scene_scale_and_position(real_height, mesh_path, gs_path) elapsed_time = (time.time() - start_time) / 60 logger.info( f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins." ) if __name__ == "__main__": entrypoint()