Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import os | |
import time | |
from collections import OrderedDict | |
from typing import Optional, List | |
import argparse | |
from functools import partial | |
from einops import repeat, rearrange | |
import numpy as np | |
from PIL import Image | |
import trimesh | |
import cv2 | |
import torch | |
import pytorch_lightning as pl | |
from michelangelo.models.tsal.tsal_base import Latent2MeshOutput | |
from michelangelo.models.tsal.inference_utils import extract_geometry | |
from michelangelo.utils.misc import get_config_from_file, instantiate_from_config | |
from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer | |
from michelangelo.utils.visualizers import html_util | |
def load_model(args): | |
model_config = get_config_from_file(args.config_path) | |
if hasattr(model_config, "model"): | |
model_config = model_config.model | |
model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path) | |
model = model.cuda() | |
model = model.eval() | |
return model | |
def load_surface(fp): | |
with np.load(args.pointcloud_path) as input_pc: | |
surface = input_pc['points'] | |
normal = input_pc['normals'] | |
rng = np.random.default_rng() | |
ind = rng.choice(surface.shape[0], 4096, replace=False) | |
surface = torch.FloatTensor(surface[ind]) | |
normal = torch.FloatTensor(normal[ind]) | |
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() | |
return surface | |
def prepare_image(args, number_samples=2): | |
image = cv2.imread(f"{args.image_path}") | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_pt = torch.tensor(image).float() | |
image_pt = image_pt / 255 * 2 - 1 | |
image_pt = rearrange(image_pt, "h w c -> c h w") | |
image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples) | |
return image_pt | |
def save_output(args, mesh_outputs): | |
os.makedirs(args.output_dir, exist_ok=True) | |
for i, mesh in enumerate(mesh_outputs): | |
mesh.mesh_f = mesh.mesh_f[:, ::-1] | |
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) | |
name = str(i) + "_out_mesh.obj" | |
mesh_output.export(os.path.join(args.output_dir, name), include_normals=True) | |
print(f'-----------------------------------------------------------------------------') | |
print(f'>>> Finished and mesh saved in {args.output_dir}') | |
print(f'-----------------------------------------------------------------------------') | |
return 0 | |
def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): | |
surface = load_surface(args.pointcloud_path) | |
# encoding | |
shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) | |
shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) | |
# decoding | |
latents = model.model.shape_model.decode(shape_zq) | |
geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) | |
# reconstruction | |
mesh_v_f, has_surface = extract_geometry( | |
geometric_func=geometric_func, | |
device=surface.device, | |
batch_size=surface.shape[0], | |
bounds=bounds, | |
octree_depth=octree_depth, | |
num_chunks=num_chunks, | |
) | |
recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) | |
# save | |
os.makedirs(args.output_dir, exist_ok=True) | |
recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj')) | |
print(f'-----------------------------------------------------------------------------') | |
print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}') | |
print(f'-----------------------------------------------------------------------------') | |
return 0 | |
def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7): | |
sample_inputs = { | |
"image": prepare_image(args) | |
} | |
mesh_outputs = model.sample( | |
sample_inputs, | |
sample_times=1, | |
guidance_scale=guidance_scale, | |
return_intermediates=False, | |
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], | |
octree_depth=octree_depth, | |
)[0] | |
save_output(args, mesh_outputs) | |
return 0 | |
def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7): | |
sample_inputs = { | |
"text": [args.text] * num_samples | |
} | |
mesh_outputs = model.sample( | |
sample_inputs, | |
sample_times=1, | |
guidance_scale=guidance_scale, | |
return_intermediates=False, | |
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], | |
octree_depth=octree_depth, | |
)[0] | |
save_output(args, mesh_outputs) | |
return 0 | |
task_dick = { | |
'reconstruction': reconstruction, | |
'image2mesh': image2mesh, | |
'text2mesh': text2mesh, | |
} | |
if __name__ == "__main__": | |
''' | |
1. Reconstruct point cloud | |
2. Image-conditioned generation | |
3. Text-conditioned generation | |
''' | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True) | |
parser.add_argument("--config_path", type=str, required=True) | |
parser.add_argument("--ckpt_path", type=str, required=True) | |
parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') | |
parser.add_argument("--image_path", type=str, help='Path to the input image') | |
parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') | |
parser.add_argument("--output_dir", type=str, default='./output') | |
parser.add_argument("-s", "--seed", type=int, default=0) | |
args = parser.parse_args() | |
pl.seed_everything(args.seed) | |
print(f'-----------------------------------------------------------------------------') | |
print(f'>>> Running {args.task}') | |
args.output_dir = os.path.join(args.output_dir, args.task) | |
print(f'>>> Output directory: {args.output_dir}') | |
print(f'-----------------------------------------------------------------------------') | |
task_dick[args.task](args, load_model(args)) |