File size: 4,758 Bytes
8c9048a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
import os
import torch
import cv2
import numpy as np
import pandas as pd
import random
from tqdm import tqdm

import webdataset as wds

from PIL import Image
from diffusers.utils import load_image, make_image_grid


def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
    # initialize the dimensions of the image to be resized and
    # grab the image size
    dim = None
    (h, w) = image.shape[:2]

    # if both the width and height are None, then return the
    # original image
    if width is None and height is None:
        return image

    # check to see if the width is None
    if width is None:
        # calculate the ratio of the height and construct the
        # dimensions
        r = height / float(h)
        dim = (int(w * r), height)

    # otherwise, the height is None
    else:
        # calculate the ratio of the width and construct the
        # dimensions
        r = width / float(w)
        dim = (width, int(h * r))

    # resize the image
    resized = cv2.resize(image, dim, interpolation = inter)

    # return the resized image
    return resized



# choose the base model here
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"
# base_model_path = "runwayml/stable-diffusion-v1-5"

# input brushnet ckpt path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"

# choose whether using blended operation
blended = False

out_dir = "cc3m_synthetic"
captions=["A photo of a man",]


print("Generating images with captions:")
for caption in captions:
    print("\t", caption)


mask_root = "../SemanticSegmentation/mask_skin"
cc3m_original_path = "PASTE_HERE"
# conditioning scale
brushnet_conditioning_scale = 1.0
seed = 12345

brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
    base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False, safety_checker=None, requires_safety_checker=False
)

n_steps = 50
high_noise_frac = 0.8
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()


generated = 0
dataset = wds.WebDataset(cc3m_original_path + "/00{000..331}.tar").decode("rgb8").to_tuple("jpg;png", "json")
for sample in tqdm(dataset):
    img = sample[0]
    name = sample[1]["key"]
    mask_path = f"{mask_root}/{name}.png"
    
    if os.path.exists(mask_path) and not os.path.exists(f"examples/brushnet/{out_dir}/{name}_man.png"):

        init_image = img
        mask_image = 1.*(cv2.imread(mask_path).sum(-1)>255)

        # resize image
        h,w,_ = init_image.shape
        if w<h:
            scale=768/w
        else:
            scale=768/h
        new_h=int(h*scale)
        new_w=int(w*scale)

        init_image=cv2.resize(init_image,(new_w,new_h))
        mask_image=cv2.resize(mask_image,(new_w,new_h))[:,:,np.newaxis]

        init_image = init_image * (1-mask_image)

        init_image = Image.fromarray(init_image.astype(np.uint8))
        mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255)
        
        
        is_empty = True if np.sum(img > 0.9) < 50 else False

        if not is_empty:
            for caption in captions:
                generator = torch.Generator("cuda").manual_seed(random.randint(3000000000, 6000000000))
                image = pipe(
                    caption, 
                    init_image, 
                    mask_image, 
                    num_inference_steps=10, 
                    generator=generator,
                    brushnet_conditioning_scale=brushnet_conditioning_scale
                ).images[0]
                

                if blended:
                    init_image_np = np.asarray(load_image(f"examples/brushnet/{out_dir}/{name}_original.png"))
                    image_np = cv2.resize(np.array(image), (init_image_np.shape[:2]), interpolation=cv2.INTER_AREA)
                    mask_np = np.asarray(load_image(mask_path))
                    _, mask_np = cv2.threshold(mask_np, 122, 255, cv2.THRESH_BINARY)
                    mask_np = mask_np / 255
                    image_pasted = init_image_np * (1 - mask_np) + image_np * mask_np
                    image_pasted = image_pasted.astype(image_np.dtype)

                image = image_resize(np.asarray(image), width=256)[:,:,::-1]
                cv2.imwrite(f"examples/brushnet/{out_dir}/{name}_man.png", image)