File size: 8,567 Bytes
31ca7a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import torch
import nvdiffrast.torch as dr
import trimesh
import os
from util import *
import render
import loss
import imageio
import sys
sys.path.append('..')
from flexicubes import FlexiCubes
###############################################################################
# Functions adapted from https://github.com/NVlabs/nvdiffrec
###############################################################################
def lr_schedule(iter):
return max(0.0, 10**(-(iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='flexicubes optimization')
parser.add_argument('-o', '--out_dir', type=str, default=None)
parser.add_argument('-rm', '--ref_mesh', type=str)
parser.add_argument('-i', '--iter', type=int, default=1000)
parser.add_argument('-b', '--batch', type=int, default=8)
parser.add_argument('-r', '--train_res', nargs=2, type=int, default=[2048, 2048])
parser.add_argument('-lr', '--learning_rate', type=float, default=0.01)
parser.add_argument('--voxel_grid_res', type=int, default=64)
parser.add_argument('--sdf_loss', type=bool, default=True)
parser.add_argument('--develop_reg', type=bool, default=False)
parser.add_argument('--sdf_regularizer', type=float, default=0.2)
parser.add_argument('-dr', '--display_res', nargs=2, type=int, default=[512, 512])
parser.add_argument('-si', '--save_interval', type=int, default=20)
FLAGS = parser.parse_args()
device = 'cuda'
os.makedirs(FLAGS.out_dir, exist_ok=True)
glctx = dr.RasterizeGLContext()
# Load GT mesh
gt_mesh = load_mesh(FLAGS.ref_mesh, device)
gt_mesh.auto_normals() # compute face normals for visualization
# ==============================================================================================
# Create and initialize FlexiCubes
# ==============================================================================================
fc = FlexiCubes(device)
x_nx3, cube_fx8 = fc.construct_voxel_grid(FLAGS.voxel_grid_res)
x_nx3 *= 2 # scale up the grid so that it's larger than the target object
sdf = torch.rand_like(x_nx3[:,0]) - 0.1 # randomly init SDF
sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
# set per-cube learnable weights to zeros
weight = torch.zeros((cube_fx8.shape[0], 21), dtype=torch.float, device='cuda')
weight = torch.nn.Parameter(weight.clone().detach(), requires_grad=True)
deform = torch.nn.Parameter(torch.zeros_like(x_nx3), requires_grad=True)
# Retrieve all the edges of the voxel grid; these edges will be utilized to
# compute the regularization loss in subsequent steps of the process.
all_edges = cube_fx8[:, fc.cube_edges].reshape(-1, 2)
grid_edges = torch.unique(all_edges, dim=0)
# ==============================================================================================
# Setup optimizer
# ==============================================================================================
optimizer = torch.optim.Adam([sdf, weight,deform], lr=FLAGS.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x))
# ==============================================================================================
# Train loop
# ==============================================================================================
for it in range(FLAGS.iter):
optimizer.zero_grad()
# sample random camera poses
mv, mvp = render.get_random_camera_batch(FLAGS.batch, iter_res=FLAGS.train_res, device=device, use_kaolin=False)
# render gt mesh
target = render.render_mesh_paper(gt_mesh, mv, mvp, FLAGS.train_res)
# extract and render FlexiCubes mesh
grid_verts = x_nx3 + (2-1e-8) / (FLAGS.voxel_grid_res * 2) * torch.tanh(deform)
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
gamma_f=weight[:,20], training=True)
flexicubes_mesh = Mesh(vertices, faces)
buffers = render.render_mesh_paper(flexicubes_mesh, mv, mvp, FLAGS.train_res)
# evaluate reconstruction loss
mask_loss = (buffers['mask'] - target['mask']).abs().mean()
depth_loss = (((((buffers['depth'] - (target['depth']))* target['mask'])**2).sum(-1)+1e-8)).sqrt().mean() * 10
t_iter = it / FLAGS.iter
sdf_weight = FLAGS.sdf_regularizer - (FLAGS.sdf_regularizer - FLAGS.sdf_regularizer/20)*min(1.0, 4.0 * t_iter)
reg_loss = loss.sdf_reg_loss(sdf, grid_edges).mean() * sdf_weight # Loss to eliminate internal floaters that are not visible
reg_loss += L_dev.mean() * 0.5
reg_loss += (weight[:,:20]).abs().mean() * 0.1
total_loss = mask_loss + depth_loss + reg_loss
if FLAGS.sdf_loss: # optionally add SDF loss to eliminate internal structures
with torch.no_grad():
pts = sample_random_points(1000, gt_mesh)
gt_sdf = compute_sdf(pts, gt_mesh.vertices, gt_mesh.faces)
pred_sdf = compute_sdf(pts, flexicubes_mesh.vertices, flexicubes_mesh.faces)
total_loss += torch.nn.functional.mse_loss(pred_sdf, gt_sdf) * 2e3
# optionally add developability regularizer, as described in paper section 5.2
if FLAGS.develop_reg:
reg_weight = max(0, t_iter - 0.8) * 5
if reg_weight > 0: # only applied after shape converges
reg_loss = loss.mesh_developable_reg(flexicubes_mesh).mean() * 10
reg_loss += (deform).abs().mean()
reg_loss += (weight[:,:20]).abs().mean()
total_loss = mask_loss + depth_loss + reg_loss
total_loss.backward()
optimizer.step()
scheduler.step()
if (it % FLAGS.save_interval == 0 or it == (FLAGS.iter-1)): # save normal image for visualization
with torch.no_grad():
# extract mesh with training=False
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
gamma_f=weight[:,20], training=False)
flexicubes_mesh = Mesh(vertices, faces)
flexicubes_mesh.auto_normals() # compute face normals for visualization
mv, mvp = render.get_rotate_camera(it//FLAGS.save_interval, iter_res=FLAGS.display_res, device=device,use_kaolin=False)
val_buffers = render.render_mesh_paper(flexicubes_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
val_image = ((val_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
gt_buffers = render.render_mesh_paper(gt_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
gt_image = ((gt_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
imageio.imwrite(os.path.join(FLAGS.out_dir, '{:04d}.png'.format(it)), np.concatenate([val_image, gt_image], 1))
print(f"Optimization Step [{it}/{FLAGS.iter}], Loss: {total_loss.item():.4f}")
# ==============================================================================================
# Save ouput
# ==============================================================================================
mesh_np = trimesh.Trimesh(vertices = vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy(), process=False)
mesh_np.export(os.path.join(FLAGS.out_dir, 'output_mesh.obj')) |