Akash Garg
handling zerogpu in prompt handling
0964e01
raw
history blame
4.42 kB
import spaces
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
from pathlib import Path
pathdir = Path(__file__).parent / 'cube'
sys.path.append(pathdir.as_posix())
# print(__file__)
# print(os.listdir())
# print(os.listdir('cube'))
# print(pathdir.as_posix())
from cube3d.inference.engine import EngineFast, Engine
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download
GLOBAL_STATE = {}
def gen_save_folder(max_size=200):
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]
if len(dirs) >= max_size:
oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
shutil.rmtree(oldest_dir)
print(f"Removed the oldest folder: {oldest_dir}")
new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
os.makedirs(new_folder, exist_ok=True)
print(f"Created new folder: {new_folder}")
return new_folder
@spaces.GPU
def handle_text_prompt(input_prompt, variance = 0):
print(f"prompt: {input_prompt}, variance: {variance}")
if "engine_fast" not in GLOBAL_STATE:
config_path = GLOBAL_STATE["config_path"]
gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast(
config_path,
gpt_ckpt_path,
shape_ckpt_path,
device=torch.device("cuda"),
)
GLOBAL_STATE["engine_fast"] = engine_fast
top_p = None if variance == 0 else (100 - variance) / 100.0
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
# save output
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
save_folder = gen_save_folder()
output_path = os.path.join(save_folder, "output.glb")
trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
return output_path
def build_interface():
"""Build UI for gradio app
"""
title = "Cube 3D"
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
gr.Markdown(
f"""
# {title}
# Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
"""
)
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
input_text_box = gr.Textbox(
value=None,
label="Prompt",
lines=2,
)
variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
with gr.Row():
submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
submit_button.click(
handle_text_prompt,
inputs=[
input_text_box,
variance
],
outputs=[
model3d
]
)
return interface
def generate(args):
GLOBAL_STATE["config_path"] = args.config_path
GLOBAL_STATE["SAVE_DIR"] = args.save_dir
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
demo = build_interface()
demo.queue(default_concurrency_limit=1)
demo.launch()
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
type=str,
help="Path to the config file",
default="cube/cube3d/configs/open_model.yaml",
)
parser.add_argument(
"--gpt_ckpt_path",
type=str,
help="Path to the gpt ckpt path",
default="model_weights/shape_gpt.safetensors",
)
parser.add_argument(
"--shape_ckpt_path",
type=str,
help="Path to the shape ckpt path",
default="model_weights/shape_tokenizer.safetensors",
)
parser.add_argument(
"--save_dir",
type=str,
default="gradio_save_dir",
)
args = parser.parse_args()
snapshot_download(
repo_id="Roblox/cube3d-v0.1",
local_dir="./model_weights"
)
generate(args)