trellis-3d-api / app.py
Steven18's picture
fix
f040f37
raw
history blame
21.6 kB
import gradio as gr
from gradio.events import SelectData
import spaces
from gradio_litmodel3d import LitModel3D
import json
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
from pathlib import Path
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def to_pil_list(
multiimages: List[
Union[
Image.Image,
Tuple[Image.Image, str],
gr.File,
Tuple[gr.File, str],
str, # fallback: plain path
Path
]
]
) -> List[Image.Image]:
"""
Convert a heterogeneous `multiimages` list into a homogeneous
`List[Image.Image]`.
Accepts elements in any of the following forms:
• PIL.Image
• (PIL.Image, caption)
• gr.File (gr.File.name is the temp‑file path)
• (gr.File, caption)
• str / pathlib.Path (direct file path)
Returns:
List[Image.Image] -- guaranteed PIL images
"""
pil_imgs: List[Image.Image] = []
for item in multiimages:
# Unpack tuple/list, keep first element
if isinstance(item, (tuple, list)):
item = item[0]
if isinstance(item, Image.Image): # already PIL
pil_imgs.append(item)
elif hasattr(item, "name"): # gr.File
pil_imgs.append(Image.open(item.name))
elif isinstance(item, (str, Path)): # file path
pil_imgs.append(Image.open(item))
else:
raise TypeError(
f"Unsupported element in multiimages: {type(item)}"
)
return pil_imgs
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
Image.Image: The preprocessed image.
"""
processed_image = pipeline.preprocess_image(image)
return processed_image
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
"""
Preprocess a list of input images.
Args:
images (List[Tuple[Image.Image, str]]): The input images.
Returns:
List[Image.Image]: The preprocessed images.
"""
images = [image[0] for image in images]
processed_images = [pipeline.preprocess_image(image) for image in images]
return processed_images
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def image_to_3d(
image: Image.Image,
multiimages: List[Tuple[Image.Image, str]],
is_multiimage: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
) -> Tuple[dict, str]:
"""
Convert an image (or multiple images) into a 3D model and return its state and video.
Args:
image (Image.Image): The input image for single-image mode.
multiimages (List[Tuple[Image.Image, str]]): List of images with captions for multi-image mode.
is_multiimage (str): Whether to use multi-image generation.
seed (int): Random seed for reproducibility.
ss_guidance_strength (float): Sparse structure guidance strength.
ss_sampling_steps (int): Sparse structure sampling steps.
slat_guidance_strength (float): SLAT guidance strength.
slat_sampling_steps (int): SLAT sampling steps.
multiimage_algo (str): Multi-image algorithm to use.
Returns:
dict: The information of the generated 3D model.
str: The path to the video of the 3D model.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
is_multiimage = is_multiimage.lower() == "true"
# Run pipeline depending on mode
if not is_multiimage:
outputs = pipeline.run(
image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
else:
pil_images = to_pil_list(multiimages)
outputs = pipeline.run_multi_image(
pil_images,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
mode=multiimage_algo,
)
# Render the 3D video combining color and geometry
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
# Save the video
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
# Pack state for downstream use
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
torch.cuda.empty_cache()
return state, video_path
@spaces.GPU(duration=90)
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, mesh = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
@spaces.GPU
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
"""
Extract a Gaussian file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
Returns:
str: The path to the extracted Gaussian file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, _ = unpack_state(state)
gaussian_path = os.path.join(user_dir, 'sample.ply')
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
return gaussian_path, gaussian_path
def prepare_multi_example() -> List[Image.Image]:
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
images = []
for case in multi_case:
_images = []
for i in range(1, 4):
img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
W, H = img.size
img = img.resize((int(W / H * 512), 512))
_images.append(np.array(img))
images.append(Image.fromarray(np.concatenate(_images, axis=1)))
return images
def split_image(image: Image.Image) -> List[Image.Image]:
"""
Split an image into multiple views.
"""
image = np.array(image)
alpha = image[..., 3]
alpha = np.any(alpha>0, axis=0)
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
images = []
for s, e in zip(start_pos, end_pos):
images.append(Image.fromarray(image[:, s:e+1]))
return [preprocess_image(image) for image in images]
@spaces.GPU(api_name="quick_generate_glb")
def quick_generate_glb(
image: Image.Image,
multiimages: List[Tuple[Image.Image, str]],
is_multiimage: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
state, _ = image_to_3d(
image=image,
multiimages=multiimages,
is_multiimage=is_multiimage,
seed=seed,
ss_guidance_strength=ss_guidance_strength,
ss_sampling_steps=ss_sampling_steps,
slat_guidance_strength=slat_guidance_strength,
slat_sampling_steps=slat_sampling_steps,
multiimage_algo=multiimage_algo,
req=req
)
return extract_glb(state, mesh_simplify=mesh_simplify, texture_size=texture_size, req=req)
@spaces.GPU(api_name="quick_generate_gs")
def quick_generate_gs(
image: Image.Image,
multiimages: List[Tuple[Image.Image, str]],
is_multiimage: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
) -> Tuple[str, str]:
state, _ = image_to_3d(
image=image,
multiimages=multiimages,
is_multiimage=is_multiimage,
seed=seed,
ss_guidance_strength=ss_guidance_strength,
ss_sampling_steps=ss_sampling_steps,
slat_guidance_strength=slat_guidance_strength,
slat_sampling_steps=slat_sampling_steps,
multiimage_algo=multiimage_algo,
req=req
)
return extract_gaussian(state, req=req)
def test_for_api_gen(image: Image.Image) -> Image.Image:
"""
bilibili .
Args:
image (Image.Image): The input imagein hererererer.
Returns:
Image.Image: The preprocessed image no processs.
"""
return image
def update_is_multiimage(event: SelectData):
return "true" if event.index == 1 else "false"
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
Thanks to the incredible work of [JeffreyXiang/TRELLIS-image-large](https://huggingface.co/JeffreyXiang/TRELLIS-image-large) for providing such a stunning implementation of the TRELLIS 3D pipeline.
During my usage, I noticed that many users had questions regarding API access. I've spent some time refactoring the `image_to_3d` pipeline and adding two new endpoints:
- 🔁 `quick_generate_glb`: Directly generate and download a `.glb` 3D asset.
- 🌐 `quick_generate_gs`: Directly generate and download the Gaussian `.ply` file.
- 🧩 Both functions are exposed as Hugging Face API endpoints and can be called via `gradio_client` or any HTTP client.
### How to Use:
- Upload an image and click **"Generate"** to create a 3D asset. If the image has an alpha channel, it will be used as a mask. Otherwise, `rembg` will automatically remove the background.
- If you're satisfied with the result, click **"Extract GLB"** or **"Extract Gaussian"** to download the 3D file.
### Features:
- ✅ Single-image and experimental multi-image generation
- ✅ `.glb` extraction with mesh simplification and texturing
- ✅ `.ply` (Gaussian) extraction
- ✅ Public API endpoints for one-click asset generation and download
Feel free to try it out and send feedback — I'm happy to keep improving it based on your suggestions!
""")
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
gr.Markdown("""
Input different views of the object in separate images.
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
""")
is_multiimage = gr.Radio(
choices=["true", "false"],
value="false",
label="Use multi-image mode",
visible=True
)
input_tabs.select(
fn=update_is_multiimage,
outputs=is_multiimage
)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
with gr.Row():
quick_generate_glb_btn = gr.Button("Quick Generate GLB")
quick_generate_gs_btn = gr.Button("Quick Generate Gaussian")
gr.Markdown("""
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
""")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
output_buf = gr.State()
# Example images at the bottom of the page
with gr.Row() as single_image_example:
examples = gr.Examples(
examples=[
f'assets/example_image/{image}'
for image in os.listdir("assets/example_image")
],
inputs=[image_prompt],
fn=preprocess_image,
outputs=[image_prompt],
run_on_click=True,
examples_per_page=64,
)
with gr.Row(visible=False) as multiimage_example:
examples_multi = gr.Examples(
examples=prepare_multi_example(),
inputs=[image_prompt],
fn=split_image,
outputs=[multiimage_prompt],
run_on_click=True,
examples_per_page=8,
)
# Handlers
demo.load(start_session)
demo.unload(end_session)
single_image_input_tab.select(
lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
outputs=[is_multiimage, single_image_example, multiimage_example]
)
multiimage_input_tab.select(
lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
outputs=[is_multiimage, single_image_example, multiimage_example]
)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
multiimage_prompt.upload(
preprocess_images,
inputs=[multiimage_prompt],
outputs=[multiimage_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[
image_prompt, multiimage_prompt, is_multiimage, seed,
ss_guidance_strength, ss_sampling_steps,
slat_guidance_strength, slat_sampling_steps, multiimage_algo
],
outputs=[output_buf, video_output],
).then(
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
video_output.clear(
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf],
outputs=[model_output, download_gs],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_gs],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
quick_generate_glb_btn.click(
fn=quick_generate_glb,
inputs=[
image_prompt,
multiimage_prompt,
is_multiimage,
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
multiimage_algo,
mesh_simplify,
texture_size,
],
outputs=[model_output, download_glb],
)
quick_generate_gs_btn.click(
fn=quick_generate_gs,
inputs=[
image_prompt,
multiimage_prompt,
is_multiimage,
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
multiimage_algo,
],
outputs=[model_output, download_gs],
)
# Launch the Gradio app
if __name__ == "__main__":
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
except:
pass
demo.launch(show_error=True,show_api=True)