Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
import os | |
import sys | |
import zipfile | |
import torch | |
from huggingface_hub import hf_hub_download | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from torchvision import transforms | |
def monkey_patch_pano2room(): | |
current_file_path = os.path.abspath(__file__) | |
current_dir = os.path.dirname(current_file_path) | |
sys.path.append(os.path.join(current_dir, "../..")) | |
sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room")) | |
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import ( | |
OmnidataNormalPredictor, | |
) | |
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import ( | |
OmnidataPredictor, | |
) | |
def patched_omni_depth_init(self): | |
self.img_size = 384 | |
self.model = torch.hub.load( | |
'alexsax/omnidata_models', 'depth_dpt_hybrid_384' | |
) | |
self.model.eval() | |
self.trans_totensor = transforms.Compose( | |
[ | |
transforms.Resize(self.img_size, interpolation=Image.BILINEAR), | |
transforms.CenterCrop(self.img_size), | |
transforms.Normalize(mean=0.5, std=0.5), | |
] | |
) | |
OmnidataPredictor.__init__ = patched_omni_depth_init | |
def patched_omni_normal_init(self): | |
self.img_size = 384 | |
self.model = torch.hub.load( | |
'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384' | |
) | |
self.model.eval() | |
self.trans_totensor = transforms.Compose( | |
[ | |
transforms.Resize(self.img_size, interpolation=Image.BILINEAR), | |
transforms.CenterCrop(self.img_size), | |
transforms.Normalize(mean=0.5, std=0.5), | |
] | |
) | |
OmnidataNormalPredictor.__init__ = patched_omni_normal_init | |
def patched_panojoint_init(self, save_path=None): | |
self.depth_predictor = OmnidataPredictor() | |
self.normal_predictor = OmnidataNormalPredictor() | |
self.save_path = save_path | |
from modules.geo_predictors import PanoJointPredictor | |
PanoJointPredictor.__init__ = patched_panojoint_init | |
# NOTE: We use gsplat instead. | |
# import depth_diff_gaussian_rasterization_min as ddgr | |
# from dataclasses import dataclass | |
# @dataclass | |
# class PatchedGaussianRasterizationSettings: | |
# image_height: int | |
# image_width: int | |
# tanfovx: float | |
# tanfovy: float | |
# bg: torch.Tensor | |
# scale_modifier: float | |
# viewmatrix: torch.Tensor | |
# projmatrix: torch.Tensor | |
# sh_degree: int | |
# campos: torch.Tensor | |
# prefiltered: bool | |
# debug: bool = False | |
# ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings | |
# disable get_has_ddp_rank print in `BaseInpaintingTrainingModule` | |
os.environ["NODE_RANK"] = "0" | |
from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import ( | |
load_checkpoint, | |
) | |
from thirdparty.pano2room.modules.inpainters.lama_inpainter import ( | |
LamaInpainter, | |
) | |
def patched_lama_inpaint_init(self): | |
zip_path = hf_hub_download( | |
repo_id="smartywu/big-lama", | |
filename="big-lama.zip", | |
repo_type="model", | |
) | |
extract_dir = os.path.splitext(zip_path)[0] | |
if not os.path.exists(extract_dir): | |
os.makedirs(extract_dir, exist_ok=True) | |
with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extractall(extract_dir) | |
config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml') | |
checkpoint_path = os.path.join( | |
extract_dir, 'big-lama/models/best.ckpt' | |
) | |
train_config = OmegaConf.load(config_path) | |
train_config.training_model.predict_only = True | |
train_config.visualizer.kind = 'noop' | |
self.model = load_checkpoint( | |
train_config, checkpoint_path, strict=False, map_location='cpu' | |
) | |
self.model.freeze() | |
LamaInpainter.__init__ = patched_lama_inpaint_init | |
from diffusers import StableDiffusionInpaintPipeline | |
from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import ( | |
SDFTInpainter, | |
) | |
def patched_sd_inpaint_init(self, subset_name=None): | |
super(SDFTInpainter, self).__init__() | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float16, | |
).to("cuda") | |
pipe.enable_model_cpu_offload() | |
self.inpaint_pipe = pipe | |
SDFTInpainter.__init__ = patched_sd_inpaint_init | |