Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
from random import randint | |
import sys | |
from scene import Scene, GaussianModel | |
from argparse import ArgumentParser, Namespace | |
from arguments import ModelParams, PipelineParams, OptimizationParams | |
import matplotlib.pyplot as plt | |
import math | |
import numpy as np | |
from scene.cameras import Camera | |
from gaussian_renderer import render | |
import open3d as o3d | |
import open3d.core as o3c | |
from scene.dataset_readers import sceneLoadTypeCallbacks | |
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON | |
import json | |
def load_camera(args): | |
if os.path.exists(os.path.join(args.source_path, "sparse")): | |
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) | |
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): | |
print("Found transforms_train.json file, assuming Blender data set!") | |
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) | |
return cameraList_from_camInfos(scene_info.train_cameras, 1.0, args) | |
def extract_mesh(dataset, pipe, checkpoint_iterations=None): | |
gaussians = GaussianModel(dataset.sh_degree) | |
output_path = os.path.join(dataset.model_path,"point_cloud") | |
iteration = 0 | |
if checkpoint_iterations is None: | |
for folder_name in os.listdir(output_path): | |
iteration= max(iteration,int(folder_name.split('_')[1])) | |
else: | |
iteration = checkpoint_iterations | |
output_path = os.path.join(output_path,"iteration_"+str(iteration),"point_cloud.ply") | |
gaussians.load_ply(output_path) | |
print(f'Loaded gaussians from {output_path}') | |
kernel_size = dataset.kernel_size | |
bg_color = [1, 1, 1] | |
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
viewpoint_cam_list = load_camera(dataset) | |
depth_list = [] | |
color_list = [] | |
alpha_thres = 0.5 | |
for viewpoint_cam in viewpoint_cam_list: | |
# Rendering offscreen from that camera | |
render_pkg = render(viewpoint_cam, gaussians, pipe, background, kernel_size) | |
rendered_img = torch.clamp(render_pkg["render"], min=0, max=1.0).cpu().numpy().transpose(1,2,0) | |
color_list.append(np.ascontiguousarray(rendered_img)) | |
depth = render_pkg["median_depth"].clone() | |
if viewpoint_cam.gt_mask is not None: | |
depth[(viewpoint_cam.gt_mask < 0.5)] = 0 | |
depth[render_pkg["mask"]<alpha_thres] = 0 | |
depth_list.append(depth[0].cpu().numpy()) | |
torch.cuda.empty_cache() | |
voxel_size = 0.002 | |
o3d_device = o3d.core.Device("CPU:0") | |
vbg = o3d.t.geometry.VoxelBlockGrid(attr_names=('tsdf', 'weight', 'color'), | |
attr_dtypes=(o3c.float32, | |
o3c.float32, | |
o3c.float32), | |
attr_channels=((1), (1), (3)), | |
voxel_size=voxel_size, | |
block_resolution=16, | |
block_count=50000, | |
device=o3d_device) | |
for color, depth, viewpoint_cam in zip(color_list, depth_list, viewpoint_cam_list): | |
depth = o3d.t.geometry.Image(depth) | |
depth = depth.to(o3d_device) | |
color = o3d.t.geometry.Image(color) | |
color = color.to(o3d_device) | |
W, H = viewpoint_cam.image_width, viewpoint_cam.image_height | |
fx = W / (2 * math.tan(viewpoint_cam.FoVx / 2.)) | |
fy = H / (2 * math.tan(viewpoint_cam.FoVy / 2.)) | |
intrinsic = np.array([[fx,0,float(W)/2],[0,fy,float(H)/2],[0,0,1]],dtype=np.float64) | |
intrinsic = o3d.core.Tensor(intrinsic) | |
extrinsic = o3d.core.Tensor((viewpoint_cam.world_view_transform.T).cpu().numpy().astype(np.float64)) | |
frustum_block_coords = vbg.compute_unique_block_coordinates( | |
depth, | |
intrinsic, | |
extrinsic, | |
1.0, 8.0 | |
) | |
vbg.integrate( | |
frustum_block_coords, | |
depth, | |
color, | |
intrinsic, | |
extrinsic, | |
1.0, 8.0 | |
) | |
mesh = vbg.extract_triangle_mesh() | |
mesh.compute_vertex_normals() | |
o3d.io.write_triangle_mesh(os.path.join(dataset.model_path,"recon.ply"),mesh.to_legacy()) | |
print("done!") | |
if __name__ == "__main__": | |
parser = ArgumentParser(description="Training script parameters") | |
lp = ModelParams(parser) | |
pp = PipelineParams(parser) | |
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=None) | |
args = parser.parse_args(sys.argv[1:]) | |
with torch.no_grad(): | |
extract_mesh(lp.extract(args), pp.extract(args), args.checkpoint_iterations) | |