xinjie.wang
update
575f14d
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import sys
import zipfile
import torch
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
def monkey_patch_pano2room():
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
OmnidataNormalPredictor,
)
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
OmnidataPredictor,
)
def patched_omni_depth_init(self):
self.img_size = 384
self.model = torch.hub.load(
'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
)
self.model.eval()
self.trans_totensor = transforms.Compose(
[
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
transforms.CenterCrop(self.img_size),
transforms.Normalize(mean=0.5, std=0.5),
]
)
OmnidataPredictor.__init__ = patched_omni_depth_init
def patched_omni_normal_init(self):
self.img_size = 384
self.model = torch.hub.load(
'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
)
self.model.eval()
self.trans_totensor = transforms.Compose(
[
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
transforms.CenterCrop(self.img_size),
transforms.Normalize(mean=0.5, std=0.5),
]
)
OmnidataNormalPredictor.__init__ = patched_omni_normal_init
def patched_panojoint_init(self, save_path=None):
self.depth_predictor = OmnidataPredictor()
self.normal_predictor = OmnidataNormalPredictor()
self.save_path = save_path
from modules.geo_predictors import PanoJointPredictor
PanoJointPredictor.__init__ = patched_panojoint_init
# NOTE: We use gsplat instead.
# import depth_diff_gaussian_rasterization_min as ddgr
# from dataclasses import dataclass
# @dataclass
# class PatchedGaussianRasterizationSettings:
# image_height: int
# image_width: int
# tanfovx: float
# tanfovy: float
# bg: torch.Tensor
# scale_modifier: float
# viewmatrix: torch.Tensor
# projmatrix: torch.Tensor
# sh_degree: int
# campos: torch.Tensor
# prefiltered: bool
# debug: bool = False
# ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
# disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
os.environ["NODE_RANK"] = "0"
from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
load_checkpoint,
)
from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
LamaInpainter,
)
def patched_lama_inpaint_init(self):
zip_path = hf_hub_download(
repo_id="smartywu/big-lama",
filename="big-lama.zip",
repo_type="model",
)
extract_dir = os.path.splitext(zip_path)[0]
if not os.path.exists(extract_dir):
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
checkpoint_path = os.path.join(
extract_dir, 'big-lama/models/best.ckpt'
)
train_config = OmegaConf.load(config_path)
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
self.model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location='cpu'
)
self.model.freeze()
LamaInpainter.__init__ = patched_lama_inpaint_init
from diffusers import StableDiffusionInpaintPipeline
from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
SDFTInpainter,
)
def patched_sd_inpaint_init(self, subset_name=None):
super(SDFTInpainter, self).__init__()
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_model_cpu_offload()
self.inpaint_pipe = pipe
SDFTInpainter.__init__ = patched_sd_inpaint_init