PartCrafter / scripts /inference_partcrafter.py
alexnasa's picture
Upload 85 files
bef5729 verified
import argparse
import os
import sys
from glob import glob
import time
from typing import Any, Union
import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image
from accelerate.utils import set_seed
from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
from src.utils.image_utils import prepare_image
from src.models.briarmbg import BriaRMBG
@torch.no_grad()
def run_triposg(
pipe: Any,
image_input: Union[str, Image.Image],
num_parts: int,
rmbg_net: Any,
seed: int,
num_tokens: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
max_num_expanded_coords: int = 1e9,
use_flash_decoder: bool = False,
rmbg: bool = False,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
) -> trimesh.Scene:
if rmbg:
img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
else:
img_pil = Image.open(image_input)
start_time = time.time()
outputs = pipe(
image=[img_pil] * num_parts,
attention_kwargs={"num_parts": num_parts},
num_tokens=num_tokens,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
max_num_expanded_coords=max_num_expanded_coords,
use_flash_decoder=use_flash_decoder,
).meshes
end_time = time.time()
print(f"Time elapsed: {end_time - start_time:.2f} seconds")
for i in range(len(outputs)):
if outputs[i] is None:
# If the generated mesh is None (decoing error), use a dummy mesh
outputs[i] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
return outputs, img_pil
MAX_NUM_PARTS = 16
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
parser = argparse.ArgumentParser()
parser.add_argument("--image_path", type=str, required=True)
parser.add_argument("--num_parts", type=int, required=True, help="number of parts to generate")
parser.add_argument("--output_dir", type=str, default="./results")
parser.add_argument("--tag", type=str, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_tokens", type=int, default=1024)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=7.0)
parser.add_argument("--max_num_expanded_coords", type=int, default=1e9)
parser.add_argument("--use_flash_decoder", action="store_true")
parser.add_argument("--rmbg", action="store_true")
parser.add_argument("--render", action="store_true")
args = parser.parse_args()
assert 1 <= args.num_parts <= MAX_NUM_PARTS, f"num_parts must be in [1, {MAX_NUM_PARTS}]"
# download pretrained weights
partcrafter_weights_dir = "pretrained_weights/PartCrafter"
rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
# init rmbg model for background removal
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
rmbg_net.eval()
# init tripoSG pipeline
pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(device, dtype)
set_seed(args.seed)
# run inference
outputs, processed_image = run_triposg(
pipe,
image_input=args.image_path,
num_parts=args.num_parts,
rmbg_net=rmbg_net,
seed=args.seed,
num_tokens=args.num_tokens,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
max_num_expanded_coords=args.max_num_expanded_coords,
use_flash_decoder=args.use_flash_decoder,
rmbg=args.rmbg,
dtype=dtype,
device=device,
)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.tag is None:
args.tag = time.strftime("%Y%m%d_%H_%M_%S")
export_dir = os.path.join(args.output_dir, args.tag)
os.makedirs(export_dir, exist_ok=True)
for i, mesh in enumerate(outputs):
mesh.export(os.path.join(export_dir, f"part_{i:02}.glb"))
merged_mesh = get_colored_mesh_composition(outputs)
merged_mesh.export(os.path.join(export_dir, "object.glb"))
print(f"Generated {len(outputs)} parts and saved to {export_dir}")
if args.render:
print("Start rendering...")
num_views = 36
radius = 4
fps = 18
rendered_images = render_views_around_mesh(
merged_mesh,
num_views=num_views,
radius=radius,
)
rendered_normals = render_normal_views_around_mesh(
merged_mesh,
num_views=num_views,
radius=radius,
)
rendered_grids = make_grid_for_images_or_videos(
[
[processed_image] * num_views,
rendered_images,
rendered_normals,
],
nrow=3
)
export_renderings(
rendered_images,
os.path.join(export_dir, "rendering.gif"),
fps=fps,
)
export_renderings(
rendered_normals,
os.path.join(export_dir, "rendering_normal.gif"),
fps=fps,
)
export_renderings(
rendered_grids,
os.path.join(export_dir, "rendering_grid.gif"),
fps=fps,
)
rendered_image, rendered_normal, rendered_grid = rendered_images[0], rendered_normals[0], rendered_grids[0]
rendered_image.save(os.path.join(export_dir, "rendering.png"))
rendered_normal.save(os.path.join(export_dir, "rendering_normal.png"))
rendered_grid.save(os.path.join(export_dir, "rendering_grid.png"))
print("Rendering done.")