File size: 5,745 Bytes
daa6779
 
 
 
 
 
 
 
 
5eaf820
daa6779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cb3552
 
daa6779
 
3cb3552
daa6779
3cb3552
917d511
4c2da87
 
 
 
 
 
 
b703ea2
 
 
02a7da8
3cb3552
daa6779
3cb3552
098a71e
917d511
3cb3552
 
 
 
 
 
 
 
 
917d511
ab069cf
daa6779
3cb3552
daa6779
 
 
8ba0984
daa6779
 
 
 
 
 
 
 
 
 
 
 
8ba0984
 
 
daa6779
 
 
8ba0984
daa6779
 
 
 
 
8ba0984
daa6779
 
 
 
 
 
 
 
 
 
 
5eaf820
3cb3552
daa6779
5eaf820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import numpy as np
import cv2
import kiui
import trimesh
import torch
import rembg
from datetime import datetime
import subprocess
import argparse

try:
    # running on Hugging Face Spaces
    import spaces

except ImportError:
    # running locally, use a dummy space
    class spaces:
        class GPU:
            def __init__(self, duration=60):
                self.duration = duration
            def __call__(self, func):
                return func


from flow.model import Model
from flow.configs.schema import ModelConfig
from flow.utils import get_random_color, recenter_foreground
from vae.utils import postprocess_mesh

# download checkpoints
from huggingface_hub import hf_hub_download
flow_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="flow.pt")
vae_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="vae.pt")

TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
MAX_SEED = np.iinfo(np.int32).max
bg_remover = rembg.new_session()

# model config
model_config = ModelConfig(
    vae_conf="vae.configs.part_woenc",
    vae_ckpt_path=vae_ckpt_path,
    qknorm=True,
    qknorm_type="RMSNorm",
    use_pos_embed=False,
    dino_model="dinov2_vitg14",
    hidden_dim=1536,
    flow_shift=3.0,
    logitnorm_mean=1.0,
    logitnorm_std=1.0,
    latent_size=4096,
    use_parts=True,
)

# instantiate model
model = Model(model_config).eval().cuda().bfloat16()

# load weight
ckpt_dict = torch.load(flow_ckpt_path, weights_only=True)
model.load_state_dict(ckpt_dict, strict=True)

# get random seed
def get_random_seed(randomize_seed, seed):
    if randomize_seed:
        seed = np.random.randint(0, MAX_SEED)
    return seed

# process image
@spaces.GPU(duration=10)
def process_image(image_path):
    image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    if image.shape[-1] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
    else:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # bg removal if there is no alpha channel
        image = rembg.remove(image, session=bg_remover)  # [H, W, 4]
    mask = image[..., -1] > 0
    image = recenter_foreground(image, mask, border_ratio=0.1)
    image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_AREA)
    return image

# process generation
@spaces.GPU(duration=90)
def process_3d(input_image, num_steps=50, cfg_scale=7, grid_res=384, seed=42, simplify_mesh=False, target_num_faces=100000):

    # seed
    kiui.seed_everything(seed)

    # output path
    os.makedirs("output", exist_ok=True)
    output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"

    # input image (assume processed to RGBA uint8)
    image = input_image.astype(np.float32) / 255.0
    image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])  # white background
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()

    data = {"cond_images": image_tensor}

    with torch.inference_mode():
        results = model(data, num_steps=num_steps, cfg_scale=cfg_scale)

    latent = results["latent"]

    # query mesh

    data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
    data_part1 = {"latent": latent[:, model.config.latent_size :, :]}

    with torch.inference_mode():
        results_part0 = model.vae(data_part0, resolution=grid_res)
        results_part1 = model.vae(data_part1, resolution=grid_res)

    if not simplify_mesh:
        target_num_faces = -1

    vertices, faces = results_part0["meshes"][0]
    mesh_part0 = trimesh.Trimesh(vertices, faces)
    mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
    mesh_part0 = postprocess_mesh(mesh_part0, target_num_faces)
    parts = mesh_part0.split(only_watertight=False)

    vertices, faces = results_part1["meshes"][0]
    mesh_part1 = trimesh.Trimesh(vertices, faces)
    mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
    mesh_part1 = postprocess_mesh(mesh_part1, target_num_faces)
    parts.extend(mesh_part1.split(only_watertight=False))

    # split connected components and assign different colors
    for j, part in enumerate(parts):
        # each component uses a random color
        part.visual.vertex_colors = get_random_color(j, use_float=True)

    mesh = trimesh.Scene(parts)
    # export the whole mesh
    mesh.export(output_glb_path)

    print(f"Saved 3D model to {output_glb_path}")
    return output_glb_path

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_path', type=str, default='examples/robot.png', help='Path to input image')
    parser.add_argument('--num_steps', type=int, default=50, help='Inference steps')
    parser.add_argument('--cfg_scale', type=float, default=7.0, help='CFG scale')
    parser.add_argument('--grid_res', type=int, default=384, help='Grid resolution')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--randomize_seed', action='store_true', help='Randomize seed')
    parser.add_argument('--simplify_mesh', action='store_true', help='Simplify mesh')
    parser.add_argument('--target_num_faces', type=int, default=100000, help='Target number of faces for mesh simplification')
    
    args = parser.parse_args()

    if args.randomize_seed:
        args.seed = get_random_seed(args.randomize_seed, args.seed)

    processed_image = process_image(args.image_path)
    process_3d(
        input_image=processed_image,
        num_steps=args.num_steps,
        cfg_scale=args.cfg_scale,
        grid_res=args.grid_res,
        seed=args.seed,
        simplify_mesh=args.simplify_mesh,
        target_num_faces=args.target_num_faces
    )