|
import os |
|
import pathlib |
|
import tempfile |
|
|
|
os.environ["SPCONV_ALGO"] = "native" |
|
|
|
import gradio as gr |
|
import imageio |
|
import numpy as np |
|
import spaces |
|
import torch |
|
from easydict import EasyDict |
|
from PIL import Image |
|
|
|
from trellis.pipelines import TrellisImageTo3DPipeline |
|
from trellis.representations import Gaussian, MeshExtractResult |
|
from trellis.utils import postprocessing_utils, render_utils |
|
|
|
DESCRIPTION = """\ |
|
# Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/) |
|
|
|
- Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. |
|
- If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. |
|
""" |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
TEMP_DIR = gr.utils.get_upload_folder() |
|
|
|
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") |
|
pipeline.cuda() |
|
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) |
|
|
|
|
|
def preprocess_image(image: Image.Image) -> Image.Image: |
|
"""Preprocess the input image for 3D model generation. |
|
|
|
This function performs several preprocessing steps to prepare the image for 3D model generation: |
|
1. Handles alpha channel or removes background if not present |
|
2. Centers and crops the object |
|
3. Normalizes the image size to 518x518 pixels |
|
4. Applies proper alpha channel processing |
|
|
|
Args: |
|
image (Image.Image): The input image to be preprocessed. Can be either RGB or RGBA format. |
|
|
|
Returns: |
|
Image.Image: The preprocessed image with the following characteristics: |
|
- Size: 518x518 pixels |
|
- Format: RGBA |
|
- Background: Removed |
|
- Object: Centered and properly scaled |
|
|
|
Raises: |
|
None: This function does not raise any exceptions. |
|
|
|
Note: |
|
The preprocessing is handled by the pipeline's internal preprocessing function, |
|
which uses rembg for background removal if needed. |
|
""" |
|
return pipeline.preprocess_image(image) |
|
|
|
|
|
def save_state_to_file(gs: Gaussian, mesh: MeshExtractResult, output_path: str) -> None: |
|
state = { |
|
"gaussian": { |
|
**gs.init_params, |
|
"_xyz": gs._xyz, |
|
"_features_dc": gs._features_dc, |
|
"_scaling": gs._scaling, |
|
"_rotation": gs._rotation, |
|
"_opacity": gs._opacity, |
|
}, |
|
"mesh": { |
|
"vertices": mesh.vertices, |
|
"faces": mesh.faces, |
|
}, |
|
} |
|
torch.save(state, output_path) |
|
|
|
|
|
def load_state_from_file(state_path: str) -> tuple[Gaussian, EasyDict]: |
|
state = torch.load(state_path) |
|
gs = Gaussian( |
|
aabb=state["gaussian"]["aabb"], |
|
sh_degree=state["gaussian"]["sh_degree"], |
|
mininum_kernel_size=state["gaussian"]["mininum_kernel_size"], |
|
scaling_bias=state["gaussian"]["scaling_bias"], |
|
opacity_bias=state["gaussian"]["opacity_bias"], |
|
scaling_activation=state["gaussian"]["scaling_activation"], |
|
) |
|
gs._xyz = state["gaussian"]["_xyz"] |
|
gs._features_dc = state["gaussian"]["_features_dc"] |
|
gs._scaling = state["gaussian"]["_scaling"] |
|
gs._rotation = state["gaussian"]["_rotation"] |
|
gs._opacity = state["gaussian"]["_opacity"] |
|
|
|
mesh = EasyDict( |
|
vertices=state["mesh"]["vertices"], |
|
faces=state["mesh"]["faces"], |
|
) |
|
|
|
return gs, mesh |
|
|
|
|
|
def get_seed(randomize_seed: bool, seed: int) -> int: |
|
"""Determine and return the random seed to use for model generation or sampling. |
|
|
|
- MAX_SEED is the maximum value for a 32-bit integer (np.iinfo(np.int32).max). |
|
- This function is typically used to ensure reproducibility or to introduce randomness in model generation. |
|
- The random seed affects the stochastic processes in downstream model inference or sampling. |
|
|
|
Args: |
|
randomize_seed (bool): If True, a random seed (an integer in [0, MAX_SEED)) is generated using NumPy's default random number generator. If False, the provided seed argument is returned as-is. |
|
seed (int): The seed value to use if randomize_seed is False. |
|
|
|
Returns: |
|
int: The selected seed value. If randomize_seed is True, a randomly generated integer; otherwise, the value of the seed argument. |
|
""" |
|
rng = np.random.default_rng() |
|
return int(rng.integers(0, MAX_SEED)) if randomize_seed else seed |
|
|
|
|
|
@spaces.GPU |
|
def image_to_3d( |
|
image: Image.Image, |
|
seed: int, |
|
ss_guidance_strength: float, |
|
ss_sampling_steps: int, |
|
slat_guidance_strength: float, |
|
slat_sampling_steps: int, |
|
) -> tuple[str, str]: |
|
"""Convert an image to a 3D model. |
|
|
|
This function takes an input image and generates a 3D model using a two-stage process |
|
with separate parameters for each stage. It also generates a preview video that combines |
|
color and normal map renderings of the 3D model. |
|
|
|
Args: |
|
image (Image.Image): The input image. |
|
seed (int): The random seed. |
|
ss_guidance_strength (float): The guidance strength for sparse structure generation. |
|
ss_sampling_steps (int): The number of sampling steps for sparse structure generation. |
|
slat_guidance_strength (float): The guidance strength for structured latent generation. |
|
slat_sampling_steps (int): The number of sampling steps for structured latent generation. |
|
|
|
Returns: |
|
tuple[str, str]: A tuple containing: |
|
- str: Path to the state file (.pth) containing the 3D model data |
|
- str: Path to the preview video file (.mp4) showing the 3D model rotation |
|
|
|
Note: |
|
The generated files are saved as temporary files that will not be automatically |
|
deleted. It is the caller's responsibility to manage these files. |
|
""" |
|
outputs = pipeline.run( |
|
image, |
|
seed=seed, |
|
formats=["gaussian", "mesh"], |
|
preprocess_image=False, |
|
sparse_structure_sampler_params={ |
|
"steps": ss_sampling_steps, |
|
"cfg_strength": ss_guidance_strength, |
|
}, |
|
slat_sampler_params={ |
|
"steps": slat_sampling_steps, |
|
"cfg_strength": slat_guidance_strength, |
|
}, |
|
) |
|
|
|
video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"] |
|
video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"] |
|
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] |
|
|
|
with ( |
|
tempfile.NamedTemporaryFile(suffix=".pth", dir=TEMP_DIR, delete=False) as state_file, |
|
tempfile.NamedTemporaryFile(suffix=".mp4", dir=TEMP_DIR, delete=False) as video_file, |
|
): |
|
save_state_to_file(outputs["gaussian"][0], outputs["mesh"][0], state_file.name) |
|
torch.cuda.empty_cache() |
|
imageio.mimsave(video_file.name, video, fps=15) |
|
return state_file.name, video_file.name |
|
|
|
|
|
@spaces.GPU(duration=90) |
|
def extract_glb( |
|
state_path: str, |
|
mesh_simplify: float, |
|
texture_size: int, |
|
) -> str: |
|
"""Extract a GLB file from the 3D model. |
|
|
|
Args: |
|
state_path (str): The path to the pickle file that contains the state of the generated 3D model. |
|
mesh_simplify (float): The mesh simplification factor. |
|
texture_size (int): The texture resolution. |
|
|
|
Returns: |
|
str: The path to the extracted GLB file. |
|
""" |
|
gs, mesh = load_state_from_file(state_path) |
|
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) |
|
torch.cuda.empty_cache() |
|
with tempfile.NamedTemporaryFile(suffix=".glb", dir=TEMP_DIR, delete=False) as glb_file: |
|
glb.export(glb_file.name) |
|
return glb_file.name |
|
|
|
|
|
@spaces.GPU |
|
def extract_gaussian(state_path: str) -> str: |
|
"""Extract a Gaussian file from the 3D model. |
|
|
|
Args: |
|
state_path (str): The path to the pickle file that contains the state of the generated 3D model. |
|
|
|
Returns: |
|
str: The path to the extracted Gaussian file. |
|
""" |
|
gs, _ = load_state_from_file(state_path) |
|
with tempfile.NamedTemporaryFile(suffix=".ply", dir=TEMP_DIR, delete=False) as gaussian_file: |
|
gs.save_ply(gaussian_file.name) |
|
return gaussian_file.name |
|
|
|
|
|
with gr.Blocks(css_paths="style.css", delete_cache=(600, 600)) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_prompt = gr.Image( |
|
label="Image Prompt", |
|
format="png", |
|
image_mode="RGBA", |
|
type="pil", |
|
height=300, |
|
) |
|
|
|
with gr.Accordion(label="Generation Settings", open=False): |
|
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) |
|
gr.Markdown("Stage 1: Sparse Structure Generation") |
|
with gr.Row(): |
|
ss_guidance_strength = gr.Slider( |
|
label="Guidance Strength", minimum=0.0, maximum=10.0, step=0.1, value=7.5 |
|
) |
|
ss_sampling_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, step=1, value=12) |
|
gr.Markdown("Stage 2: Structured Latent Generation") |
|
with gr.Row(): |
|
slat_guidance_strength = gr.Slider( |
|
label="Guidance Strength", minimum=0.0, maximum=10.0, step=0.1, value=3.0 |
|
) |
|
slat_sampling_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, step=1, value=12) |
|
|
|
generate_btn = gr.Button("Generate") |
|
|
|
with gr.Accordion(label="GLB Extraction Settings", open=False): |
|
mesh_simplify = gr.Slider(label="Simplify", minimum=0.9, maximum=0.98, step=0.01, value=0.95) |
|
texture_size = gr.Slider(label="Texture Size", minimum=512, maximum=2048, step=512, value=1024) |
|
|
|
with gr.Row(): |
|
extract_glb_btn = gr.Button("Extract GLB", interactive=False) |
|
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) |
|
gr.Markdown(""" |
|
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* |
|
""") |
|
|
|
with gr.Column(): |
|
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) |
|
model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300) |
|
|
|
state_file = gr.File(visible=False) |
|
|
|
examples = gr.Examples( |
|
examples=sorted(pathlib.Path("assets/example_image").glob("*.png")), |
|
fn=preprocess_image, |
|
inputs=image_prompt, |
|
outputs=image_prompt, |
|
run_on_click=True, |
|
examples_per_page=64, |
|
) |
|
|
|
image_prompt.upload( |
|
fn=preprocess_image, |
|
inputs=image_prompt, |
|
outputs=image_prompt, |
|
) |
|
|
|
generate_btn.click( |
|
fn=get_seed, |
|
inputs=[randomize_seed, seed], |
|
outputs=seed, |
|
).then( |
|
fn=image_to_3d, |
|
inputs=[ |
|
image_prompt, |
|
seed, |
|
ss_guidance_strength, |
|
ss_sampling_steps, |
|
slat_guidance_strength, |
|
slat_sampling_steps, |
|
], |
|
outputs=[state_file, video_output], |
|
).then( |
|
fn=lambda: (gr.Button(interactive=True), gr.Button(interactive=True)), |
|
outputs=[extract_glb_btn, extract_gs_btn], |
|
api_name=False, |
|
) |
|
|
|
video_output.clear( |
|
fn=lambda: (gr.Button(interactive=False), gr.Button(interactive=False)), |
|
outputs=[extract_glb_btn, extract_gs_btn], |
|
api_name=False, |
|
) |
|
|
|
extract_glb_btn.click( |
|
fn=extract_glb, |
|
inputs=[state_file, mesh_simplify, texture_size], |
|
outputs=model_output, |
|
) |
|
extract_gs_btn.click( |
|
fn=extract_gaussian, |
|
inputs=state_file, |
|
outputs=model_output, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(mcp_server=True) |
|
|