paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
#
# Copyright (C) 2024, ShanghaiTech
# SVIP research group, https://github.com/svip-lab
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact huangbb@shanghaitech.edu.cn
#
#copy from https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/mesh_utils.py
import torch
import numpy as np
import os
import math
from tqdm import tqdm
from functools import partial
import open3d as o3d
import trimesh
from utils.depth_utils import depth_to_normal
def post_process_mesh(mesh, cluster_to_keep=1000):
"""
Post-process a mesh to filter out floaters and disconnected parts
"""
import copy
print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep))
mesh_0 = copy.deepcopy(mesh)
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())
triangle_clusters = np.asarray(triangle_clusters)
cluster_n_triangles = np.asarray(cluster_n_triangles)
cluster_area = np.asarray(cluster_area)
n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]
n_cluster = max(n_cluster, 50) # filter meshes smaller than 50
triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster
mesh_0.remove_triangles_by_mask(triangles_to_remove)
mesh_0.remove_unreferenced_vertices()
mesh_0.remove_degenerate_triangles()
print("num vertices raw {}".format(len(mesh.vertices)))
print("num vertices post {}".format(len(mesh_0.vertices)))
return mesh_0
def to_cam_open3d(viewpoint_stack):
camera_traj = []
for i, viewpoint_cam in enumerate(viewpoint_stack):
intrinsic=o3d.camera.PinholeCameraIntrinsic(width=viewpoint_cam.image_width,
height=viewpoint_cam.image_height,
cx = viewpoint_cam.image_width/2,
cy = viewpoint_cam.image_height/2,
fx = viewpoint_cam.image_width / (2 * math.tan(viewpoint_cam.FoVx / 2.)),
fy = viewpoint_cam.image_height / (2 * math.tan(viewpoint_cam.FoVy / 2.)))
extrinsic=np.asarray((viewpoint_cam.world_view_transform.T).cpu().numpy())
camera = o3d.camera.PinholeCameraParameters()
camera.extrinsic = extrinsic
camera.intrinsic = intrinsic
camera_traj.append(camera)
return camera_traj
class GaussianExtractor(object):
def __init__(self, gaussians, render, pipe, bg_color=None):
"""
a class that extracts attributes a scene presented by 2DGS
Usage example:
>>> gaussExtrator = GaussianExtractor(gaussians, render, pipe)
>>> gaussExtrator.reconstruction(view_points)
>>> mesh = gaussExtractor.export_mesh_bounded(...)
"""
if bg_color is None:
bg_color = [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
self.gaussians = gaussians
self.render = partial(render, pipe=pipe, bg_color=background)
self.clean()
@torch.no_grad()
def clean(self):
self.depthmaps = []
self.alphamaps = []
self.rgbmaps = []
self.normals = []
self.depth_normals = []
self.viewpoint_stack = []
@torch.no_grad()
def reconstruction(self, viewpoint_stack):
"""
reconstruct radiance field given cameras
"""
self.clean()
self.viewpoint_stack = viewpoint_stack
for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields"):
render_pkg = self.render(viewpoint_cam, self.gaussians)
rgb = render_pkg['render']
alpha = render_pkg['mask']
normal = torch.nn.functional.normalize(render_pkg['normal'], dim=0)
depth = render_pkg['middepth']
depth_normal, _ = depth_to_normal(viewpoint_cam, depth)
depth_normal = depth_normal.permute(2,0,1)
# depth_normal = render_pkg['surf_normal']
self.rgbmaps.append(rgb.cpu())
self.depthmaps.append(depth.cpu())
self.alphamaps.append(alpha.cpu())
self.normals.append(normal.cpu())
self.depth_normals.append(depth_normal.cpu())
self.rgbmaps = torch.stack(self.rgbmaps, dim=0)
self.depthmaps = torch.stack(self.depthmaps, dim=0)
self.alphamaps = torch.stack(self.alphamaps, dim=0)
self.depth_normals = torch.stack(self.depth_normals, dim=0)
@torch.no_grad()
def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_backgrond=True):
"""
Perform TSDF fusion given a fixed depth range, used in the paper.
voxel_size: the voxel size of the volume
sdf_trunc: truncation value
depth_trunc: maximum depth range, should depended on the scene's scales
mask_backgrond: whether to mask backgroud, only works when the dataset have masks
return o3d.mesh
"""
print("Running tsdf volume integration ...")
print(f'voxel_size: {voxel_size}')
print(f'sdf_trunc: {sdf_trunc}')
print(f'depth_truc: {depth_trunc}')
volume = o3d.pipelines.integration.ScalableTSDFVolume(
voxel_length= voxel_size,
sdf_trunc=sdf_trunc,
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
)
for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"):
rgb = self.rgbmaps[i]
depth = self.depthmaps[i]
# if we have mask provided, use it
if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None):
depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0
# make open3d rgbd
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
o3d.geometry.Image(np.asarray(rgb.permute(1,2,0).cpu().numpy() * 255, order="C", dtype=np.uint8)),
o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order="C")),
depth_trunc = depth_trunc, convert_rgb_to_intensity=False,
depth_scale = 1.0
)
volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic)
mesh = volume.extract_triangle_mesh()
return mesh
@torch.no_grad()
def extract_mesh_unbounded(self, resolution=1024):
"""
Experimental features, extracting meshes from unbounded scenes, not fully test across datasets.
#TODO: support color mesh exporting
sdf_trunc: truncation value
return o3d.mesh
"""
def contract(x):
mag = torch.linalg.norm(x, ord=2, dim=-1)[..., None]
return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag))
def uncontract(y):
mag = torch.linalg.norm(y, ord=2, dim=-1)[..., None]
return torch.where(mag < 1, y, (1 / (2-mag) * (y/mag)))
def compute_sdf_perframe(i, points, depthmap, rgbmap, normalmap, viewpoint_cam):
"""
compute per frame sdf
"""
new_points = torch.cat([points, torch.ones_like(points[...,:1])], dim=-1) @ viewpoint_cam.full_proj_transform
z = new_points[..., -1:]
pix_coords = (new_points[..., :2] / new_points[..., -1:])
mask_proj = ((pix_coords > -1. ) & (pix_coords < 1.) & (z > 0)).all(dim=-1)
sampled_depth = torch.nn.functional.grid_sample(depthmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(-1, 1)
sampled_rgb = torch.nn.functional.grid_sample(rgbmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T
sampled_normal = torch.nn.functional.grid_sample(normalmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T
sdf = (sampled_depth-z)
return sdf, sampled_rgb, sampled_normal, mask_proj
def compute_unbounded_tsdf(samples, inv_contraction, voxel_size, return_rgb=False):
"""
Fusion all frames, perform adaptive sdf_funcation on the contract spaces.
"""
if inv_contraction is not None:
samples = inv_contraction(samples)
mask = torch.linalg.norm(samples, dim=-1) > 1
# adaptive sdf_truncation
sdf_trunc = 5 * voxel_size * torch.ones_like(samples[:, 0])
sdf_trunc[mask] *= 1/(2-torch.linalg.norm(samples, dim=-1)[mask].clamp(max=1.9))
else:
sdf_trunc = 5 * voxel_size
tsdfs = torch.ones_like(samples[:,0]) * 1
rgbs = torch.zeros((samples.shape[0], 3)).cuda()
weights = torch.ones_like(samples[:,0])
for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="TSDF integration progress"):
sdf, rgb, normal, mask_proj = compute_sdf_perframe(i, samples,
depthmap = self.depthmaps[i],
rgbmap = self.rgbmaps[i],
normalmap = self.depth_normals[i],
viewpoint_cam=self.viewpoint_stack[i],
)
# volume integration
sdf = sdf.flatten()
mask_proj = mask_proj & (sdf > -sdf_trunc)
sdf = torch.clamp(sdf / sdf_trunc, min=-1.0, max=1.0)[mask_proj]
w = weights[mask_proj]
wp = w + 1
tsdfs[mask_proj] = (tsdfs[mask_proj] * w + sdf) / wp
rgbs[mask_proj] = (rgbs[mask_proj] * w[:,None] + rgb[mask_proj]) / wp[:,None]
# update weight
weights[mask_proj] = wp
if return_rgb:
return tsdfs, rgbs
return tsdfs
from utils.render_utils import transform_poses_pca, focus_point_fn
torch.cuda.empty_cache()
c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in self.viewpoint_stack])
poses = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1])
center = (focus_point_fn(poses))
radius = np.linalg.norm(c2ws[:,:3,3] - center, axis=-1).min()
center = torch.from_numpy(center).float().cuda()
normalize = lambda x: (x - center) / radius
unnormalize = lambda x: (x * radius) + center
inv_contraction = lambda x: unnormalize(uncontract(x))
N = resolution
voxel_size = (radius * 2 / N)
print(f"Computing sdf gird resolution {N} x {N} x {N}")
print(f"Define the voxel_size as {voxel_size}")
sdf_function = lambda x: compute_unbounded_tsdf(x, inv_contraction, voxel_size)
from utils.mcube_utils import marching_cubes_with_contraction
R = contract(normalize(self.gaussians.get_xyz)).norm(dim=-1).cpu().numpy()
R = np.quantile(R, q=0.95)
R = min(R+0.01, 1.9)
mesh = marching_cubes_with_contraction(
sdf=sdf_function,
bounding_box_min=(-R, -R, -R),
bounding_box_max=(R, R, R),
level=0,
resolution=N,
inv_contraction=inv_contraction,
)
# coloring the mesh
torch.cuda.empty_cache()
mesh = mesh.as_open3d
print("texturing mesh ... ")
_, rgbs = compute_unbounded_tsdf(torch.tensor(np.asarray(mesh.vertices)).float().cuda(), inv_contraction=None, voxel_size=voxel_size, return_rgb=True)
mesh.vertex_colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy())
return mesh
@torch.no_grad()
def export_image(self, path):
render_path = os.path.join(path, "renders")
gts_path = os.path.join(path, "gt")
vis_path = os.path.join(path, "vis")
os.makedirs(render_path, exist_ok=True)
os.makedirs(vis_path, exist_ok=True)
os.makedirs(gts_path, exist_ok=True)
for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"):
gt = viewpoint_cam.original_image[0:3, :, :]
save_img_u8(gt.permute(1,2,0).cpu().numpy(), os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
save_img_u8(self.rgbmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff"))
save_img_u8(self.normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png"))
save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png"))