Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import numpy as np | |
from typing import Any | |
from PIL import Image | |
from tqdm import tqdm | |
from omegaconf import OmegaConf | |
from huggingface_hub import hf_hub_download | |
from typing import Union, List, Optional | |
from direct3d_s2.modules import sparse as sp | |
from direct3d_s2.utils import ( | |
instantiate_from_config, | |
preprocess_image, | |
sort_block, | |
extract_tokens_and_coords, | |
normalize_mesh, | |
mesh2index, | |
) | |
class Direct3DS2Pipeline(object): | |
def __init__(self, | |
dense_vae, | |
dense_dit, | |
sparse_vae_512, | |
sparse_dit_512, | |
sparse_vae_1024, | |
sparse_dit_1024, | |
refiner, | |
dense_image_encoder, | |
sparse_image_encoder, | |
dense_scheduler, | |
sparse_scheduler_512, | |
sparse_scheduler_1024, | |
dtype=torch.float16, | |
): | |
self.dense_vae = dense_vae | |
self.dense_dit = dense_dit | |
self.sparse_vae_512 = sparse_vae_512 | |
self.sparse_dit_512 = sparse_dit_512 | |
self.sparse_vae_1024 = sparse_vae_1024 | |
self.sparse_dit_1024 = sparse_dit_1024 | |
self.refiner = refiner | |
self.dense_image_encoder = dense_image_encoder | |
self.sparse_image_encoder = sparse_image_encoder | |
self.dense_scheduler = dense_scheduler | |
self.sparse_scheduler_512 = sparse_scheduler_512 | |
self.sparse_scheduler_1024 = sparse_scheduler_1024 | |
self.dtype = dtype | |
def to(self, device): | |
self.device = torch.device(device) | |
self.dense_vae.to(device) | |
self.dense_dit.to(device) | |
self.sparse_vae_512.to(device) | |
self.sparse_dit_512.to(device) | |
self.sparse_vae_1024.to(device) | |
self.sparse_dit_1024.to(device) | |
self.refiner.to(device) | |
self.dense_image_encoder.to(device) | |
self.sparse_image_encoder.to(device) | |
def from_pretrained(cls, pipeline_path, subfolder="direct3d-s2-v-1-1"): | |
if os.path.isdir(pipeline_path): | |
config_path = os.path.join(pipeline_path, 'config.yaml') | |
model_dense_path = os.path.join(pipeline_path, 'model_dense.ckpt') | |
model_sparse_512_path = os.path.join(pipeline_path, 'model_sparse_512.ckpt') | |
model_sparse_1024_path = os.path.join(pipeline_path, 'model_sparse_1024.ckpt') | |
model_refiner_path = os.path.join(pipeline_path, 'model_refiner.ckpt') | |
else: | |
config_path = hf_hub_download( | |
repo_id=pipeline_path, | |
subfolder=subfolder, | |
filename="config.yaml", | |
repo_type="model" | |
) | |
model_dense_path = hf_hub_download( | |
repo_id=pipeline_path, | |
subfolder=subfolder, | |
filename="model_dense.ckpt", | |
repo_type="model" | |
) | |
model_sparse_512_path = hf_hub_download( | |
repo_id=pipeline_path, | |
subfolder=subfolder, | |
filename="model_sparse_512.ckpt", | |
repo_type="model" | |
) | |
model_sparse_1024_path = hf_hub_download( | |
repo_id=pipeline_path, | |
subfolder=subfolder, | |
filename="model_sparse_1024.ckpt", | |
repo_type="model" | |
) | |
model_refiner_path = hf_hub_download( | |
repo_id=pipeline_path, | |
subfolder=subfolder, | |
filename="model_refiner.ckpt", | |
repo_type="model" | |
) | |
cfg = OmegaConf.load(config_path) | |
state_dict_dense = torch.load(model_dense_path, map_location='cpu', weights_only=True) | |
dense_vae = instantiate_from_config(cfg.dense_vae) | |
dense_vae.load_state_dict(state_dict_dense["vae"], strict=True) | |
dense_vae.eval() | |
dense_dit = instantiate_from_config(cfg.dense_dit) | |
dense_dit.load_state_dict(state_dict_dense["dit"], strict=True) | |
dense_dit.eval() | |
state_dict_sparse_512 = torch.load(model_sparse_512_path, map_location='cpu', weights_only=True) | |
sparse_vae_512 = instantiate_from_config(cfg.sparse_vae_512) | |
sparse_vae_512.load_state_dict(state_dict_sparse_512["vae"], strict=True) | |
sparse_vae_512.eval() | |
sparse_dit_512 = instantiate_from_config(cfg.sparse_dit_512) | |
sparse_dit_512.load_state_dict(state_dict_sparse_512["dit"], strict=True) | |
sparse_dit_512.eval() | |
state_dict_sparse_1024 = torch.load(model_sparse_1024_path, map_location='cpu', weights_only=True) | |
sparse_vae_1024 = instantiate_from_config(cfg.sparse_vae_1024) | |
sparse_vae_1024.load_state_dict(state_dict_sparse_1024["vae"], strict=True) | |
sparse_vae_1024.eval() | |
sparse_dit_1024 = instantiate_from_config(cfg.sparse_dit_1024) | |
sparse_dit_1024.load_state_dict(state_dict_sparse_1024["dit"], strict=True) | |
sparse_dit_1024.eval() | |
state_dict_refiner = torch.load(model_refiner_path, map_location='cpu', weights_only=True) | |
refiner = instantiate_from_config(cfg.refiner) | |
refiner.load_state_dict(state_dict_refiner["refiner"], strict=True) | |
refiner.eval() | |
dense_image_encoder = instantiate_from_config(cfg.dense_image_encoder) | |
sparse_image_encoder = instantiate_from_config(cfg.sparse_image_encoder) | |
dense_scheduler = instantiate_from_config(cfg.dense_scheduler) | |
sparse_scheduler_512 = instantiate_from_config(cfg.sparse_scheduler_512) | |
sparse_scheduler_1024 = instantiate_from_config(cfg.sparse_scheduler_1024) | |
return cls( | |
dense_vae=dense_vae, | |
dense_dit=dense_dit, | |
sparse_vae_512=sparse_vae_512, | |
sparse_dit_512=sparse_dit_512, | |
sparse_vae_1024=sparse_vae_1024, | |
sparse_dit_1024=sparse_dit_1024, | |
dense_image_encoder=dense_image_encoder, | |
sparse_image_encoder=sparse_image_encoder, | |
dense_scheduler=dense_scheduler, | |
sparse_scheduler_512=sparse_scheduler_512, | |
sparse_scheduler_1024=sparse_scheduler_1024, | |
refiner=refiner, | |
) | |
def preprocess(self, image): | |
if image.mode == 'RGBA': | |
image = np.array(image) | |
else: | |
if getattr(self, 'birefnet_model', None) is None: | |
from direct3d_s2.utils import BiRefNet | |
self.birefnet_model = BiRefNet(self.device) | |
image = self.birefnet_model.run(image) | |
image = preprocess_image(image) | |
return image | |
def prepare_image(self, image: Union[str, List[str], Image.Image, List[Image.Image]]): | |
if not isinstance(image, list): | |
image = [image] | |
if isinstance(image[0], str): | |
image = [Image.open(img) for img in image] | |
image = [self.preprocess(img) for img in image] | |
image = torch.stack([img for img in image]).to(self.device) | |
return image | |
def encode_image(self, image: torch.Tensor, conditioner: Any, | |
do_classifier_free_guidance: bool = True, use_mask: bool = False): | |
if use_mask: | |
cond = conditioner(image[:, :3], image[:, 3:]) | |
else: | |
cond = conditioner(image[:, :3]) | |
if isinstance(cond, tuple): | |
cond, cond_mask = cond | |
cond, cond_coords = extract_tokens_and_coords(cond, cond_mask) | |
else: | |
cond_mask, cond_coords = None, None | |
if do_classifier_free_guidance: | |
uncond = torch.zeros_like(cond) | |
else: | |
uncond = None | |
if cond_coords is not None: | |
cond = sp.SparseTensor(cond, cond_coords.int()) | |
if uncond is not None: | |
uncond = sp.SparseTensor(uncond, cond_coords.int()) | |
return cond, uncond | |
def inference( | |
self, | |
image, | |
vae, | |
dit, | |
conditioner, | |
scheduler, | |
num_inference_steps: int = 30, | |
guidance_scale: int = 7.0, | |
generator: Optional[torch.Generator] = None, | |
latent_index: torch.Tensor = None, | |
mode: str = 'dense', # 'dense', 'sparse512' or 'sparse1024 | |
remove_interior: bool = False, | |
mc_threshold: float = 0.02): | |
do_classifier_free_guidance = guidance_scale > 0 | |
if mode == 'dense': | |
sparse_conditions = False | |
else: | |
sparse_conditions = dit.sparse_conditions | |
cond, uncond = self.encode_image(image, conditioner, | |
do_classifier_free_guidance, sparse_conditions) | |
batch_size = cond.shape[0] | |
if mode == 'dense': | |
latent_shape = (batch_size, *dit.latent_shape) | |
else: | |
latent_shape = (len(latent_index), dit.out_channels) | |
latents = torch.randn(latent_shape, dtype=self.dtype, device=self.device, generator=generator) | |
scheduler.set_timesteps(num_inference_steps, device=self.device) | |
timesteps = scheduler.timesteps | |
extra_step_kwargs = { | |
"generator": generator | |
} | |
for i, t in enumerate(tqdm(timesteps, desc=f"{mode} Sampling:")): | |
latent_model_input = latents | |
timestep_tensor = torch.tensor([t], dtype=latent_model_input.dtype, device=self.device) | |
if mode == 'dense': | |
x_input = latent_model_input | |
elif mode in ['sparse512', 'sparse1024']: | |
x_input = sp.SparseTensor(latent_model_input, latent_index.int()) | |
diffusion_inputs = { | |
"x": x_input, | |
"t": timestep_tensor, | |
"cond": cond, | |
} | |
noise_pred_cond = dit(**diffusion_inputs) | |
if mode != 'dense': | |
noise_pred_cond = noise_pred_cond.feats | |
if do_classifier_free_guidance: | |
diffusion_inputs["cond"] = uncond | |
noise_pred_uncond = dit(**diffusion_inputs) | |
if mode != 'dense': | |
noise_pred_uncond = noise_pred_uncond.feats | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
else: | |
noise_pred = noise_pred_cond | |
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
latents = 1. / vae.latents_scale * latents + vae.latents_shift | |
if mode != 'dense': | |
latents = sp.SparseTensor(latents, latent_index.int()) | |
decoder_inputs = { | |
"latents": latents, | |
"mc_threshold": mc_threshold, | |
} | |
if mode == 'dense': | |
decoder_inputs['return_index'] = True | |
elif remove_interior: | |
decoder_inputs['return_feat'] = True | |
if mode == 'sparse1024': | |
decoder_inputs['voxel_resolution'] = 1024 | |
outputs = vae.decode_mesh(**decoder_inputs) | |
if remove_interior: | |
del latents, noise_pred, noise_pred_cond, noise_pred_uncond, x_input, cond, uncond | |
torch.cuda.empty_cache() | |
outputs = self.refiner.run(*outputs, mc_threshold=mc_threshold*2.0) | |
return outputs | |
def __call__( | |
self, | |
image: Union[str, List[str], Image.Image, List[Image.Image]] = None, | |
sdf_resolution: int = 1024, | |
dense_sampler_params: dict = {'num_inference_steps': 50, 'guidance_scale': 7.0}, | |
sparse_512_sampler_params: dict = {'num_inference_steps': 30, 'guidance_scale': 7.0}, | |
sparse_1024_sampler_params: dict = {'num_inference_steps': 15, 'guidance_scale': 7.0}, | |
generator: Optional[torch.Generator] = None, | |
remesh: bool = False, | |
simplify_ratio: float = 0.95, | |
mc_threshold: float = 0.2): | |
image = self.prepare_image(image) | |
latent_index = self.inference(image, self.dense_vae, self.dense_dit, self.dense_image_encoder, | |
self.dense_scheduler, generator=generator, mode='dense', mc_threshold=0.1, **dense_sampler_params)[0] | |
latent_index = sort_block(latent_index, self.sparse_dit_512.selection_block_size) | |
torch.cuda.empty_cache() | |
if sdf_resolution == 512: | |
remove_interior = False | |
else: | |
remove_interior = True | |
mesh = self.inference(image, self.sparse_vae_512, self.sparse_dit_512, | |
self.sparse_image_encoder, self.sparse_scheduler_512, | |
generator=generator, mode='sparse512', | |
mc_threshold=mc_threshold, latent_index=latent_index, | |
remove_interior=remove_interior, **sparse_512_sampler_params)[0] | |
if sdf_resolution == 1024: | |
del latent_index | |
torch.cuda.empty_cache() | |
mesh = normalize_mesh(mesh) | |
latent_index = mesh2index(mesh, size=1024, factor=8) | |
latent_index = sort_block(latent_index, self.sparse_dit_1024.selection_block_size) | |
print(f"number of latent tokens: {len(latent_index)}") | |
mesh = self.inference(image, self.sparse_vae_1024, self.sparse_dit_1024, | |
self.sparse_image_encoder, self.sparse_scheduler_1024, | |
generator=generator, mode='sparse1024', | |
mc_threshold=mc_threshold, latent_index=latent_index, | |
**sparse_1024_sampler_params)[0] | |
if remesh: | |
import trimesh | |
from direct3d_s2.utils import postprocess_mesh | |
filled_mesh = postprocess_mesh( | |
vertices=mesh.vertices, | |
faces=mesh.faces, | |
simplify=True, | |
simplify_ratio=simplify_ratio, | |
verbose=True, | |
) | |
mesh = trimesh.Trimesh(filled_mesh[0], filled_mesh[1]) | |
outputs = {"mesh": mesh} | |
return outputs | |