File size: 3,636 Bytes
f7fe32c
 
 
 
 
 
 
 
2a97e46
76ce6fb
f7fe32c
be0f082
f7fe32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0f082
f7fe32c
 
 
 
 
be0f082
f7fe32c
 
 
 
 
be0f082
 
 
 
 
 
 
 
 
 
 
 
f7fe32c
 
be0f082
f7fe32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745568f
f7fe32c
 
 
 
 
745568f
f7fe32c
 
745568f
 
 
 
 
 
 
f7fe32c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import cv2
import numpy as np
import os
import torch
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline


auth_token = os.environ.get("READ_TOKEN") or True

pipe = None
def preview(image, state):
    h, w = image.shape[:2]
    scale_percent = 512 / max([w, h])

    width = int(w * scale_percent)
    height = int(h * scale_percent)
    dim = (width, height)
    resized = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
    yoff = round((512-height)/2)
    xoff = round((512-width)/2)

    final_image = np.zeros((512, 512, 3), dtype=np.uint8)
    final_image.fill(120)
    final_image[yoff:yoff+height, xoff:xoff+width, :] = resized

    mask_image = np.zeros((512, 512, 3), dtype=np.uint8)
    mask_image.fill(255)
    mask_image[yoff:yoff+height, xoff:xoff+width, :] = 0
    state.clear()
    state.append(mask_image)
    state.append([yoff, xoff, height, width])
    state.append(resized)

    return final_image, state


def sd_inpaint(image, prompt, state):
    global pipe
    mask = state[0]
    yoff, xoff, height, width = state[1]
    orig_image = state[2]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    if device == "cuda":
        if pipe is None:
            pipe = StableDiffusionInpaintPipeline.from_pretrained(
                "runwayml/stable-diffusion-inpainting",
                revision="fp16",
                torch_dtype=torch.float16,
                use_auth_token=auth_token
            ).to(device)
        output = pipe(prompt=prompt, image=Image.fromarray(image), mask_image=Image.fromarray(mask)).images[0]
    else:
        output = image

    result = np.array(output)
    result[yoff:yoff+height, xoff:xoff+width, :] = orig_image
    result = Image.fromarray(result)

    return result


with gr.Blocks(title='Dreambooth Image Editing and Stable Diffusion Inpainting') as demo:
    state = gr.State([])
    gr.Markdown("# Dreambooth Image Editing and Stable Diffusion Inpainting")
    gr.Markdown("It's difficult to get a good image to use for dreambooth, I do not have many photograhps of myself alone and it's very slow to edit the images (crop the selection, scale it to 512x512 and solve the problem of the background somehow)")
    gr.Markdown("This app uses a combination of image selection, automatic scaling, and  stable diffusion inpainting to speed that process. Follow the next instructions:")
    gr.Markdown("""- Upload an image
- Use the select tool to select the area you want to use for dreambooth
- The image will be resized to 512x512 and fill the rest of with a gray background
- Then click the Inpaint button to use stable diffusion to inpaint the background
- Save the image and use it for dreambooth
    """)
    with gr.Row():
        with gr.Column():
            img_ctr = gr.Image(tool='select')
        with gr.Column():
            output = gr.Image(label="Selection with mask (512x512)")
    with gr.Row():
        greet_btn = gr.Button("Selection")
    with gr.Row():
        sd_prompt = gr.Textbox(lines=2, label="Stable diffusion prompt")
    with gr.Row():
        final_image = gr.Image(label='Generated Image (512x512)')
    with gr.Row():
        stab_btn = gr.Button("Inpaint")
    with gr.Row():
        gr.Examples([
            ['one.png', 'in the office'],
        ], inputs=[img_ctr, sd_prompt])
        gr.Examples([
            ['two.png', 'in the office'],
        ], inputs=[img_ctr, sd_prompt])

    greet_btn.click(fn=preview, inputs=[img_ctr, state], outputs=[output, state])
    stab_btn.click(fn=sd_inpaint, inputs=[output, sd_prompt, state], outputs=final_image)


demo.launch()