Diffsplat / extensions /RaDe-GS /mesh_extract_tetrahedra.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
#adopted from https://github.com/autonomousvision/gaussian-opacity-fields/blob/main/extract_mesh.py
import torch
from scene import Scene
import os
from os import makedirs
from gaussian_renderer import render, integrate
import random
from tqdm import tqdm
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
import trimesh
from tetranerf.utils.extension import cpp
from utils.tetmesh import marching_tetrahedra
@torch.no_grad()
def evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size):
final_alpha = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda")
with torch.no_grad():
for _, view in enumerate(tqdm(views, desc="Rendering progress")):
ret = integrate(points, view, gaussians, pipeline, background, kernel_size=kernel_size)
alpha_integrated = ret["alpha_integrated"]
final_alpha = torch.min(final_alpha, alpha_integrated)
alpha = 1 - final_alpha
return alpha
@torch.no_grad()
def evaluage_cull_alpha(points, views, masks, gaussians, pipeline, background, kernel_size):
# final_sdf = torch.zeros((points.shape[0]), dtype=torch.float32, device="cuda")
final_sdf = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda")
weight = torch.zeros((points.shape[0]), dtype=torch.int32, device="cuda")
with torch.no_grad():
for cam_id, view in enumerate(tqdm(views, desc="Rendering progress")):
torch.cuda.empty_cache()
ret = integrate(points, view, gaussians, pipeline, background, kernel_size)
alpha_integrated = ret["alpha_integrated"]
point_coordinate = ret["point_coordinate"]
point_coordinate[:,0] = (point_coordinate[:,0]*2+1)/(views[cam_id].image_width-1) - 1
point_coordinate[:,1] = (point_coordinate[:,1]*2+1)/(views[cam_id].image_height-1) - 1
rendered_mask = ret["render"][7]
mask = rendered_mask[None]
if not view.gt_mask is None:
mask = mask * view.gt_mask
if not masks is None:
mask = mask * masks[cam_id]
valid_point_prob = torch.nn.functional.grid_sample(mask.type(torch.float32)[None],point_coordinate[None,None],padding_mode='zeros',align_corners=False)
valid_point_prob = valid_point_prob[0,0,0]
valid_point = valid_point_prob>0.5
final_sdf = torch.where(valid_point, torch.min(alpha_integrated,final_sdf), final_sdf)
weight = torch.where(valid_point, weight+1, weight)
final_sdf = torch.where(weight>0,0.5-final_sdf,-100)
return final_sdf
@torch.no_grad()
def marching_tetrahedra_with_binary_search(model_path, name, iteration, views, gaussians: GaussianModel, pipeline, background, kernel_size):
# generate tetra points here
points, points_scale = gaussians.get_tetra_points()
cells = cpp.triangulate(points)
mask = None
sdf = evaluage_cull_alpha(points, views, mask, gaussians, pipeline, background, kernel_size)
torch.cuda.empty_cache()
# the function marching_tetrahedra costs much memory, so we move it to cpu.
verts_list, scale_list, faces_list, _ = marching_tetrahedra(points.cpu()[None], cells.cpu().long(), sdf[None].cpu(), points_scale[None].cpu())
del points
del points_scale
del cells
end_points, end_sdf = verts_list[0]
end_scales = scale_list[0]
end_points, end_sdf, end_scales = end_points.cuda(), end_sdf.cuda(), end_scales.cuda()
faces=faces_list[0].cpu().numpy()
points = (end_points[:, 0, :] + end_points[:, 1, :]) / 2.
left_points = end_points[:, 0, :]
right_points = end_points[:, 1, :]
left_sdf = end_sdf[:, 0, :]
right_sdf = end_sdf[:, 1, :]
left_scale = end_scales[:, 0, 0]
right_scale = end_scales[:, 1, 0]
distance = torch.norm(left_points - right_points, dim=-1)
scale = left_scale + right_scale
n_binary_steps = 8
for step in range(n_binary_steps):
print("binary search in step {}".format(step))
mid_points = (left_points + right_points) / 2
mid_sdf = evaluage_cull_alpha(mid_points, views, mask, gaussians, pipeline, background, kernel_size)
mid_sdf = mid_sdf.unsqueeze(-1)
ind_low = ((mid_sdf < 0) & (left_sdf < 0)) | ((mid_sdf > 0) & (left_sdf > 0))
left_sdf[ind_low] = mid_sdf[ind_low]
right_sdf[~ind_low] = mid_sdf[~ind_low]
left_points[ind_low.flatten()] = mid_points[ind_low.flatten()]
right_points[~ind_low.flatten()] = mid_points[~ind_low.flatten()]
points = (left_points + right_points) / 2
mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, process=False)
# filter
vertice_mask = (distance <= scale).cpu().numpy()
face_mask = vertice_mask[faces].all(axis=1)
mesh.update_vertices(vertice_mask)
mesh.update_faces(face_mask)
mesh.export(os.path.join(model_path,"recon.ply"))
def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelineParams):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
gaussians.load_ply(os.path.join(dataset.model_path, "point_cloud", f"iteration_{iteration}", "point_cloud.ply"))
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
kernel_size = dataset.kernel_size
cams = scene.getTrainCameras()
marching_tetrahedra_with_binary_search(dataset.model_path, "test", iteration, cams, gaussians, pipeline, background, kernel_size)
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=30000, type=int)
parser.add_argument("--quiet", action="store_true")
args = get_combined_args(parser)
print("Rendering " + args.model_path)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))
extract_mesh(model.extract(args), args.iteration, pipeline.extract(args))