Spaces:
Running
on
Zero
Running
on
Zero
# | |
# Copyright (C) 2024, ShanghaiTech | |
# SVIP research group, https://github.com/svip-lab | |
# All rights reserved. | |
# | |
# This software is free for non-commercial, research and evaluation use | |
# under the terms of the LICENSE.md file. | |
# | |
# For inquiries contact huangbb@shanghaitech.edu.cn | |
# | |
#copy from https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/mesh_utils.py | |
import torch | |
import numpy as np | |
import os | |
import math | |
from tqdm import tqdm | |
from functools import partial | |
import open3d as o3d | |
import trimesh | |
from utils.depth_utils import depth_to_normal | |
def post_process_mesh(mesh, cluster_to_keep=1000): | |
""" | |
Post-process a mesh to filter out floaters and disconnected parts | |
""" | |
import copy | |
print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep)) | |
mesh_0 = copy.deepcopy(mesh) | |
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm: | |
triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles()) | |
triangle_clusters = np.asarray(triangle_clusters) | |
cluster_n_triangles = np.asarray(cluster_n_triangles) | |
cluster_area = np.asarray(cluster_area) | |
n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep] | |
n_cluster = max(n_cluster, 50) # filter meshes smaller than 50 | |
triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster | |
mesh_0.remove_triangles_by_mask(triangles_to_remove) | |
mesh_0.remove_unreferenced_vertices() | |
mesh_0.remove_degenerate_triangles() | |
print("num vertices raw {}".format(len(mesh.vertices))) | |
print("num vertices post {}".format(len(mesh_0.vertices))) | |
return mesh_0 | |
def to_cam_open3d(viewpoint_stack): | |
camera_traj = [] | |
for i, viewpoint_cam in enumerate(viewpoint_stack): | |
intrinsic=o3d.camera.PinholeCameraIntrinsic(width=viewpoint_cam.image_width, | |
height=viewpoint_cam.image_height, | |
cx = viewpoint_cam.image_width/2, | |
cy = viewpoint_cam.image_height/2, | |
fx = viewpoint_cam.image_width / (2 * math.tan(viewpoint_cam.FoVx / 2.)), | |
fy = viewpoint_cam.image_height / (2 * math.tan(viewpoint_cam.FoVy / 2.))) | |
extrinsic=np.asarray((viewpoint_cam.world_view_transform.T).cpu().numpy()) | |
camera = o3d.camera.PinholeCameraParameters() | |
camera.extrinsic = extrinsic | |
camera.intrinsic = intrinsic | |
camera_traj.append(camera) | |
return camera_traj | |
class GaussianExtractor(object): | |
def __init__(self, gaussians, render, pipe, bg_color=None): | |
""" | |
a class that extracts attributes a scene presented by 2DGS | |
Usage example: | |
>>> gaussExtrator = GaussianExtractor(gaussians, render, pipe) | |
>>> gaussExtrator.reconstruction(view_points) | |
>>> mesh = gaussExtractor.export_mesh_bounded(...) | |
""" | |
if bg_color is None: | |
bg_color = [0, 0, 0] | |
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
self.gaussians = gaussians | |
self.render = partial(render, pipe=pipe, bg_color=background) | |
self.clean() | |
def clean(self): | |
self.depthmaps = [] | |
self.alphamaps = [] | |
self.rgbmaps = [] | |
self.normals = [] | |
self.depth_normals = [] | |
self.viewpoint_stack = [] | |
def reconstruction(self, viewpoint_stack): | |
""" | |
reconstruct radiance field given cameras | |
""" | |
self.clean() | |
self.viewpoint_stack = viewpoint_stack | |
for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields"): | |
render_pkg = self.render(viewpoint_cam, self.gaussians) | |
rgb = render_pkg['render'] | |
alpha = render_pkg['mask'] | |
normal = torch.nn.functional.normalize(render_pkg['normal'], dim=0) | |
depth = render_pkg['middepth'] | |
depth_normal, _ = depth_to_normal(viewpoint_cam, depth) | |
depth_normal = depth_normal.permute(2,0,1) | |
# depth_normal = render_pkg['surf_normal'] | |
self.rgbmaps.append(rgb.cpu()) | |
self.depthmaps.append(depth.cpu()) | |
self.alphamaps.append(alpha.cpu()) | |
self.normals.append(normal.cpu()) | |
self.depth_normals.append(depth_normal.cpu()) | |
self.rgbmaps = torch.stack(self.rgbmaps, dim=0) | |
self.depthmaps = torch.stack(self.depthmaps, dim=0) | |
self.alphamaps = torch.stack(self.alphamaps, dim=0) | |
self.depth_normals = torch.stack(self.depth_normals, dim=0) | |
def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_backgrond=True): | |
""" | |
Perform TSDF fusion given a fixed depth range, used in the paper. | |
voxel_size: the voxel size of the volume | |
sdf_trunc: truncation value | |
depth_trunc: maximum depth range, should depended on the scene's scales | |
mask_backgrond: whether to mask backgroud, only works when the dataset have masks | |
return o3d.mesh | |
""" | |
print("Running tsdf volume integration ...") | |
print(f'voxel_size: {voxel_size}') | |
print(f'sdf_trunc: {sdf_trunc}') | |
print(f'depth_truc: {depth_trunc}') | |
volume = o3d.pipelines.integration.ScalableTSDFVolume( | |
voxel_length= voxel_size, | |
sdf_trunc=sdf_trunc, | |
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 | |
) | |
for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"): | |
rgb = self.rgbmaps[i] | |
depth = self.depthmaps[i] | |
# if we have mask provided, use it | |
if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None): | |
depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0 | |
# make open3d rgbd | |
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( | |
o3d.geometry.Image(np.asarray(rgb.permute(1,2,0).cpu().numpy() * 255, order="C", dtype=np.uint8)), | |
o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order="C")), | |
depth_trunc = depth_trunc, convert_rgb_to_intensity=False, | |
depth_scale = 1.0 | |
) | |
volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic) | |
mesh = volume.extract_triangle_mesh() | |
return mesh | |
def extract_mesh_unbounded(self, resolution=1024): | |
""" | |
Experimental features, extracting meshes from unbounded scenes, not fully test across datasets. | |
#TODO: support color mesh exporting | |
sdf_trunc: truncation value | |
return o3d.mesh | |
""" | |
def contract(x): | |
mag = torch.linalg.norm(x, ord=2, dim=-1)[..., None] | |
return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag)) | |
def uncontract(y): | |
mag = torch.linalg.norm(y, ord=2, dim=-1)[..., None] | |
return torch.where(mag < 1, y, (1 / (2-mag) * (y/mag))) | |
def compute_sdf_perframe(i, points, depthmap, rgbmap, normalmap, viewpoint_cam): | |
""" | |
compute per frame sdf | |
""" | |
new_points = torch.cat([points, torch.ones_like(points[...,:1])], dim=-1) @ viewpoint_cam.full_proj_transform | |
z = new_points[..., -1:] | |
pix_coords = (new_points[..., :2] / new_points[..., -1:]) | |
mask_proj = ((pix_coords > -1. ) & (pix_coords < 1.) & (z > 0)).all(dim=-1) | |
sampled_depth = torch.nn.functional.grid_sample(depthmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(-1, 1) | |
sampled_rgb = torch.nn.functional.grid_sample(rgbmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T | |
sampled_normal = torch.nn.functional.grid_sample(normalmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T | |
sdf = (sampled_depth-z) | |
return sdf, sampled_rgb, sampled_normal, mask_proj | |
def compute_unbounded_tsdf(samples, inv_contraction, voxel_size, return_rgb=False): | |
""" | |
Fusion all frames, perform adaptive sdf_funcation on the contract spaces. | |
""" | |
if inv_contraction is not None: | |
samples = inv_contraction(samples) | |
mask = torch.linalg.norm(samples, dim=-1) > 1 | |
# adaptive sdf_truncation | |
sdf_trunc = 5 * voxel_size * torch.ones_like(samples[:, 0]) | |
sdf_trunc[mask] *= 1/(2-torch.linalg.norm(samples, dim=-1)[mask].clamp(max=1.9)) | |
else: | |
sdf_trunc = 5 * voxel_size | |
tsdfs = torch.ones_like(samples[:,0]) * 1 | |
rgbs = torch.zeros((samples.shape[0], 3)).cuda() | |
weights = torch.ones_like(samples[:,0]) | |
for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="TSDF integration progress"): | |
sdf, rgb, normal, mask_proj = compute_sdf_perframe(i, samples, | |
depthmap = self.depthmaps[i], | |
rgbmap = self.rgbmaps[i], | |
normalmap = self.depth_normals[i], | |
viewpoint_cam=self.viewpoint_stack[i], | |
) | |
# volume integration | |
sdf = sdf.flatten() | |
mask_proj = mask_proj & (sdf > -sdf_trunc) | |
sdf = torch.clamp(sdf / sdf_trunc, min=-1.0, max=1.0)[mask_proj] | |
w = weights[mask_proj] | |
wp = w + 1 | |
tsdfs[mask_proj] = (tsdfs[mask_proj] * w + sdf) / wp | |
rgbs[mask_proj] = (rgbs[mask_proj] * w[:,None] + rgb[mask_proj]) / wp[:,None] | |
# update weight | |
weights[mask_proj] = wp | |
if return_rgb: | |
return tsdfs, rgbs | |
return tsdfs | |
from utils.render_utils import transform_poses_pca, focus_point_fn | |
torch.cuda.empty_cache() | |
c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in self.viewpoint_stack]) | |
poses = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1]) | |
center = (focus_point_fn(poses)) | |
radius = np.linalg.norm(c2ws[:,:3,3] - center, axis=-1).min() | |
center = torch.from_numpy(center).float().cuda() | |
normalize = lambda x: (x - center) / radius | |
unnormalize = lambda x: (x * radius) + center | |
inv_contraction = lambda x: unnormalize(uncontract(x)) | |
N = resolution | |
voxel_size = (radius * 2 / N) | |
print(f"Computing sdf gird resolution {N} x {N} x {N}") | |
print(f"Define the voxel_size as {voxel_size}") | |
sdf_function = lambda x: compute_unbounded_tsdf(x, inv_contraction, voxel_size) | |
from utils.mcube_utils import marching_cubes_with_contraction | |
R = contract(normalize(self.gaussians.get_xyz)).norm(dim=-1).cpu().numpy() | |
R = np.quantile(R, q=0.95) | |
R = min(R+0.01, 1.9) | |
mesh = marching_cubes_with_contraction( | |
sdf=sdf_function, | |
bounding_box_min=(-R, -R, -R), | |
bounding_box_max=(R, R, R), | |
level=0, | |
resolution=N, | |
inv_contraction=inv_contraction, | |
) | |
# coloring the mesh | |
torch.cuda.empty_cache() | |
mesh = mesh.as_open3d | |
print("texturing mesh ... ") | |
_, rgbs = compute_unbounded_tsdf(torch.tensor(np.asarray(mesh.vertices)).float().cuda(), inv_contraction=None, voxel_size=voxel_size, return_rgb=True) | |
mesh.vertex_colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy()) | |
return mesh | |
def export_image(self, path): | |
render_path = os.path.join(path, "renders") | |
gts_path = os.path.join(path, "gt") | |
vis_path = os.path.join(path, "vis") | |
os.makedirs(render_path, exist_ok=True) | |
os.makedirs(vis_path, exist_ok=True) | |
os.makedirs(gts_path, exist_ok=True) | |
for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"): | |
gt = viewpoint_cam.original_image[0:3, :, :] | |
save_img_u8(gt.permute(1,2,0).cpu().numpy(), os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) | |
save_img_u8(self.rgbmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) | |
save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff")) | |
save_img_u8(self.normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png")) | |
save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png")) |