Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,089 Bytes
476e0f0 b823627 476e0f0 0e03f38 7760d2d 476e0f0 7760d2d 476e0f0 7760d2d 476e0f0 7760d2d 476e0f0 f3d0960 7760d2d b823627 7760d2d 476e0f0 3a3602f 0e03f38 476e0f0 7760d2d 476e0f0 7760d2d 476e0f0 7760d2d 476e0f0 7760d2d 476e0f0 91544fb 476e0f0 aeba29c 476e0f0 0e03f38 476e0f0 7760d2d b823627 476e0f0 7760d2d 476e0f0 aeba29c 476e0f0 aeba29c 476e0f0 aeba29c 476e0f0 aeba29c 476e0f0 7560047 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
import shlex
import subprocess
import imageio
import numpy as np
import gradio as gr
import spaces
import sys
from loguru import logger
current_path = os.path.dirname(os.path.abspath(__file__))
# try:
# import diff_gaussian_rasterization # noqa: F401
# except ImportError:
# @spaces.GPU
# def install_diff_gaussian_rasterization():
# os.system("pip install ./extensions/RaDe-GS/submodules/diff-gaussian-rasterization")
# install_diff_gaussian_rasterization()
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(current_path, 'out')
os.makedirs(TMP_DIR, exist_ok=True)
TAG = {
"SD15": ["gsdiff_gobj83k_sd15__render", "gsdiff_gobj83k_sd15_image__render"], # Best efficiency
"PixArt-Sigma": ["gsdiff_gobj83k_pas_fp16__render","gsdiff_gobj83k_pas_fp16_image__render"],
"SD3": ["gsdiff_gobj83k_sd35m__render", "gsdiff_gobj83k_sd35m_image__render"] # Best performance
}
MODEL_TYPE = "PixArt-Sigma"
# for PixArt-Sigma
subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas")) # for txt condition
subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas --image_cond")) # for img condition
img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py configs/gsdiff_pas.yaml {} \
--rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \
--image_path {} --elevation {} --prompt {} --seed {}"
txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py configs/gsdiff_pas.yaml {} \
--save_ply --output_video_type mp4 \
--prompt {} --seed {}"
# for SD1.5
# subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15")) # for txt condition
# subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15 --image_cond")) # for img condition
# img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py configs/gsdiff_sd15.yaml {} \
# --rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \
# --image_path {} --elevation {} --prompt {} --seed {}"
# txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py configs/gsdiff_sd15.yaml {} \
# --save_ply --output_video_type mp4 --guidance_scale {} \
# --elevation {} --prompt {} --seed {}"
# process function
@spaces.GPU(duration=120)
def process(input_image, prompt='a_high_quality_3D_asset', prompt_neg='poor_quality', input_elevation=20, guidance_scale=2., input_seed=0):
# fail to install RaDe-GS
# subprocess.run("cd extensions/RaDe-GS/submodules && pip3 install diff-gaussian-rasterization", shell=True)
# subprocess.run("cd extensions/RaDe-GS/submodules/diff-gaussian-rasterization && python3 setup.py bdist_wheel ", shell=True)
if input_image is not None:
import uuid
image_path = os.path.join(TMP_DIR, f"{str(uuid.uuid4())}.png")
image_name = image_path.split('/')[-1].split('.')[0] + "_rgba"
input_image.save(image_path)
TAG_DEST = TAG[MODEL_TYPE][1]
full_command = img_commands.format(TAG_DEST, guidance_scale, image_path, input_elevation, prompt, input_seed)
else:
TAG_DEST = TAG[MODEL_TYPE][0]
# without guidance_scale and input_elevation
full_command = txt_commands.format(TAG_DEST, prompt, input_seed)
image_name = ""
os.system(full_command)
# save video and ply files
ckpt_dir = os.path.join(TMP_DIR, TAG_DEST, "checkpoints")
infer_from_iter = int(sorted(os.listdir(ckpt_dir))[-1])
MAX_NAME_LEN = 20 # TODO: make `20` configurable
prompt = prompt.replace("_", " ")
prompt_name = prompt[:MAX_NAME_LEN] + "..." if prompt[:MAX_NAME_LEN] != "" else prompt
name = f"[{image_name}]_[{prompt_name}]_{infer_from_iter:06d}"
output_video_path = os.path.join(TMP_DIR, TAG_DEST, "inference", name + ".mp4")
output_ply_path = os.path.join(TMP_DIR, TAG_DEST, "inference", name + ".ply")
output_img_path = os.path.join(TMP_DIR, TAG_DEST, "inference", name + "_gs.png")
logger.info(full_command, output_video_path, output_ply_path)
output_image = imageio.imread(output_img_path)
return output_image, output_video_path, output_ply_path
# gradio UI
_TITLE = '''DiffSplat: Repurposing Image Diffusion Models for Scalable Gaussian Splat Generation'''
_DESCRIPTION = '''
<strong style="color:red;">This space is currently under maintenance. If you experience slow or failed generations, please try again. We apologize for any inconvenience.</strong>
### If you find our work helpful, please consider citing our paper π or giving the repo a star π
<div>
<a style="display:inline-block; margin-left: .5em" href="https://chenguolin.github.io/projects/DiffSplat"><img src='https://img.shields.io/badge/Project-Page-brightgreen'/></a>
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2501.16764"><img src='https://img.shields.io/badge/arXiv-2501.16764-b31b1b.svg?logo=arXiv'/></a>
<a style="display:inline-block; margin-left: .5em" href="https://github.com/chenguolin/DiffSplat"><img src='https://img.shields.io/github/stars/chenguolin/DiffSplat?style=social'/></a>
<a style="display:inline-block; margin-left: .5em" href="https://huggingface.co/chenguolin/DiffSplat"><img src='https://img.shields.io/badge/HF-Model-yellow'/></a>
</div>
* Input can be only text, only image, or both image and text.
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
'''
block = gr.Blocks(title=_TITLE)
with block:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
# input image
input_image = gr.Image(label="image", type='pil')
# input prompt
input_text = gr.Textbox(label="prompt",value="a_high_quality_3D_asset")
# negative prompt
input_neg_text = gr.Textbox(label="negative prompt", value="ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate")
# guidance_scale
guidance_scale = gr.Slider(label="guidance scale", minimum=1., maximum=7.5, step=0.5, value=2.0)
# elevation
input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=10)
# random seed
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
# gen button
button_gen = gr.Button("Generate")
with gr.Column(scale=0.8):
with gr.Tab("Video"):
# final video results
output_video = gr.Video(label="video")
# ply file
output_file = gr.File(label="3D Gaussians (ply format)")
with gr.Tab("Splatter Images"):
output_image = gr.Image(interactive=False, show_label=False)
button_gen.click(process, inputs=[input_image, input_text, input_neg_text, input_elevation, guidance_scale, input_seed], outputs=[output_image, output_video, output_file])
gr.Examples(
examples=[
[f'assets/diffsplat/{image}', "a_high_quality_3D_asset"]
for image in os.listdir("assets/diffsplat") if image.endswith('.png')
],
inputs=[input_image, input_text],
label='Image-to-3D Examples'
)
gr.Examples(
examples=[
["a_toy_robot", None],
["a_cute_panda", None],
["an_ancient_leather-bound_book", None]
],
inputs=[input_text, input_image],
label='Text-to-3D Examples'
)
# Launch the Gradio app
if __name__ == "__main__":
block.launch(share=True)
|