Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
from embodied_gen.utils.monkey_patches import monkey_patch_pano2room | |
monkey_patch_pano2room() | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
import trimesh | |
from equilib import cube2equi, equi2pers | |
from kornia.morphology import dilation | |
from PIL import Image | |
from embodied_gen.models.sr_model import ImageRealESRGAN | |
from embodied_gen.utils.config import Pano2MeshSRConfig | |
from embodied_gen.utils.gaussian import compute_pinhole_intrinsics | |
from embodied_gen.utils.log import logger | |
from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor | |
from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import ( | |
PanoFusionDistancePredictor, | |
) | |
from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter | |
from thirdparty.pano2room.modules.mesh_fusion.render import ( | |
features_to_world_space_mesh, | |
render_mesh, | |
) | |
from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool | |
from thirdparty.pano2room.utils.camera_utils import gen_pano_rays | |
from thirdparty.pano2room.utils.functions import ( | |
depth_to_distance, | |
get_cubemap_views_world_to_cam, | |
resize_image_with_aspect_ratio, | |
rot_z_world_to_cam, | |
tensor_to_pil, | |
) | |
class Pano2MeshSRPipeline: | |
"""Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement. | |
This class integrates several key components including: | |
- Depth estimation from RGB panorama | |
- Inpainting of missing regions under offsets | |
- RGB-D to mesh conversion | |
- Multi-view mesh repair | |
- 3D Gaussian Splatting (3DGS) dataset generation | |
Args: | |
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters. | |
Example: | |
```python | |
pipeline = Pano2MeshSRPipeline(config) | |
pipeline(pano_image='example.png', output_dir='./output') | |
``` | |
""" | |
def __init__(self, config: Pano2MeshSRConfig) -> None: | |
self.cfg = config | |
self.device = config.device | |
# Init models. | |
self.inpainter = PanoPersFusionInpainter(save_path=None) | |
self.geo_predictor = PanoJointPredictor(save_path=None) | |
self.pano_fusion_distance_predictor = PanoFusionDistancePredictor() | |
self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor) | |
# Init poses. | |
cubemap_w2cs = get_cubemap_views_world_to_cam() | |
self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs] | |
self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir) | |
kernel = cv2.getStructuringElement( | |
cv2.MORPH_ELLIPSE, self.cfg.kernel_size | |
) | |
self.kernel = torch.from_numpy(kernel).float().to(self.device) | |
def init_mesh_params(self) -> None: | |
torch.set_default_device(self.device) | |
self.inpaint_mask = torch.ones( | |
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool | |
) | |
self.vertices = torch.empty((3, 0), requires_grad=False) | |
self.colors = torch.empty((3, 0), requires_grad=False) | |
self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False) | |
def read_camera_pose_file(filepath: str) -> np.ndarray: | |
with open(filepath, "r") as f: | |
values = [float(num) for line in f for num in line.split()] | |
return np.array(values).reshape(4, 4) | |
def load_camera_poses( | |
self, trajectory_dir: str | |
) -> tuple[np.ndarray, list[torch.Tensor]]: | |
pose_filenames = sorted( | |
[ | |
fname | |
for fname in os.listdir(trajectory_dir) | |
if fname.startswith("camera_pose") | |
] | |
) | |
pano_pose_world = None | |
relative_poses = [] | |
for idx, filename in enumerate(pose_filenames): | |
pose_path = os.path.join(trajectory_dir, filename) | |
pose_matrix = self.read_camera_pose_file(pose_path) | |
if pano_pose_world is None: | |
pano_pose_world = pose_matrix.copy() | |
pano_pose_world[0, 3] += self.cfg.pano_center_offset[0] | |
pano_pose_world[2, 3] += self.cfg.pano_center_offset[1] | |
# Use different reference for the first 6 cubemap views | |
reference_pose = pose_matrix if idx < 6 else pano_pose_world | |
relative_matrix = pose_matrix @ np.linalg.inv(reference_pose) | |
relative_matrix[0:2, :] *= -1 # flip_xy | |
relative_matrix = ( | |
relative_matrix @ rot_z_world_to_cam(180).cpu().numpy() | |
) | |
relative_matrix[:3, 3] *= self.cfg.pose_scale | |
relative_matrix = torch.tensor( | |
relative_matrix, dtype=torch.float32 | |
) | |
relative_poses.append(relative_matrix) | |
return relative_poses | |
def load_inpaint_poses( | |
self, poses: torch.Tensor | |
) -> dict[int, torch.Tensor]: | |
inpaint_poses = dict() | |
sampled_views = poses[:: self.cfg.inpaint_frame_stride] | |
init_pose = torch.eye(4) | |
for idx, w2c_tensor in enumerate(sampled_views): | |
w2c = w2c_tensor.cpu().numpy().astype(np.float32) | |
c2w = np.linalg.inv(w2c) | |
pose_tensor = init_pose.clone() | |
pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3]) | |
pose_tensor[:3, 3] *= -1 | |
inpaint_poses[idx] = pose_tensor.to(self.device) | |
return inpaint_poses | |
def project(self, world_to_cam: torch.Tensor): | |
( | |
project_image, | |
project_depth, | |
inpaint_mask, | |
_, | |
z_buf, | |
mesh, | |
) = render_mesh( | |
vertices=self.vertices, | |
faces=self.faces, | |
vertex_features=self.colors, | |
H=self.cfg.cubemap_h, | |
W=self.cfg.cubemap_w, | |
fov_in_degrees=self.cfg.fov, | |
RT=world_to_cam, | |
blur_radius=self.cfg.blur_radius, | |
faces_per_pixel=self.cfg.faces_per_pixel, | |
) | |
project_image = project_image * ~inpaint_mask | |
return project_image[:3, ...], inpaint_mask, project_depth | |
def render_pano(self, pose: torch.Tensor): | |
cubemap_list = [] | |
for cubemap_pose in self.cubemap_w2cs: | |
project_pose = cubemap_pose @ pose | |
rgb, inpaint_mask, depth = self.project(project_pose) | |
distance_map = depth_to_distance(depth[None, ...]) | |
mask = inpaint_mask[None, ...] | |
cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0)) | |
# Set default tensor type for CPU operation in cube2equi | |
with torch.device("cpu"): | |
pano_rgbd = cube2equi( | |
cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w | |
) | |
pano_rgb = pano_rgbd[:3, :, :] | |
pano_depth = pano_rgbd[3:4, :, :].squeeze(0) | |
pano_mask = pano_rgbd[4:, :, :].squeeze(0) | |
return pano_rgb, pano_depth, pano_mask | |
def rgbd_to_mesh( | |
self, | |
rgb: torch.Tensor, | |
depth: torch.Tensor, | |
inpaint_mask: torch.Tensor, | |
world_to_cam: torch.Tensor = None, | |
using_distance_map: bool = True, | |
) -> None: | |
if world_to_cam is None: | |
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device) | |
if inpaint_mask.sum() == 0: | |
return | |
vertices, faces, colors = features_to_world_space_mesh( | |
colors=rgb.squeeze(0), | |
depth=depth, | |
fov_in_degrees=self.cfg.fov, | |
world_to_cam=world_to_cam, | |
mask=inpaint_mask, | |
faces=self.faces, | |
vertices=self.vertices, | |
using_distance_map=using_distance_map, | |
edge_threshold=0.05, | |
) | |
faces += self.vertices.shape[1] | |
self.vertices = torch.cat([self.vertices, vertices], dim=1) | |
self.colors = torch.cat([self.colors, colors], dim=1) | |
self.faces = torch.cat([self.faces, faces], dim=1) | |
def get_edge_image_by_depth( | |
self, depth: torch.Tensor, dilate_iter: int = 1 | |
) -> np.ndarray: | |
if isinstance(depth, torch.Tensor): | |
depth = depth.cpu().detach().numpy() | |
gray = (depth / depth.max() * 255).astype(np.uint8) | |
edges = cv2.Canny(gray, 60, 150) | |
if dilate_iter > 0: | |
kernel = np.ones((3, 3), np.uint8) | |
edges = cv2.dilate(edges, kernel, iterations=dilate_iter) | |
return edges | |
def mesh_repair_by_greedy_view_selection( | |
self, pose_dict: dict[str, torch.Tensor], output_dir: str | |
) -> list: | |
inpainted_panos_w_pose = [] | |
while len(pose_dict) > 0: | |
logger.info(f"Repairing mesh left rounds {len(pose_dict)}") | |
sampled_views = [] | |
for key, pose in pose_dict.items(): | |
pano_rgb, pano_distance, pano_mask = self.render_pano(pose) | |
completeness = torch.sum(1 - pano_mask) / (pano_mask.numel()) | |
sampled_views.append((key, completeness.item(), pose)) | |
if len(sampled_views) == 0: | |
break | |
# Find inpainting with least view completeness. | |
sampled_views = sorted(sampled_views, key=lambda x: x[1]) | |
key, _, pose = sampled_views[len(sampled_views) * 2 // 3] | |
pose_dict.pop(key) | |
pano_rgb, pano_distance, pano_mask = self.render_pano(pose) | |
colors = pano_rgb.permute(1, 2, 0).clone() | |
distances = pano_distance.unsqueeze(-1).clone() | |
pano_inpaint_mask = pano_mask.clone() | |
init_pose = pose.clone() | |
normals = None | |
if pano_inpaint_mask.min().item() < 0.5: | |
colors, distances, normals = self.inpaint_panorama( | |
idx=key, | |
colors=colors, | |
distances=distances, | |
pano_mask=pano_inpaint_mask, | |
) | |
init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = ( | |
-pose[0, 3], | |
pose[2, 3], | |
0, | |
) | |
rays = gen_pano_rays( | |
init_pose, self.cfg.pano_h, self.cfg.pano_w | |
) | |
conflict_mask = self.sup_pool.geo_check( | |
rays, distances.unsqueeze(-1) | |
) # 0 is conflict, 1 not conflict | |
pano_inpaint_mask *= conflict_mask | |
self.rgbd_to_mesh( | |
colors.permute(2, 0, 1), | |
distances, | |
pano_inpaint_mask, | |
world_to_cam=pose, | |
) | |
self.sup_pool.register_sup_info( | |
pose=init_pose, | |
mask=pano_inpaint_mask.clone(), | |
rgb=colors, | |
distance=distances.unsqueeze(-1), | |
normal=normals, | |
) | |
colors = colors.permute(2, 0, 1).unsqueeze(0) | |
inpainted_panos_w_pose.append([colors, pose]) | |
if self.cfg.visualize: | |
from embodied_gen.data.utils import DiffrastRender | |
tensor_to_pil(pano_rgb.unsqueeze(0)).save( | |
f"{output_dir}/rendered_pano_{key}.jpg" | |
) | |
tensor_to_pil(colors).save( | |
f"{output_dir}/inpainted_pano_{key}.jpg" | |
) | |
norm_depth = DiffrastRender.normalize_map_by_mask( | |
distances, torch.ones_like(distances) | |
) | |
heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8) | |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
Image.fromarray(heatmap).save( | |
f"{output_dir}/inpainted_depth_{key}.png" | |
) | |
return inpainted_panos_w_pose | |
def inpaint_panorama( | |
self, | |
idx: int, | |
colors: torch.Tensor, | |
distances: torch.Tensor, | |
pano_mask: torch.Tensor, | |
) -> tuple[torch.Tensor]: | |
mask = (pano_mask[None, ..., None] > 0.5).float() | |
mask = mask.permute(0, 3, 1, 2) | |
mask = dilation(mask, kernel=self.kernel) | |
mask = mask[0, 0, ..., None] # hwc | |
inpainted_img = self.inpainter.inpaint(idx, colors, mask) | |
inpainted_img = colors * (1 - mask) + inpainted_img * mask | |
inpainted_distances, inpainted_normals = self.geo_predictor( | |
idx, | |
inpainted_img, | |
distances[..., None], | |
mask=mask, | |
reg_loss_weight=0.0, | |
normal_loss_weight=5e-2, | |
normal_tv_loss_weight=5e-2, | |
) | |
return inpainted_img, inpainted_distances.squeeze(), inpainted_normals | |
def preprocess_pano( | |
self, image: Image.Image | str | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
if isinstance(image, str): | |
image = Image.open(image) | |
image = image.convert("RGB") | |
if image.size[0] < image.size[1]: | |
image = image.transpose(Image.TRANSPOSE) | |
image = resize_image_with_aspect_ratio(image, self.cfg.pano_w) | |
image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255 | |
image_rgb = image_rgb.to(self.device) | |
image_depth = self.pano_fusion_distance_predictor.predict( | |
image_rgb.permute(1, 2, 0) | |
) | |
image_depth = ( | |
image_depth / image_depth.max() * self.cfg.depth_scale_factor | |
) | |
return image_rgb, image_depth | |
def pano_to_perpective( | |
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float | |
) -> torch.Tensor: | |
rots = dict( | |
roll=0, | |
pitch=pitch, | |
yaw=yaw, | |
) | |
perspective = equi2pers( | |
equi=pano_image.squeeze(0), | |
rots=rots, | |
height=self.cfg.cubemap_h, | |
width=self.cfg.cubemap_w, | |
fov_x=fov, | |
mode="bilinear", | |
).unsqueeze(0) | |
return perspective | |
def pano_to_cubemap(self, pano_rgb: torch.Tensor): | |
# Define six canonical cube directions in (pitch, yaw) | |
directions = [ | |
(0, 0), | |
(0, 1.5 * np.pi), | |
(0, 1.0 * np.pi), | |
(0, 0.5 * np.pi), | |
(-0.5 * np.pi, 0), | |
(0.5 * np.pi, 0), | |
] | |
cubemaps_rgb = [] | |
for pitch, yaw in directions: | |
rgb_view = self.pano_to_perpective( | |
pano_rgb, pitch, yaw, fov=self.cfg.fov | |
) | |
cubemaps_rgb.append(rgb_view.cpu()) | |
return cubemaps_rgb | |
def save_mesh(self, output_path: str) -> None: | |
vertices_np = self.vertices.T.cpu().numpy() | |
colors_np = self.colors.T.cpu().numpy() | |
faces_np = self.faces.T.cpu().numpy() | |
mesh = trimesh.Trimesh( | |
vertices=vertices_np, faces=faces_np, vertex_colors=colors_np | |
) | |
mesh.export(output_path) | |
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: | |
pose = mesh_pose.clone() | |
pose[0, :] *= -1 | |
pose[1, :] *= -1 | |
Rw2c = pose[:3, :3].cpu().numpy() | |
Tw2c = pose[:3, 3:].cpu().numpy() | |
yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) | |
Rc2w = (yz_reverse @ Rw2c).T | |
Tc2w = -(Rc2w @ yz_reverse @ Tw2c) | |
c2w = np.concatenate((Rc2w, Tc2w), axis=1) | |
c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0) | |
return c2w | |
def __call__(self, pano_image: Image.Image | str, output_dir: str): | |
self.init_mesh_params() | |
pano_rgb, pano_depth = self.preprocess_pano(pano_image) | |
self.sup_pool = SupInfoPool() | |
self.sup_pool.register_sup_info( | |
pose=torch.eye(4).to(self.device), | |
mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]), | |
rgb=pano_rgb.permute(1, 2, 0), | |
distance=pano_depth[..., None], | |
) | |
self.sup_pool.gen_occ_grid(res=256) | |
logger.info("Init mesh from pano RGBD image...") | |
depth_edge = self.get_edge_image_by_depth(pano_depth) | |
inpaint_edge_mask = ( | |
~torch.from_numpy(depth_edge).to(self.device).bool() | |
) | |
self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask) | |
repair_poses = self.load_inpaint_poses(self.camera_poses) | |
inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection( | |
repair_poses, output_dir | |
) | |
torch.cuda.empty_cache() | |
torch.set_default_device("cpu") | |
if self.cfg.mesh_file is not None: | |
mesh_path = os.path.join(output_dir, self.cfg.mesh_file) | |
self.save_mesh(mesh_path) | |
if self.cfg.gs_data_file is None: | |
return | |
logger.info(f"Dump data for 3DGS training...") | |
points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8) | |
data = { | |
"points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3) | |
"points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3) | |
"train": [], | |
"eval": [], | |
} | |
image_h = self.cfg.cubemap_h * self.cfg.upscale_factor | |
image_w = self.cfg.cubemap_w * self.cfg.upscale_factor | |
Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov) | |
for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses): | |
cubemaps = self.pano_to_cubemap(pano_img) | |
for i in range(len(cubemaps)): | |
cubemap = tensor_to_pil(cubemaps[i]) | |
cubemap = self.super_model(cubemap) | |
mesh_pose = self.cubemap_w2cs[i] @ pano_pose | |
c2w = self.mesh_pose_to_gs_pose(mesh_pose) | |
data["train"].append( | |
{ | |
"camtoworld": c2w.astype(np.float32), | |
"K": Ks.astype(np.float32), | |
"image": np.array(cubemap), | |
"image_h": image_h, | |
"image_w": image_w, | |
"image_id": len(cubemaps) * idx + i, | |
} | |
) | |
# Camera poses for evaluation. | |
for idx in range(len(self.camera_poses)): | |
c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx]) | |
data["eval"].append( | |
{ | |
"camtoworld": c2w.astype(np.float32), | |
"K": Ks.astype(np.float32), | |
"image_h": image_h, | |
"image_w": image_w, | |
"image_id": idx, | |
} | |
) | |
data_path = os.path.join(output_dir, self.cfg.gs_data_file) | |
torch.save(data, data_path) | |
return | |
if __name__ == "__main__": | |
output_dir = "outputs/bg_v2/test3" | |
input_pano = "apps/assets/example_scene/result_pano.png" | |
config = Pano2MeshSRConfig() | |
pipeline = Pano2MeshSRPipeline(config) | |
pipeline(input_pano, output_dir) | |