Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,268 Bytes
bef5729 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import argparse
import os
import sys
from glob import glob
import time
from typing import Any, Union
import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image
from accelerate.utils import set_seed
from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
from src.utils.image_utils import prepare_image
from src.models.briarmbg import BriaRMBG
@torch.no_grad()
def run_triposg(
pipe: Any,
image_input: Union[str, Image.Image],
num_parts: int,
rmbg_net: Any,
seed: int,
num_tokens: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
max_num_expanded_coords: int = 1e9,
use_flash_decoder: bool = False,
rmbg: bool = False,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
) -> trimesh.Scene:
if rmbg:
img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
else:
img_pil = Image.open(image_input)
start_time = time.time()
outputs = pipe(
image=[img_pil] * num_parts,
attention_kwargs={"num_parts": num_parts},
num_tokens=num_tokens,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
max_num_expanded_coords=max_num_expanded_coords,
use_flash_decoder=use_flash_decoder,
).meshes
end_time = time.time()
print(f"Time elapsed: {end_time - start_time:.2f} seconds")
for i in range(len(outputs)):
if outputs[i] is None:
# If the generated mesh is None (decoing error), use a dummy mesh
outputs[i] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
return outputs, img_pil
MAX_NUM_PARTS = 16
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
parser = argparse.ArgumentParser()
parser.add_argument("--image_path", type=str, required=True)
parser.add_argument("--num_parts", type=int, required=True, help="number of parts to generate")
parser.add_argument("--output_dir", type=str, default="./results")
parser.add_argument("--tag", type=str, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_tokens", type=int, default=1024)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=7.0)
parser.add_argument("--max_num_expanded_coords", type=int, default=1e9)
parser.add_argument("--use_flash_decoder", action="store_true")
parser.add_argument("--rmbg", action="store_true")
parser.add_argument("--render", action="store_true")
args = parser.parse_args()
assert 1 <= args.num_parts <= MAX_NUM_PARTS, f"num_parts must be in [1, {MAX_NUM_PARTS}]"
# download pretrained weights
partcrafter_weights_dir = "pretrained_weights/PartCrafter"
rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
# init rmbg model for background removal
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
rmbg_net.eval()
# init tripoSG pipeline
pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(device, dtype)
set_seed(args.seed)
# run inference
outputs, processed_image = run_triposg(
pipe,
image_input=args.image_path,
num_parts=args.num_parts,
rmbg_net=rmbg_net,
seed=args.seed,
num_tokens=args.num_tokens,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
max_num_expanded_coords=args.max_num_expanded_coords,
use_flash_decoder=args.use_flash_decoder,
rmbg=args.rmbg,
dtype=dtype,
device=device,
)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.tag is None:
args.tag = time.strftime("%Y%m%d_%H_%M_%S")
export_dir = os.path.join(args.output_dir, args.tag)
os.makedirs(export_dir, exist_ok=True)
for i, mesh in enumerate(outputs):
mesh.export(os.path.join(export_dir, f"part_{i:02}.glb"))
merged_mesh = get_colored_mesh_composition(outputs)
merged_mesh.export(os.path.join(export_dir, "object.glb"))
print(f"Generated {len(outputs)} parts and saved to {export_dir}")
if args.render:
print("Start rendering...")
num_views = 36
radius = 4
fps = 18
rendered_images = render_views_around_mesh(
merged_mesh,
num_views=num_views,
radius=radius,
)
rendered_normals = render_normal_views_around_mesh(
merged_mesh,
num_views=num_views,
radius=radius,
)
rendered_grids = make_grid_for_images_or_videos(
[
[processed_image] * num_views,
rendered_images,
rendered_normals,
],
nrow=3
)
export_renderings(
rendered_images,
os.path.join(export_dir, "rendering.gif"),
fps=fps,
)
export_renderings(
rendered_normals,
os.path.join(export_dir, "rendering_normal.gif"),
fps=fps,
)
export_renderings(
rendered_grids,
os.path.join(export_dir, "rendering_grid.gif"),
fps=fps,
)
rendered_image, rendered_normal, rendered_grid = rendered_images[0], rendered_normals[0], rendered_grids[0]
rendered_image.save(os.path.join(export_dir, "rendering.png"))
rendered_normal.save(os.path.join(export_dir, "rendering_normal.png"))
rendered_grid.save(os.path.join(export_dir, "rendering_grid.png"))
print("Rendering done.")
|