File size: 8,089 Bytes
476e0f0
 
 
 
 
 
 
 
 
 
 
 
b823627
 
 
 
 
 
 
 
 
 
476e0f0
 
 
0e03f38
7760d2d
 
 
 
 
 
 
 
 
 
 
 
 
476e0f0
7760d2d
 
 
476e0f0
 
7760d2d
 
 
 
 
476e0f0
 
 
7760d2d
 
 
476e0f0
 
 
 
f3d0960
7760d2d
b823627
7760d2d
 
 
476e0f0
3a3602f
 
0e03f38
476e0f0
7760d2d
 
476e0f0
7760d2d
 
 
476e0f0
 
 
 
 
7760d2d
476e0f0
 
 
 
 
7760d2d
 
 
476e0f0
 
 
 
 
 
 
 
 
 
 
91544fb
 
 
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
aeba29c
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e03f38
476e0f0
 
 
 
 
7760d2d
b823627
476e0f0
 
 
 
 
 
7760d2d
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
aeba29c
 
476e0f0
aeba29c
476e0f0
 
 
 
 
aeba29c
 
 
476e0f0
aeba29c
476e0f0
 
 
 
 
 
7560047
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import shlex
import subprocess
import imageio
import numpy as np

import gradio as gr
import spaces
import sys
from loguru import logger
current_path = os.path.dirname(os.path.abspath(__file__))

# try:
#     import diff_gaussian_rasterization  # noqa: F401
# except ImportError:
#     @spaces.GPU
#     def install_diff_gaussian_rasterization():
#         os.system("pip install ./extensions/RaDe-GS/submodules/diff-gaussian-rasterization")
#     install_diff_gaussian_rasterization()



MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(current_path, 'out')
os.makedirs(TMP_DIR, exist_ok=True)

TAG = {
    "SD15": ["gsdiff_gobj83k_sd15__render", "gsdiff_gobj83k_sd15_image__render"], # Best efficiency
    "PixArt-Sigma": ["gsdiff_gobj83k_pas_fp16__render","gsdiff_gobj83k_pas_fp16_image__render"], 
    "SD3": ["gsdiff_gobj83k_sd35m__render", "gsdiff_gobj83k_sd35m_image__render"] # Best performance
}
MODEL_TYPE = "PixArt-Sigma"

# for PixArt-Sigma
subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas")) # for txt condition
subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas --image_cond")) # for img condition
img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py configs/gsdiff_pas.yaml {} \
--rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \
--image_path {} --elevation {} --prompt {}  --seed {}"

txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py  configs/gsdiff_pas.yaml {} \
 --save_ply --output_video_type mp4 \
 --prompt {} --seed {}"


# for SD1.5
# subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15")) # for txt condition
# subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15 --image_cond")) # for img condition

# img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py configs/gsdiff_sd15.yaml {} \
# --rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \
# --image_path {} --elevation {} --prompt {} --seed {}"

# txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py  configs/gsdiff_sd15.yaml {} \
#  --save_ply --output_video_type mp4 --guidance_scale {} \
#  --elevation {} --prompt {} --seed {}"



