Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,242 Bytes
476e0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import torch
from random import randint
import sys
from scene import Scene, GaussianModel
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
import matplotlib.pyplot as plt
import math
import numpy as np
from scene.cameras import Camera
from gaussian_renderer import render
import open3d as o3d
import open3d.core as o3c
from scene.dataset_readers import sceneLoadTypeCallbacks
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
import json
def load_camera(args):
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
return cameraList_from_camInfos(scene_info.train_cameras, 1.0, args)
def extract_mesh(dataset, pipe, checkpoint_iterations=None):
gaussians = GaussianModel(dataset.sh_degree)
output_path = os.path.join(dataset.model_path,"point_cloud")
iteration = 0
if checkpoint_iterations is None:
for folder_name in os.listdir(output_path):
iteration= max(iteration,int(folder_name.split('_')[1]))
else:
iteration = checkpoint_iterations
output_path = os.path.join(output_path,"iteration_"+str(iteration),"point_cloud.ply")
gaussians.load_ply(output_path)
print(f'Loaded gaussians from {output_path}')
kernel_size = dataset.kernel_size
bg_color = [1, 1, 1]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
viewpoint_cam_list = load_camera(dataset)
depth_list = []
color_list = []
alpha_thres = 0.5
for viewpoint_cam in viewpoint_cam_list:
# Rendering offscreen from that camera
render_pkg = render(viewpoint_cam, gaussians, pipe, background, kernel_size)
rendered_img = torch.clamp(render_pkg["render"], min=0, max=1.0).cpu().numpy().transpose(1,2,0)
color_list.append(np.ascontiguousarray(rendered_img))
depth = render_pkg["median_depth"].clone()
if viewpoint_cam.gt_mask is not None:
depth[(viewpoint_cam.gt_mask < 0.5)] = 0
depth[render_pkg["mask"]<alpha_thres] = 0
depth_list.append(depth[0].cpu().numpy())
torch.cuda.empty_cache()
voxel_size = 0.002
o3d_device = o3d.core.Device("CPU:0")
vbg = o3d.t.geometry.VoxelBlockGrid(attr_names=('tsdf', 'weight', 'color'),
attr_dtypes=(o3c.float32,
o3c.float32,
o3c.float32),
attr_channels=((1), (1), (3)),
voxel_size=voxel_size,
block_resolution=16,
block_count=50000,
device=o3d_device)
for color, depth, viewpoint_cam in zip(color_list, depth_list, viewpoint_cam_list):
depth = o3d.t.geometry.Image(depth)
depth = depth.to(o3d_device)
color = o3d.t.geometry.Image(color)
color = color.to(o3d_device)
W, H = viewpoint_cam.image_width, viewpoint_cam.image_height
fx = W / (2 * math.tan(viewpoint_cam.FoVx / 2.))
fy = H / (2 * math.tan(viewpoint_cam.FoVy / 2.))
intrinsic = np.array([[fx,0,float(W)/2],[0,fy,float(H)/2],[0,0,1]],dtype=np.float64)
intrinsic = o3d.core.Tensor(intrinsic)
extrinsic = o3d.core.Tensor((viewpoint_cam.world_view_transform.T).cpu().numpy().astype(np.float64))
frustum_block_coords = vbg.compute_unique_block_coordinates(
depth,
intrinsic,
extrinsic,
1.0, 8.0
)
vbg.integrate(
frustum_block_coords,
depth,
color,
intrinsic,
extrinsic,
1.0, 8.0
)
mesh = vbg.extract_triangle_mesh()
mesh.compute_vertex_normals()
o3d.io.write_triangle_mesh(os.path.join(dataset.model_path,"recon.ply"),mesh.to_legacy())
print("done!")
if __name__ == "__main__":
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
pp = PipelineParams(parser)
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=None)
args = parser.parse_args(sys.argv[1:])
with torch.no_grad():
extract_mesh(lp.extract(args), pp.extract(args), args.checkpoint_iterations)
|