|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import functools |
|
import os |
|
import tempfile |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch as torch |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
from PIL import Image |
|
from gradio_imageslider import ImageSlider |
|
from tqdm import tqdm |
|
|
|
from pathlib import Path |
|
import gradio |
|
from gradio.utils import get_cache_folder |
|
from DAI.pipeline_all import DAIPipeline |
|
|
|
from diffusers import ( |
|
AutoencoderKL, |
|
UNet2DConditionModel, |
|
) |
|
|
|
from transformers import CLIPTextModel, AutoTokenizer |
|
|
|
from DAI.controlnetvae import ControlNetVAEModel |
|
|
|
from DAI.decoder import CustomAutoencoderKL |
|
|
|
|
|
class Examples(gradio.helpers.Examples): |
|
def __init__(self, *args, directory_name=None, **kwargs): |
|
super().__init__(*args, **kwargs, _initiated_directly=False) |
|
if directory_name is not None: |
|
self.cached_folder = get_cache_folder() / directory_name |
|
self.cached_file = Path(self.cached_folder) / "log.csv" |
|
self.create() |
|
|
|
|
|
default_seed = 2024 |
|
default_batch_size = 1 |
|
|
|
def process_image_check(path_input): |
|
if path_input is None: |
|
raise gr.Error( |
|
"Missing image in the first pane: upload a file or use one from the gallery below." |
|
) |
|
|
|
def resize_image(input_image, resolution): |
|
|
|
if not isinstance(input_image, Image.Image): |
|
raise ValueError("input_image should be a PIL Image object") |
|
|
|
|
|
input_image_np = np.asarray(input_image) |
|
|
|
|
|
H, W, C = input_image_np.shape |
|
H = float(H) |
|
W = float(W) |
|
|
|
|
|
k = float(resolution) / min(H, W) |
|
|
|
|
|
H *= k |
|
W *= k |
|
H = int(np.round(H / 64.0)) * 64 |
|
W = int(np.round(W / 64.0)) * 64 |
|
|
|
|
|
img = input_image.resize((W, H), Image.Resampling.LANCZOS) |
|
|
|
return img |
|
|
|
def process_image( |
|
pipe, |
|
vae_2, |
|
path_input, |
|
): |
|
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) |
|
print(f"Processing image {name_base}{name_ext}") |
|
|
|
path_output_dir = tempfile.mkdtemp() |
|
path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png") |
|
input_image = Image.open(path_input) |
|
resolution = None |
|
|
|
pipe_out = pipe( |
|
image=input_image, |
|
prompt="remove glass reflection", |
|
vae_2=vae_2, |
|
processing_resolution=resolution, |
|
) |
|
|
|
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2 |
|
processed_frame = (processed_frame[0] * 255).astype(np.uint8) |
|
processed_frame = Image.fromarray(processed_frame) |
|
processed_frame.save(path_out_png) |
|
yield [input_image, path_out_png] |
|
|
|
def run_demo_server(pipe, vae_2): |
|
process_pipe_image = spaces.GPU(functools.partial(process_image, pipe, vae_2)) |
|
|
|
gradio_theme = gr.themes.Default() |
|
|
|
with gr.Blocks( |
|
theme=gradio_theme, |
|
title="Dereflection Any Image", |
|
css=""" |
|
#download { |
|
height: 118px; |
|
} |
|
.slider .inner { |
|
width: 5px; |
|
background: #FFF; |
|
} |
|
.viewport { |
|
aspect-ratio: 4/3; |
|
} |
|
.tabs button.selected { |
|
font-size: 20px !important; |
|
color: crimson !important; |
|
} |
|
h1 { |
|
text-align: center; |
|
display: block; |
|
} |
|
h2 { |
|
text-align: center; |
|
display: block; |
|
} |
|
h3 { |
|
text-align: center; |
|
display: block; |
|
} |
|
.md_feedback li { |
|
margin-bottom: 0px !important; |
|
} |
|
""", |
|
head=""" |
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> |
|
<script> |
|
window.dataLayer = window.dataLayer || []; |
|
function gtag() {dataLayer.push(arguments);} |
|
gtag('js', new Date()); |
|
gtag('config', 'G-1FWSVCGZTG'); |
|
</script> |
|
""", |
|
) as demo: |
|
gr.Markdown( |
|
""" |
|
# Dereflection Any Image |
|
<p align="center"> |
|
""" |
|
) |
|
|
|
with gr.Tabs(elem_classes=["tabs"]): |
|
with gr.Tab("Image"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image( |
|
label="Input Image", |
|
type="filepath", |
|
) |
|
with gr.Row(): |
|
image_submit_btn = gr.Button( |
|
value="remove reflection", variant="primary" |
|
) |
|
image_reset_btn = gr.Button(value="Reset") |
|
with gr.Column(): |
|
image_output_slider = ImageSlider( |
|
label="outputs", |
|
type="filepath", |
|
show_download_button=True, |
|
show_share_button=True, |
|
interactive=False, |
|
elem_classes="slider", |
|
|
|
) |
|
|
|
Examples( |
|
fn=process_pipe_image, |
|
examples=sorted([ |
|
os.path.join("files", "image", name) |
|
for name in os.listdir(os.path.join("files", "image")) |
|
]), |
|
inputs=[image_input], |
|
outputs=[image_output_slider], |
|
cache_examples=False, |
|
directory_name="examples_image", |
|
) |
|
|
|
|
|
|
|
image_submit_btn.click( |
|
fn=process_image_check, |
|
inputs=image_input, |
|
outputs=None, |
|
preprocess=False, |
|
queue=False, |
|
).success( |
|
fn=process_pipe_image, |
|
inputs=[ |
|
image_input, |
|
], |
|
outputs=[image_output_slider], |
|
concurrency_limit=1, |
|
) |
|
|
|
image_reset_btn.click( |
|
fn=lambda: ( |
|
None, |
|
None, |
|
None, |
|
), |
|
inputs=[], |
|
outputs=[ |
|
image_input, |
|
image_output_slider, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
|
|
|
|
demo.queue( |
|
api_open=False, |
|
).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
) |
|
|
|
|
|
def main(): |
|
os.system("pip freeze") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
weight_dtype = torch.float32 |
|
model_dir = "./weights" |
|
pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0" |
|
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1" |
|
revision = None |
|
variant = None |
|
|
|
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device) |
|
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device) |
|
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device) |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained( |
|
pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant |
|
).to(device) |
|
|
|
text_encoder = CLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant |
|
).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path2, |
|
subfolder="tokenizer", |
|
revision=revision, |
|
use_fast=False, |
|
) |
|
pipe = DAIPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
controlnet=controlnet, |
|
safety_checker=None, |
|
scheduler=None, |
|
feature_extractor=None, |
|
t_start=0, |
|
).to(device) |
|
|
|
try: |
|
import xformers |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
pass |
|
|
|
run_demo_server(pipe, vae_2) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|