File size: 6,581 Bytes
bc2085d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import torch
import argparse
import mcubes
import trimesh
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
from utility.initialize import instantiate_from_config, get_obj_from_str
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
# load model
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, required=True)
parser.add_argument("--ckpt", type=str, default=None, required=True)
args = parser.parse_args()
configs = OmegaConf.load(args.config)
device = 'cuda'
vae = get_obj_from_str(configs.model.params.first_stage_config['target'])(**configs.model.params.first_stage_config['params'])
vae = vae.to(device)
vae.eval()
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(args.ckpt, map_location='cpu', strict=False, **configs.model.params)
model = model.to(device)
def extract_mesh(triplane_fname, save_name=None):
latent = torch.from_numpy(np.load(triplane_fname)).to(device)
with torch.no_grad():
with model.ema_scope():
triplane = model.decode_first_stage(latent)
# prepare volumn for marching cube
res = 128
c_list = torch.linspace(-1.2, 1.2, steps=res)
grid_x, grid_y, grid_z = torch.meshgrid(
c_list, c_list, c_list, indexing='ij'
)
coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) # 256x256x256x3
plane_axes = generate_planes()
feats = sample_from_planes(
plane_axes, triplane.reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4
)
fake_dirs = torch.zeros_like(coords)
fake_dirs[..., 0] = 1
with torch.no_grad():
out = vae.triplane_decoder.decoder(feats, fake_dirs)
u = out['sigma'].reshape(res, res, res).detach().cpu().numpy()
del out
# marching cube
vertices, triangles = mcubes.marching_cubes(u, 8)
min_bound = np.array([-1.2, -1.2, -1.2])
max_bound = np.array([1.2, 1.2, 1.2])
vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :]
pt_vertices = torch.from_numpy(vertices).to(device)
# extract vertices color
res_triplane = 256
# rays_d = torch.from_numpy(-vertices / np.sqrt((vertices ** 2).sum(-1)).reshape(-1, 1)).to(device).unsqueeze(0)
# rays_o = -rays_d * 2.0
render_kwargs = {
'depth_resolution': 128,
'disparity_space_sampling': False,
'box_warp': 2.4,
'depth_resolution_importance': 128,
'clamp_mode': 'softplus',
'white_back': True,
'det': True
}
# render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane), rays_o, rays_d, render_kwargs, whole_img=False, tvloss=False)
# rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
# rgb = (rgb * 255).astype(np.uint8)
rays_o_list = [
np.array([0, 0, 2]),
np.array([0, 0, -2]),
np.array([0, 2, 0]),
np.array([0, -2, 0]),
np.array([2, 0, 0]),
np.array([-2, 0, 0]),
]
rgb_final = None
diff_final = None
for rays_o in tqdm(rays_o_list):
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
rays_d = pt_vertices.reshape(-1, 3) - rays_o
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1)
# batch_size = 2**14
# batch_num = (rays_o.shape[0] // batch_size) + 1
# rgb_list = []
# depth_diff_list = []
# for b in range(batch_num):
# cur_rays_o = rays_o[b * batch_size: (b + 1) * batch_size]
# cur_rays_d = rays_d[b * batch_size: (b + 1) * batch_size]
with torch.no_grad():
render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane),
rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs,
whole_img=False, tvloss=False)
rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy()
depth_diff = np.abs(dist - depth)
# rgb_list.append(rgb)
# depth_diff_list.append(depth_diff)
# del render_out
# torch.cuda.empty_cache()
# rgb = np.concatenate(rgb_list, 0)
# depth_diff = np.concatenate(depth_diff_list, 0)
if rgb_final is None:
rgb_final = rgb.copy()
diff_final = depth_diff.copy()
else:
ind = diff_final > depth_diff
rgb_final[ind] = rgb[ind]
diff_final[ind] = depth_diff[ind]
# bgr to rgb
rgb_final = np.stack([
rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0]
], -1)
# export to ply
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8))
if save_name:
trimesh.exchange.export.export_mesh(mesh, save_name, file_type='ply')
else:
trimesh.exchange.export.export_mesh(mesh, triplane_fname[:-4] + '.ply', file_type='ply')
# load triplane
# fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/sample_16_0.npy'
# u = np.load(fname)
# triplane_fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/triplane_16_0.npy'
# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt'
# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt_simple'
folder = '/mnt/lustre/hongfangzhou.p/AE3D/log/diff_res32ch8_preprocess_ca_text_new_triplane_96_full_openaimodel_only_cap3d_high_quality_7w/sample_demo_424_prompts_for_demo_30_60_10'
save_folder = folder + '_extract_mesh'
os.makedirs(save_folder, exist_ok=True)
fnames = [f.replace('_sample', 'triplane').replace('mp4', 'npy') for f in os.listdir(folder) if f.startswith('_')]
prompts = [l.strip() for l in open('test/prompts_for_demo_2.txt', 'r').readlines()][30:60]
# fnames = [os.path.join(folder, f) for f in os.listdir(folder) if (f.startswith('triplane') and f.endswith('.npy'))]
fnames = sorted(fnames)
def extract_number(s):
return int(s.split('_')[-2])
def extract_id(s):
return s.split('_')[-1].split('.')[0]
for fname in fnames:
try:
print(fname)
extract_mesh(os.path.join(folder, fname), os.path.join(save_folder, prompts[extract_number(fname)].replace(' ', '_') + '_' + extract_id(fname) + '.ply'))
except Exception as e:
print(e)
|