sp3d / gradio_app.py
rgndgn's picture
r
1a20c54 verified
raw
history blame
3.93 kB
import os
import random
import tempfile
import time
import zipfile
from contextlib import nullcontext
from functools import lru_cache
from typing import Any
import cv2
import gradio as gr
import numpy as np
import torch
import trimesh
from gradio_litmodel3d import LitModel3D
from gradio_pointcloudeditor import PointCloudEditor
from PIL import Image
from transparent_background import Remover
os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
os.system("pip install ./deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl")
import spar3d.utils as spar3d_utils
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
from spar3d.system import SPAR3D
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
bg_remover = Remover() # default setting
COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 2.2
COND_FOVY = 0.591627
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
# Cached. Doesn't change
c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
COND_FOVY, COND_HEIGHT, COND_WIDTH
)
generated_files = []
# Delete previous gradio temp dir folder
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
import shutil
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
device = spar3d_utils.get_device()
model = SPAR3D.from_pretrained(
"stabilityai/stable-point-aware-3d",
config_name="config.yaml",
weight_name="model.safetensors",
)
model.eval()
model = model.to(device)
example_files = [
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
]
def auto_process(input_image):
if input_image is None:
return None, None, None, None
# Default values
guidance_scale = 3.0
random_seed = 0
foreground_ratio = 1.3
remesh_option = "None"
vertex_count_type = "Keep Vertex Count"
vertex_count = 2000
texture_resolution = 1024
no_crop = False
pc_cond = None
# First step: Remove background
rem_removed = remove_background(input_image)
fr_res = spar3d_utils.foreground_crop(
rem_removed,
crop_ratio=foreground_ratio,
newsize=(COND_WIDTH, COND_HEIGHT),
no_crop=no_crop,
)
# Second step: Run model
glb_file, pc_file, illumination_file, pc_list = process_model_run(
fr_res,
guidance_scale,
random_seed,
pc_cond,
remesh_option,
vertex_count_type,
vertex_count,
texture_resolution,
)
zip_file = create_zip_file(glb_file, pc_file, illumination_file)
return glb_file, illumination_file, zip_file, pc_list
# Simplified interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
Upload an image to generate a 3D model.
"""
)
with gr.Row():
with gr.Column():
input_img = gr.Image(
type="pil",
label="Upload Image",
sources=["upload", "click"],
image_mode="RGBA"
)
with gr.Column():
output_3d = LitModel3D(
label="3D Model",
clear_color=[0.0, 0.0, 0.0, 0.0],
tonemapping="aces",
contrast=1.0,
scale=1.0,
)
download_all_btn = gr.File(
label="Download Model (ZIP)",
file_count="single",
visible=True
)
input_img.change(
auto_process,
inputs=[input_img],
outputs=[
output_3d,
gr.State(), # for illumination file
download_all_btn,
gr.State(), # for point cloud list
],
)
demo.queue().launch(share=False)