File size: 4,723 Bytes
616f571
 
 
 
 
 
 
ce16420
616f571
 
 
 
 
 
 
 
ce16420
616f571
 
 
 
 
 
 
f6a2f50
ce16420
616f571
 
 
 
 
f6a2f50
ce16420
616f571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6a2f50
 
 
 
616f571
ce16420
 
 
 
 
 
 
 
616f571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce16420
616f571
 
 
 
 
 
 
 
 
 
 
 
 
 
ce16420
 
 
 
616f571
 
 
 
 
 
 
 
f6a2f50
ce16420
616f571
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import os

import torch
import trimesh

from cube3d.inference.engine import Engine, EngineFast
from cube3d.inference.utils import normalize_bbox, select_device
from cube3d.mesh_utils.postprocessing import (
    PYMESHLAB_AVAILABLE,
    create_pymeshset,
    postprocess_mesh,
    save_mesh,
)
from cube3d.renderer import renderer


def generate_mesh(
    engine,
    prompt,
    output_dir,
    output_name,
    resolution_base=8.0,
    disable_postprocess=False,
    top_p=None,
    bounding_box_xyz=None,
):
    mesh_v_f = engine.t2s(
        [prompt],
        use_kv_cache=True,
        resolution_base=resolution_base,
        top_p=top_p,
        bounding_box_xyz=bounding_box_xyz,
    )
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    obj_path = os.path.join(output_dir, f"{output_name}.obj")
    if PYMESHLAB_AVAILABLE:
        ms = create_pymeshset(vertices, faces)
        if not disable_postprocess:
            target_face_num = max(10000, int(faces.shape[0] * 0.1))
            print(f"Postprocessing mesh to {target_face_num} faces")
            postprocess_mesh(ms, target_face_num, obj_path)

        save_mesh(ms, obj_path)
    else:
        print(
            "WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing."
        )
        mesh = trimesh.Trimesh(vertices, faces)
        mesh.export(obj_path)

    return obj_path


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="cube shape generation script")
    parser.add_argument(
        "--config-path",
        type=str,
        default="cube3d/configs/open_model.yaml",
        help="Path to the configuration YAML file.",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="outputs/",
        help="Path to the output directory to store .obj and .gif files",
    )
    parser.add_argument(
        "--gpt-ckpt-path",
        type=str,
        required=True,
        help="Path to the main GPT checkpoint file.",
    )
    parser.add_argument(
        "--shape-ckpt-path",
        type=str,
        required=True,
        help="Path to the shape encoder/decoder checkpoint file.",
    )
    parser.add_argument(
        "--fast-inference",
        help="Use optimized inference",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        required=True,
        help="Text prompt for generating a 3D mesh",
    )
    parser.add_argument(
        "--top-p",
        type=float,
        default=None,
        help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
    )
    parser.add_argument(
        "--bounding_box_xyz",
        nargs=3,
        type=float,
        help="Three float values for x, y, z bounding box",
        default=None,
        required=False,
    )
    parser.add_argument(
        "--render-gif",
        help="Render a turntable gif of the mesh",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--disable-postprocessing",
        help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--resolution-base",
        type=float,
        default=8.0,
        help="Resolution base for the shape decoder.",
    )
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    device = select_device()
    print(f"Using device: {device}")
    # Initialize engine based on fast_inference flag
    if args.fast_inference:
        print(
            "Using cuda graphs, this will take some time to warmup and capture the graph."
        )
        engine = EngineFast(
            args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
        )
        print("Compiled the graph.")
    else:
        engine = Engine(
            args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
        )

    if args.bounding_box_xyz is not None:
        args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz))

    # Generate meshes based on input source
    obj_path = generate_mesh(
        engine,
        args.prompt,
        args.output_dir,
        "output",
        args.resolution_base,
        args.disable_postprocessing,
        args.top_p,
        args.bounding_box_xyz,
    )
    if args.render_gif:
        gif_path = renderer.render_turntable(obj_path, args.output_dir)
        print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`")
    print(f"Generated mesh for {args.prompt} at `{obj_path}`")