# Project EmbodiedGen # # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import sys import zipfile import torch from huggingface_hub import hf_hub_download from omegaconf import OmegaConf from PIL import Image from torchvision import transforms def monkey_patch_pano2room(): current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_dir, "../..")) sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room")) from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import ( OmnidataNormalPredictor, ) from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import ( OmnidataPredictor, ) def patched_omni_depth_init(self): self.img_size = 384 self.model = torch.hub.load( 'alexsax/omnidata_models', 'depth_dpt_hybrid_384' ) self.model.eval() self.trans_totensor = transforms.Compose( [ transforms.Resize(self.img_size, interpolation=Image.BILINEAR), transforms.CenterCrop(self.img_size), transforms.Normalize(mean=0.5, std=0.5), ] ) OmnidataPredictor.__init__ = patched_omni_depth_init def patched_omni_normal_init(self): self.img_size = 384 self.model = torch.hub.load( 'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384' ) self.model.eval() self.trans_totensor = transforms.Compose( [ transforms.Resize(self.img_size, interpolation=Image.BILINEAR), transforms.CenterCrop(self.img_size), transforms.Normalize(mean=0.5, std=0.5), ] ) OmnidataNormalPredictor.__init__ = patched_omni_normal_init def patched_panojoint_init(self, save_path=None): self.depth_predictor = OmnidataPredictor() self.normal_predictor = OmnidataNormalPredictor() self.save_path = save_path from modules.geo_predictors import PanoJointPredictor PanoJointPredictor.__init__ = patched_panojoint_init # NOTE: We use gsplat instead. # import depth_diff_gaussian_rasterization_min as ddgr # from dataclasses import dataclass # @dataclass # class PatchedGaussianRasterizationSettings: # image_height: int # image_width: int # tanfovx: float # tanfovy: float # bg: torch.Tensor # scale_modifier: float # viewmatrix: torch.Tensor # projmatrix: torch.Tensor # sh_degree: int # campos: torch.Tensor # prefiltered: bool # debug: bool = False # ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings # disable get_has_ddp_rank print in `BaseInpaintingTrainingModule` os.environ["NODE_RANK"] = "0" from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import ( load_checkpoint, ) from thirdparty.pano2room.modules.inpainters.lama_inpainter import ( LamaInpainter, ) def patched_lama_inpaint_init(self): zip_path = hf_hub_download( repo_id="smartywu/big-lama", filename="big-lama.zip", repo_type="model", ) extract_dir = os.path.splitext(zip_path)[0] if not os.path.exists(extract_dir): os.makedirs(extract_dir, exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_dir) config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml') checkpoint_path = os.path.join( extract_dir, 'big-lama/models/best.ckpt' ) train_config = OmegaConf.load(config_path) train_config.training_model.predict_only = True train_config.visualizer.kind = 'noop' self.model = load_checkpoint( train_config, checkpoint_path, strict=False, map_location='cpu' ) self.model.freeze() LamaInpainter.__init__ = patched_lama_inpaint_init from diffusers import StableDiffusionInpaintPipeline from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import ( SDFTInpainter, ) def patched_sd_inpaint_init(self, subset_name=None): super(SDFTInpainter, self).__init__() pipe = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16, ).to("cuda") pipe.enable_model_cpu_offload() self.inpaint_pipe = pipe SDFTInpainter.__init__ = patched_sd_inpaint_init