import argparse import os import sys from glob import glob import time from typing import Any, Union import numpy as np import torch import trimesh from huggingface_hub import snapshot_download from PIL import Image from accelerate.utils import set_seed from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings from src.pipelines.pipeline_partcrafter import PartCrafterPipeline from src.utils.image_utils import prepare_image from src.models.briarmbg import BriaRMBG @torch.no_grad() def run_triposg( pipe: Any, image_input: Union[str, Image.Image], num_parts: int, rmbg_net: Any, seed: int, num_tokens: int = 1024, num_inference_steps: int = 50, guidance_scale: float = 7.0, max_num_expanded_coords: int = 1e9, use_flash_decoder: bool = False, rmbg: bool = False, dtype: torch.dtype = torch.float16, device: str = "cuda", ) -> trimesh.Scene: if rmbg: img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net) else: img_pil = Image.open(image_input) start_time = time.time() outputs = pipe( image=[img_pil] * num_parts, attention_kwargs={"num_parts": num_parts}, num_tokens=num_tokens, generator=torch.Generator(device=pipe.device).manual_seed(seed), num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, max_num_expanded_coords=max_num_expanded_coords, use_flash_decoder=use_flash_decoder, ).meshes end_time = time.time() print(f"Time elapsed: {end_time - start_time:.2f} seconds") for i in range(len(outputs)): if outputs[i] is None: # If the generated mesh is None (decoing error), use a dummy mesh outputs[i] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]]) return outputs, img_pil MAX_NUM_PARTS = 16 if __name__ == "__main__": device = "cuda" dtype = torch.float16 parser = argparse.ArgumentParser() parser.add_argument("--image_path", type=str, required=True) parser.add_argument("--num_parts", type=int, required=True, help="number of parts to generate") parser.add_argument("--output_dir", type=str, default="./results") parser.add_argument("--tag", type=str, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--num_tokens", type=int, default=1024) parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=7.0) parser.add_argument("--max_num_expanded_coords", type=int, default=1e9) parser.add_argument("--use_flash_decoder", action="store_true") parser.add_argument("--rmbg", action="store_true") parser.add_argument("--render", action="store_true") args = parser.parse_args() assert 1 <= args.num_parts <= MAX_NUM_PARTS, f"num_parts must be in [1, {MAX_NUM_PARTS}]" # download pretrained weights partcrafter_weights_dir = "pretrained_weights/PartCrafter" rmbg_weights_dir = "pretrained_weights/RMBG-1.4" snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir) snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir) # init rmbg model for background removal rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device) rmbg_net.eval() # init tripoSG pipeline pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(device, dtype) set_seed(args.seed) # run inference outputs, processed_image = run_triposg( pipe, image_input=args.image_path, num_parts=args.num_parts, rmbg_net=rmbg_net, seed=args.seed, num_tokens=args.num_tokens, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, max_num_expanded_coords=args.max_num_expanded_coords, use_flash_decoder=args.use_flash_decoder, rmbg=args.rmbg, dtype=dtype, device=device, ) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if args.tag is None: args.tag = time.strftime("%Y%m%d_%H_%M_%S") export_dir = os.path.join(args.output_dir, args.tag) os.makedirs(export_dir, exist_ok=True) for i, mesh in enumerate(outputs): mesh.export(os.path.join(export_dir, f"part_{i:02}.glb")) merged_mesh = get_colored_mesh_composition(outputs) merged_mesh.export(os.path.join(export_dir, "object.glb")) print(f"Generated {len(outputs)} parts and saved to {export_dir}") if args.render: print("Start rendering...") num_views = 36 radius = 4 fps = 18 rendered_images = render_views_around_mesh( merged_mesh, num_views=num_views, radius=radius, ) rendered_normals = render_normal_views_around_mesh( merged_mesh, num_views=num_views, radius=radius, ) rendered_grids = make_grid_for_images_or_videos( [ [processed_image] * num_views, rendered_images, rendered_normals, ], nrow=3 ) export_renderings( rendered_images, os.path.join(export_dir, "rendering.gif"), fps=fps, ) export_renderings( rendered_normals, os.path.join(export_dir, "rendering_normal.gif"), fps=fps, ) export_renderings( rendered_grids, os.path.join(export_dir, "rendering_grid.gif"), fps=fps, ) rendered_image, rendered_normal, rendered_grid = rendered_images[0], rendered_normals[0], rendered_grids[0] rendered_image.save(os.path.join(export_dir, "rendering.png")) rendered_normal.save(os.path.join(export_dir, "rendering_normal.png")) rendered_grid.save(os.path.join(export_dir, "rendering_grid.png")) print("Rendering done.")