captaincobb's picture
removing shared url on launch
978db53
raw
history blame
6.79 kB
import spaces
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
from pathlib import Path
import numpy as np
pathdir = Path(__file__).parent / 'cube'
sys.path.append(pathdir.as_posix())
# print(__file__)
# print(os.listdir())
# print(os.listdir('cube'))
# print(pathdir.as_posix())
from cube3d.inference.engine import EngineFast, Engine
from cube3d.inference.utils import normalize_bbox
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download
from cube3d.mesh_utils.postprocessing import (
PYMESHLAB_AVAILABLE,
create_pymeshset,
postprocess_mesh,
save_mesh,
)
GLOBAL_STATE = {}
def gen_save_folder(max_size=200):
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]
if len(dirs) >= max_size:
oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
shutil.rmtree(oldest_dir)
print(f"Removed the oldest folder: {oldest_dir}")
new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
os.makedirs(new_folder, exist_ok=True)
print(f"Created new folder: {new_folder}")
return new_folder
@spaces.GPU
def handle_text_prompt(input_prompt, use_bbox = True, bbox_x=1.0, bbox_y=1.0, bbox_z=1.0, hi_res=False):
print(f"prompt: {input_prompt}, use_bbox: {use_bbox}, bbox_x: {bbox_x}, bbox_y: {bbox_y}, bbox_z: {bbox_z}, hi_res: {hi_res}")
if "engine_fast" not in GLOBAL_STATE:
config_path = GLOBAL_STATE["config_path"]
gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast(
config_path,
gpt_ckpt_path,
shape_ckpt_path,
device=torch.device("cuda"),
)
GLOBAL_STATE["engine_fast"] = engine_fast
# Determine bounding box size based on option
bbox_size = None
if use_bbox:
bbox_size = [bbox_x, bbox_y, bbox_z]
# For "No Bounding Box", bbox_size remains None
normalized_bbox = normalize_bbox(bbox_size) if bbox_size is not None else None
resolution_base = 9.0 if hi_res else 8.0
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=resolution_base, bounding_box_xyz=normalized_bbox)
# save output
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
ms = create_pymeshset(vertices, faces)
target_face_num = max(10000, int(faces.shape[0] * 0.1))
print(f"Postprocessing mesh to {target_face_num} faces")
postprocess_mesh(ms, target_face_num)
mesh = ms.current_mesh()
vertices = mesh.vertex_matrix()
faces = mesh.face_matrix()
min_extents = np.min(mesh.vertex_matrix(), axis = 0)
max_extents = np.max(mesh.vertex_matrix(), axis = 0)
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
scene = trimesh.scene.Scene()
scene.add_geometry(mesh)
save_folder = gen_save_folder()
output_path = os.path.join(save_folder, "output.glb")
# trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
scene.export(output_path)
return output_path
def build_interface():
"""Build UI for gradio app
"""
title = "Cube 3D"
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
gr.Markdown(
f"""
# {title}
# Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
"""
)
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
input_text_box = gr.Textbox(
value=None,
label="Prompt",
lines=2,
)
use_bbox = gr.Checkbox(label="Use Bbox", value=False)
with gr.Group() as bbox_group:
bbox_x = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Length", interactive=False)
bbox_y = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Height", interactive=False)
bbox_z = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Depth", interactive=False)
# Enable/disable bbox sliders based on use_bbox checkbox
def toggle_bbox_interactivity(use_bbox):
return (
gr.Slider(interactive=use_bbox),
gr.Slider(interactive=use_bbox),
gr.Slider(interactive=use_bbox)
)
use_bbox.change(
toggle_bbox_interactivity,
inputs=[use_bbox],
outputs=[bbox_x, bbox_y, bbox_z]
)
hi_res = gr.Checkbox(label="Hi-Res", value=False)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
submit_button.click(
handle_text_prompt,
inputs=[
input_text_box,
use_bbox,
bbox_x,
bbox_y,
bbox_z,
hi_res
],
outputs=[
model3d
]
)
return interface
def generate(args):
GLOBAL_STATE["config_path"] = args.config_path
GLOBAL_STATE["SAVE_DIR"] = args.save_dir
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
demo = build_interface()
demo.queue(default_concurrency_limit=1)
demo.launch()
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
type=str,
help="Path to the config file",
default="cube/cube3d/configs/open_model_v0.5.yaml",
)
parser.add_argument(
"--gpt_ckpt_path",
type=str,
help="Path to the gpt ckpt path",
default="model_weights/shape_gpt.safetensors",
)
parser.add_argument(
"--shape_ckpt_path",
type=str,
help="Path to the shape ckpt path",
default="model_weights/shape_tokenizer.safetensors",
)
parser.add_argument(
"--save_dir",
type=str,
default="gradio_save_dir",
)
args = parser.parse_args()
snapshot_download(
repo_id="Roblox/cube3d-v0.5",
local_dir="./model_weights"
)
generate(args)