# Project EmbodiedGen # # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from embodied_gen.utils.monkey_patches import monkey_patch_pano2room monkey_patch_pano2room() import os import cv2 import numpy as np import torch import trimesh from equilib import cube2equi, equi2pers from kornia.morphology import dilation from PIL import Image from embodied_gen.models.sr_model import ImageRealESRGAN from embodied_gen.utils.config import Pano2MeshSRConfig from embodied_gen.utils.gaussian import compute_pinhole_intrinsics from embodied_gen.utils.log import logger from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import ( PanoFusionDistancePredictor, ) from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter from thirdparty.pano2room.modules.mesh_fusion.render import ( features_to_world_space_mesh, render_mesh, ) from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool from thirdparty.pano2room.utils.camera_utils import gen_pano_rays from thirdparty.pano2room.utils.functions import ( depth_to_distance, get_cubemap_views_world_to_cam, resize_image_with_aspect_ratio, rot_z_world_to_cam, tensor_to_pil, ) class Pano2MeshSRPipeline: """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement. This class integrates several key components including: - Depth estimation from RGB panorama - Inpainting of missing regions under offsets - RGB-D to mesh conversion - Multi-view mesh repair - 3D Gaussian Splatting (3DGS) dataset generation Args: config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters. Example: ```python pipeline = Pano2MeshSRPipeline(config) pipeline(pano_image='example.png', output_dir='./output') ``` """ def __init__(self, config: Pano2MeshSRConfig) -> None: self.cfg = config self.device = config.device # Init models. self.inpainter = PanoPersFusionInpainter(save_path=None) self.geo_predictor = PanoJointPredictor(save_path=None) self.pano_fusion_distance_predictor = PanoFusionDistancePredictor() self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor) # Init poses. cubemap_w2cs = get_cubemap_views_world_to_cam() self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs] self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir) kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, self.cfg.kernel_size ) self.kernel = torch.from_numpy(kernel).float().to(self.device) def init_mesh_params(self) -> None: torch.set_default_device(self.device) self.inpaint_mask = torch.ones( (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool ) self.vertices = torch.empty((3, 0), requires_grad=False) self.colors = torch.empty((3, 0), requires_grad=False) self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False) @staticmethod def read_camera_pose_file(filepath: str) -> np.ndarray: with open(filepath, "r") as f: values = [float(num) for line in f for num in line.split()] return np.array(values).reshape(4, 4) def load_camera_poses( self, trajectory_dir: str ) -> tuple[np.ndarray, list[torch.Tensor]]: pose_filenames = sorted( [ fname for fname in os.listdir(trajectory_dir) if fname.startswith("camera_pose") ] ) pano_pose_world = None relative_poses = [] for idx, filename in enumerate(pose_filenames): pose_path = os.path.join(trajectory_dir, filename) pose_matrix = self.read_camera_pose_file(pose_path) if pano_pose_world is None: pano_pose_world = pose_matrix.copy() pano_pose_world[0, 3] += self.cfg.pano_center_offset[0] pano_pose_world[2, 3] += self.cfg.pano_center_offset[1] # Use different reference for the first 6 cubemap views reference_pose = pose_matrix if idx < 6 else pano_pose_world relative_matrix = pose_matrix @ np.linalg.inv(reference_pose) relative_matrix[0:2, :] *= -1 # flip_xy relative_matrix = ( relative_matrix @ rot_z_world_to_cam(180).cpu().numpy() ) relative_matrix[:3, 3] *= self.cfg.pose_scale relative_matrix = torch.tensor( relative_matrix, dtype=torch.float32 ) relative_poses.append(relative_matrix) return relative_poses def load_inpaint_poses( self, poses: torch.Tensor ) -> dict[int, torch.Tensor]: inpaint_poses = dict() sampled_views = poses[:: self.cfg.inpaint_frame_stride] init_pose = torch.eye(4) for idx, w2c_tensor in enumerate(sampled_views): w2c = w2c_tensor.cpu().numpy().astype(np.float32) c2w = np.linalg.inv(w2c) pose_tensor = init_pose.clone() pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3]) pose_tensor[:3, 3] *= -1 inpaint_poses[idx] = pose_tensor.to(self.device) return inpaint_poses def project(self, world_to_cam: torch.Tensor): ( project_image, project_depth, inpaint_mask, _, z_buf, mesh, ) = render_mesh( vertices=self.vertices, faces=self.faces, vertex_features=self.colors, H=self.cfg.cubemap_h, W=self.cfg.cubemap_w, fov_in_degrees=self.cfg.fov, RT=world_to_cam, blur_radius=self.cfg.blur_radius, faces_per_pixel=self.cfg.faces_per_pixel, ) project_image = project_image * ~inpaint_mask return project_image[:3, ...], inpaint_mask, project_depth def render_pano(self, pose: torch.Tensor): cubemap_list = [] for cubemap_pose in self.cubemap_w2cs: project_pose = cubemap_pose @ pose rgb, inpaint_mask, depth = self.project(project_pose) distance_map = depth_to_distance(depth[None, ...]) mask = inpaint_mask[None, ...] cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0)) # Set default tensor type for CPU operation in cube2equi with torch.device("cpu"): pano_rgbd = cube2equi( cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w ) pano_rgb = pano_rgbd[:3, :, :] pano_depth = pano_rgbd[3:4, :, :].squeeze(0) pano_mask = pano_rgbd[4:, :, :].squeeze(0) return pano_rgb, pano_depth, pano_mask def rgbd_to_mesh( self, rgb: torch.Tensor, depth: torch.Tensor, inpaint_mask: torch.Tensor, world_to_cam: torch.Tensor = None, using_distance_map: bool = True, ) -> None: if world_to_cam is None: world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device) if inpaint_mask.sum() == 0: return vertices, faces, colors = features_to_world_space_mesh( colors=rgb.squeeze(0), depth=depth, fov_in_degrees=self.cfg.fov, world_to_cam=world_to_cam, mask=inpaint_mask, faces=self.faces, vertices=self.vertices, using_distance_map=using_distance_map, edge_threshold=0.05, ) faces += self.vertices.shape[1] self.vertices = torch.cat([self.vertices, vertices], dim=1) self.colors = torch.cat([self.colors, colors], dim=1) self.faces = torch.cat([self.faces, faces], dim=1) def get_edge_image_by_depth( self, depth: torch.Tensor, dilate_iter: int = 1 ) -> np.ndarray: if isinstance(depth, torch.Tensor): depth = depth.cpu().detach().numpy() gray = (depth / depth.max() * 255).astype(np.uint8) edges = cv2.Canny(gray, 60, 150) if dilate_iter > 0: kernel = np.ones((3, 3), np.uint8) edges = cv2.dilate(edges, kernel, iterations=dilate_iter) return edges def mesh_repair_by_greedy_view_selection( self, pose_dict: dict[str, torch.Tensor], output_dir: str ) -> list: inpainted_panos_w_pose = [] while len(pose_dict) > 0: logger.info(f"Repairing mesh left rounds {len(pose_dict)}") sampled_views = [] for key, pose in pose_dict.items(): pano_rgb, pano_distance, pano_mask = self.render_pano(pose) completeness = torch.sum(1 - pano_mask) / (pano_mask.numel()) sampled_views.append((key, completeness.item(), pose)) if len(sampled_views) == 0: break # Find inpainting with least view completeness. sampled_views = sorted(sampled_views, key=lambda x: x[1]) key, _, pose = sampled_views[len(sampled_views) * 2 // 3] pose_dict.pop(key) pano_rgb, pano_distance, pano_mask = self.render_pano(pose) colors = pano_rgb.permute(1, 2, 0).clone() distances = pano_distance.unsqueeze(-1).clone() pano_inpaint_mask = pano_mask.clone() init_pose = pose.clone() normals = None if pano_inpaint_mask.min().item() < 0.5: colors, distances, normals = self.inpaint_panorama( idx=key, colors=colors, distances=distances, pano_mask=pano_inpaint_mask, ) init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = ( -pose[0, 3], pose[2, 3], 0, ) rays = gen_pano_rays( init_pose, self.cfg.pano_h, self.cfg.pano_w ) conflict_mask = self.sup_pool.geo_check( rays, distances.unsqueeze(-1) ) # 0 is conflict, 1 not conflict pano_inpaint_mask *= conflict_mask self.rgbd_to_mesh( colors.permute(2, 0, 1), distances, pano_inpaint_mask, world_to_cam=pose, ) self.sup_pool.register_sup_info( pose=init_pose, mask=pano_inpaint_mask.clone(), rgb=colors, distance=distances.unsqueeze(-1), normal=normals, ) colors = colors.permute(2, 0, 1).unsqueeze(0) inpainted_panos_w_pose.append([colors, pose]) if self.cfg.visualize: from embodied_gen.data.utils import DiffrastRender tensor_to_pil(pano_rgb.unsqueeze(0)).save( f"{output_dir}/rendered_pano_{key}.jpg" ) tensor_to_pil(colors).save( f"{output_dir}/inpainted_pano_{key}.jpg" ) norm_depth = DiffrastRender.normalize_map_by_mask( distances, torch.ones_like(distances) ) heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) Image.fromarray(heatmap).save( f"{output_dir}/inpainted_depth_{key}.png" ) return inpainted_panos_w_pose def inpaint_panorama( self, idx: int, colors: torch.Tensor, distances: torch.Tensor, pano_mask: torch.Tensor, ) -> tuple[torch.Tensor]: mask = (pano_mask[None, ..., None] > 0.5).float() mask = mask.permute(0, 3, 1, 2) mask = dilation(mask, kernel=self.kernel) mask = mask[0, 0, ..., None] # hwc inpainted_img = self.inpainter.inpaint(idx, colors, mask) inpainted_img = colors * (1 - mask) + inpainted_img * mask inpainted_distances, inpainted_normals = self.geo_predictor( idx, inpainted_img, distances[..., None], mask=mask, reg_loss_weight=0.0, normal_loss_weight=5e-2, normal_tv_loss_weight=5e-2, ) return inpainted_img, inpainted_distances.squeeze(), inpainted_normals def preprocess_pano( self, image: Image.Image | str ) -> tuple[torch.Tensor, torch.Tensor]: if isinstance(image, str): image = Image.open(image) image = image.convert("RGB") if image.size[0] < image.size[1]: image = image.transpose(Image.TRANSPOSE) image = resize_image_with_aspect_ratio(image, self.cfg.pano_w) image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255 image_rgb = image_rgb.to(self.device) image_depth = self.pano_fusion_distance_predictor.predict( image_rgb.permute(1, 2, 0) ) image_depth = ( image_depth / image_depth.max() * self.cfg.depth_scale_factor ) return image_rgb, image_depth def pano_to_perpective( self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float ) -> torch.Tensor: rots = dict( roll=0, pitch=pitch, yaw=yaw, ) perspective = equi2pers( equi=pano_image.squeeze(0), rots=rots, height=self.cfg.cubemap_h, width=self.cfg.cubemap_w, fov_x=fov, mode="bilinear", ).unsqueeze(0) return perspective def pano_to_cubemap(self, pano_rgb: torch.Tensor): # Define six canonical cube directions in (pitch, yaw) directions = [ (0, 0), (0, 1.5 * np.pi), (0, 1.0 * np.pi), (0, 0.5 * np.pi), (-0.5 * np.pi, 0), (0.5 * np.pi, 0), ] cubemaps_rgb = [] for pitch, yaw in directions: rgb_view = self.pano_to_perpective( pano_rgb, pitch, yaw, fov=self.cfg.fov ) cubemaps_rgb.append(rgb_view.cpu()) return cubemaps_rgb def save_mesh(self, output_path: str) -> None: vertices_np = self.vertices.T.cpu().numpy() colors_np = self.colors.T.cpu().numpy() faces_np = self.faces.T.cpu().numpy() mesh = trimesh.Trimesh( vertices=vertices_np, faces=faces_np, vertex_colors=colors_np ) mesh.export(output_path) def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: pose = mesh_pose.clone() pose[0, :] *= -1 pose[1, :] *= -1 Rw2c = pose[:3, :3].cpu().numpy() Tw2c = pose[:3, 3:].cpu().numpy() yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) Rc2w = (yz_reverse @ Rw2c).T Tc2w = -(Rc2w @ yz_reverse @ Tw2c) c2w = np.concatenate((Rc2w, Tc2w), axis=1) c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0) return c2w def __call__(self, pano_image: Image.Image | str, output_dir: str): self.init_mesh_params() pano_rgb, pano_depth = self.preprocess_pano(pano_image) self.sup_pool = SupInfoPool() self.sup_pool.register_sup_info( pose=torch.eye(4).to(self.device), mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]), rgb=pano_rgb.permute(1, 2, 0), distance=pano_depth[..., None], ) self.sup_pool.gen_occ_grid(res=256) logger.info("Init mesh from pano RGBD image...") depth_edge = self.get_edge_image_by_depth(pano_depth) inpaint_edge_mask = ( ~torch.from_numpy(depth_edge).to(self.device).bool() ) self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask) repair_poses = self.load_inpaint_poses(self.camera_poses) inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection( repair_poses, output_dir ) torch.cuda.empty_cache() torch.set_default_device("cpu") if self.cfg.mesh_file is not None: mesh_path = os.path.join(output_dir, self.cfg.mesh_file) self.save_mesh(mesh_path) if self.cfg.gs_data_file is None: return logger.info(f"Dump data for 3DGS training...") points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8) data = { "points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3) "points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3) "train": [], "eval": [], } image_h = self.cfg.cubemap_h * self.cfg.upscale_factor image_w = self.cfg.cubemap_w * self.cfg.upscale_factor Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov) for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses): cubemaps = self.pano_to_cubemap(pano_img) for i in range(len(cubemaps)): cubemap = tensor_to_pil(cubemaps[i]) cubemap = self.super_model(cubemap) mesh_pose = self.cubemap_w2cs[i] @ pano_pose c2w = self.mesh_pose_to_gs_pose(mesh_pose) data["train"].append( { "camtoworld": c2w.astype(np.float32), "K": Ks.astype(np.float32), "image": np.array(cubemap), "image_h": image_h, "image_w": image_w, "image_id": len(cubemaps) * idx + i, } ) # Camera poses for evaluation. for idx in range(len(self.camera_poses)): c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx]) data["eval"].append( { "camtoworld": c2w.astype(np.float32), "K": Ks.astype(np.float32), "image_h": image_h, "image_w": image_w, "image_id": idx, } ) data_path = os.path.join(output_dir, self.cfg.gs_data_file) torch.save(data, data_path) return if __name__ == "__main__": output_dir = "outputs/bg_v2/test3" input_pano = "apps/assets/example_scene/result_pano.png" config = Pano2MeshSRConfig() pipeline = Pano2MeshSRPipeline(config) pipeline(input_pano, output_dir)