# process function
@spaces.GPU(duration=120)
def process(input_image, prompt='a_high_quality_3D_asset', prompt_neg='poor_quality', input_elevation=20, guidance_scale=2., input_seed=0):
    # fail to install RaDe-GS 
    # subprocess.run("cd extensions/RaDe-GS/submodules && pip3 install diff-gaussian-rasterization", shell=True)
    # subprocess.run("cd extensions/RaDe-GS/submodules/diff-gaussian-rasterization && python3 setup.py bdist_wheel ", shell=True)

    if input_image is not None:
        import uuid
        image_path = os.path.join(TMP_DIR, f"{str(uuid.uuid4())}.png")
        image_name = image_path.split('/')[-1].split('.')[0] + "_rgba"
        input_image.save(image_path)
        TAG_DEST = TAG[MODEL_TYPE][1]
        full_command = img_commands.format(TAG_DEST, guidance_scale, image_path, input_elevation, prompt, input_seed)
    else:
        TAG_DEST = TAG[MODEL_TYPE][0]
        # without guidance_scale and input_elevation
        full_command = txt_commands.format(TAG_DEST, prompt, input_seed)
        image_name = ""
    
    os.system(full_command)
    
    # save video and ply files
    ckpt_dir = os.path.join(TMP_DIR, TAG_DEST, "checkpoints")
    infer_from_iter = int(sorted(os.listdir(ckpt_dir))[-1])
    MAX_NAME_LEN = 20  # TODO: make `20` configurable
    prompt = prompt.replace("_", " ")
    prompt_name = prompt[:MAX_NAME_LEN] + "..." if prompt[:MAX_NAME_LEN] != "" else prompt
    name = f"[{image_name}]_[{prompt_name}]_{infer_from_iter:06d}"
    output_video_path = os.path.join(TMP_DIR, TAG_DEST, "inference",  name + ".mp4")
    output_ply_path = os.path.join(TMP_DIR, TAG_DEST, "inference",  name + ".ply")
    output_img_path = os.path.join(TMP_DIR, TAG_DEST, "inference",  name + "_gs.png")

    logger.info(full_command, output_video_path, output_ply_path)

    output_image = imageio.imread(output_img_path)
    return output_image, output_video_path, output_ply_path


# gradio UI
_TITLE = '''DiffSplat: Repurposing Image Diffusion Models for Scalable Gaussian Splat Generation'''

_DESCRIPTION = '''

<strong style="color:red;">This space is currently under maintenance. If you experience slow or failed generations, please try again. We apologize for any inconvenience.</strong>

### If you find our work helpful, please consider citing our paper πŸ“š or giving the repo a star 🌟
<div>
<a style="display:inline-block; margin-left: .5em" href="https://chenguolin.github.io/projects/DiffSplat"><img src='https://img.shields.io/badge/Project-Page-brightgreen'/></a>  
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2501.16764"><img src='https://img.shields.io/badge/arXiv-2501.16764-b31b1b.svg?logo=arXiv'/></a>  
<a style="display:inline-block; margin-left: .5em" href="https://github.com/chenguolin/DiffSplat"><img src='https://img.shields.io/github/stars/chenguolin/DiffSplat?style=social'/></a>  
<a style="display:inline-block; margin-left: .5em" href="https://huggingface.co/chenguolin/DiffSplat"><img src='https://img.shields.io/badge/HF-Model-yellow'/></a>  
</div>

* Input can be only text, only image, or both image and text. 
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
'''

block = gr.Blocks(title=_TITLE)
with block:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown('# ' + _TITLE)
    gr.Markdown(_DESCRIPTION)

    with gr.Row(variant='panel'):
        with gr.Column(scale=1):
            # input image
            input_image = gr.Image(label="image", type='pil')
            
            # input prompt
            input_text = gr.Textbox(label="prompt",value="a_high_quality_3D_asset")
            
            # negative prompt
            input_neg_text = gr.Textbox(label="negative prompt", value="ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate")

            # guidance_scale
            guidance_scale = gr.Slider(label="guidance scale", minimum=1., maximum=7.5, step=0.5, value=2.0)
            
            # elevation
            input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=10)

            # random seed
            input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
            # gen button
            button_gen = gr.Button("Generate")


        with gr.Column(scale=0.8):
            with gr.Tab("Video"):
                # final video results
                output_video = gr.Video(label="video")
                # ply file
                output_file = gr.File(label="3D Gaussians (ply format)")
            with gr.Tab("Splatter Images"):
                output_image = gr.Image(interactive=False, show_label=False)
                

        button_gen.click(process, inputs=[input_image, input_text, input_neg_text, input_elevation, guidance_scale, input_seed], outputs=[output_image, output_video, output_file])

    gr.Examples(
        examples=[
            [f'assets/diffsplat/{image}', "a_high_quality_3D_asset"]
            for image in os.listdir("assets/diffsplat") if image.endswith('.png')
        ],
        inputs=[input_image, input_text],
        label='Image-to-3D Examples'
    )

    gr.Examples(
        examples=[
            ["a_toy_robot", None],
            ["a_cute_panda", None],
            ["an_ancient_leather-bound_book", None]
        ],
        inputs=[input_text, input_image],
        label='Text-to-3D Examples'
    )
    

# Launch the Gradio app
if __name__ == "__main__":
    block.launch(share=True)