Spaces:
Running
on
Zero
Running
on
Zero
#adopted from https://github.com/autonomousvision/gaussian-opacity-fields/blob/main/extract_mesh.py | |
import torch | |
from scene import Scene | |
import os | |
from os import makedirs | |
from gaussian_renderer import render, integrate | |
import random | |
from tqdm import tqdm | |
from argparse import ArgumentParser | |
from arguments import ModelParams, PipelineParams, get_combined_args | |
from gaussian_renderer import GaussianModel | |
import numpy as np | |
import trimesh | |
from tetranerf.utils.extension import cpp | |
from utils.tetmesh import marching_tetrahedra | |
def evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size): | |
final_alpha = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda") | |
with torch.no_grad(): | |
for _, view in enumerate(tqdm(views, desc="Rendering progress")): | |
ret = integrate(points, view, gaussians, pipeline, background, kernel_size=kernel_size) | |
alpha_integrated = ret["alpha_integrated"] | |
final_alpha = torch.min(final_alpha, alpha_integrated) | |
alpha = 1 - final_alpha | |
return alpha | |
def evaluage_cull_alpha(points, views, masks, gaussians, pipeline, background, kernel_size): | |
# final_sdf = torch.zeros((points.shape[0]), dtype=torch.float32, device="cuda") | |
final_sdf = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda") | |
weight = torch.zeros((points.shape[0]), dtype=torch.int32, device="cuda") | |
with torch.no_grad(): | |
for cam_id, view in enumerate(tqdm(views, desc="Rendering progress")): | |
torch.cuda.empty_cache() | |
ret = integrate(points, view, gaussians, pipeline, background, kernel_size) | |
alpha_integrated = ret["alpha_integrated"] | |
point_coordinate = ret["point_coordinate"] | |
point_coordinate[:,0] = (point_coordinate[:,0]*2+1)/(views[cam_id].image_width-1) - 1 | |
point_coordinate[:,1] = (point_coordinate[:,1]*2+1)/(views[cam_id].image_height-1) - 1 | |
rendered_mask = ret["render"][7] | |
mask = rendered_mask[None] | |
if not view.gt_mask is None: | |
mask = mask * view.gt_mask | |
if not masks is None: | |
mask = mask * masks[cam_id] | |
valid_point_prob = torch.nn.functional.grid_sample(mask.type(torch.float32)[None],point_coordinate[None,None],padding_mode='zeros',align_corners=False) | |
valid_point_prob = valid_point_prob[0,0,0] | |
valid_point = valid_point_prob>0.5 | |
final_sdf = torch.where(valid_point, torch.min(alpha_integrated,final_sdf), final_sdf) | |
weight = torch.where(valid_point, weight+1, weight) | |
final_sdf = torch.where(weight>0,0.5-final_sdf,-100) | |
return final_sdf | |
def marching_tetrahedra_with_binary_search(model_path, name, iteration, views, gaussians: GaussianModel, pipeline, background, kernel_size): | |
# generate tetra points here | |
points, points_scale = gaussians.get_tetra_points() | |
cells = cpp.triangulate(points) | |
mask = None | |
sdf = evaluage_cull_alpha(points, views, mask, gaussians, pipeline, background, kernel_size) | |
torch.cuda.empty_cache() | |
# the function marching_tetrahedra costs much memory, so we move it to cpu. | |
verts_list, scale_list, faces_list, _ = marching_tetrahedra(points.cpu()[None], cells.cpu().long(), sdf[None].cpu(), points_scale[None].cpu()) | |
del points | |
del points_scale | |
del cells | |
end_points, end_sdf = verts_list[0] | |
end_scales = scale_list[0] | |
end_points, end_sdf, end_scales = end_points.cuda(), end_sdf.cuda(), end_scales.cuda() | |
faces=faces_list[0].cpu().numpy() | |
points = (end_points[:, 0, :] + end_points[:, 1, :]) / 2. | |
left_points = end_points[:, 0, :] | |
right_points = end_points[:, 1, :] | |
left_sdf = end_sdf[:, 0, :] | |
right_sdf = end_sdf[:, 1, :] | |
left_scale = end_scales[:, 0, 0] | |
right_scale = end_scales[:, 1, 0] | |
distance = torch.norm(left_points - right_points, dim=-1) | |
scale = left_scale + right_scale | |
n_binary_steps = 8 | |
for step in range(n_binary_steps): | |
print("binary search in step {}".format(step)) | |
mid_points = (left_points + right_points) / 2 | |
mid_sdf = evaluage_cull_alpha(mid_points, views, mask, gaussians, pipeline, background, kernel_size) | |
mid_sdf = mid_sdf.unsqueeze(-1) | |
ind_low = ((mid_sdf < 0) & (left_sdf < 0)) | ((mid_sdf > 0) & (left_sdf > 0)) | |
left_sdf[ind_low] = mid_sdf[ind_low] | |
right_sdf[~ind_low] = mid_sdf[~ind_low] | |
left_points[ind_low.flatten()] = mid_points[ind_low.flatten()] | |
right_points[~ind_low.flatten()] = mid_points[~ind_low.flatten()] | |
points = (left_points + right_points) / 2 | |
mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, process=False) | |
# filter | |
vertice_mask = (distance <= scale).cpu().numpy() | |
face_mask = vertice_mask[faces].all(axis=1) | |
mesh.update_vertices(vertice_mask) | |
mesh.update_faces(face_mask) | |
mesh.export(os.path.join(model_path,"recon.ply")) | |
def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelineParams): | |
with torch.no_grad(): | |
gaussians = GaussianModel(dataset.sh_degree) | |
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) | |
gaussians.load_ply(os.path.join(dataset.model_path, "point_cloud", f"iteration_{iteration}", "point_cloud.ply")) | |
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] | |
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
kernel_size = dataset.kernel_size | |
cams = scene.getTrainCameras() | |
marching_tetrahedra_with_binary_search(dataset.model_path, "test", iteration, cams, gaussians, pipeline, background, kernel_size) | |
if __name__ == "__main__": | |
# Set up command line argument parser | |
parser = ArgumentParser(description="Testing script parameters") | |
model = ModelParams(parser, sentinel=True) | |
pipeline = PipelineParams(parser) | |
parser.add_argument("--iteration", default=30000, type=int) | |
parser.add_argument("--quiet", action="store_true") | |
args = get_combined_args(parser) | |
print("Rendering " + args.model_path) | |
random.seed(0) | |
np.random.seed(0) | |
torch.manual_seed(0) | |
torch.cuda.set_device(torch.device("cuda:0")) | |
extract_mesh(model.extract(args), args.iteration, pipeline.extract(args)) |