|
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler |
|
import os |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from PIL import Image |
|
|
|
|
|
|
|
imagelist = pd.read_csv("examples/brushnet/paper_imagelist_for_inpainting.csv").values |
|
|
|
def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA): |
|
|
|
|
|
dim = None |
|
(h, w) = image.shape[:2] |
|
|
|
|
|
|
|
if width is None and height is None: |
|
return image |
|
|
|
|
|
if width is None: |
|
|
|
|
|
r = height / float(h) |
|
dim = (int(w * r), height) |
|
|
|
|
|
else: |
|
|
|
|
|
r = width / float(w) |
|
dim = (width, int(h * r)) |
|
|
|
|
|
resized = cv2.resize(image, dim, interpolation = inter) |
|
|
|
|
|
return resized |
|
|
|
|
|
|
|
|
|
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" |
|
|
|
|
|
|
|
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" |
|
|
|
|
|
blended = False |
|
|
|
occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter'] |
|
facet = pd.read_csv("../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'}) |
|
|
|
root = "../../datasets/facet/images_bb" |
|
|
|
mask_root = "../Color-Invariant-Skin-Segmentation/FCN/output/person_masks" |
|
|
|
|
|
|
|
|
|
|
|
output_dir = "/home/kis/datasets/facet_paper_whole_body_occupation_prompt_filelist/" |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
steps = 10 |
|
|
|
brushnet_conditioning_scale = 1.0 |
|
seed = 12345 |
|
|
|
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16) |
|
pipe = StableDiffusionBrushNetPipeline.from_pretrained( |
|
base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False, safety_checker=None, requires_safety_checker=False |
|
) |
|
|
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
for category in occupations: |
|
|
|
n_imgs = facet[facet['class1'] == category]['person_id'].shape[0] |
|
|
|
for id_ in tqdm(range(n_imgs)): |
|
img = facet[facet['class1'] == category].iloc[id_] |
|
|
|
if int(img['visible_face']) != 1: |
|
continue |
|
if int(img['gender_presentation_masc']) == 1: |
|
gender = 'male' |
|
elif int(img['gender_presentation_fem']) == 1: |
|
gender = 'female' |
|
else: |
|
continue |
|
|
|
if gender == 'male': |
|
bb = eval(img["bounding_box"]) |
|
|
|
input_box = np.array([int(bb['x']), int(bb['y']), int(bb['x'])+int(bb['width']), int(bb['y'])+int(bb['height'])]) |
|
|
|
img_id = str(img['filename']).replace(".jpg", "") |
|
bb_id = str(img['person_id']) |
|
|
|
|
|
if not int(bb_id) in imagelist: |
|
print("file not in the in-painting list, skipping") |
|
continue |
|
|
|
if os.path.exists(f"{output_dir}/{bb_id}_original.png"): |
|
print(f"Skipping image {bb_id}: already processed") |
|
continue |
|
|
|
image_path = f"{root}/{bb_id}.jpg" |
|
mask_path = f"{mask_root}/{bb_id}.jpg" |
|
|
|
if not os.path.exists(mask_path): |
|
print(f"No mask found for image id {bb_id}") |
|
continue |
|
|
|
if not os.path.exists(image_path): |
|
print(f"No image found for image id {bb_id}") |
|
continue |
|
|
|
init_image = cv2.imread(image_path) |
|
(h, w) = init_image.shape[:2] |
|
if h < 224 or w < 224: |
|
print(f"Skipping image as it is too small: {h}x{w}") |
|
continue |
|
|
|
init_image = image_resize(init_image, width=min(512, init_image.shape[1])) |
|
|
|
init_image = init_image[:,:,::-1] |
|
mask_image = 1.*(cv2.resize(cv2.imread(mask_path), (init_image.shape[1], init_image.shape[0])).sum(-1)>255)[:,:,np.newaxis] |
|
|
|
mask = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") |
|
size = np.asarray(mask).shape |
|
ret, output = cv2.threshold(np.asarray(mask),127,255,cv2.THRESH_BINARY) |
|
is_empty = True if np.sum(output > 127) / (size[0]*size[1]) < 0.30 else False |
|
if is_empty: |
|
print(f"Skipping image {bb_id}: mask percentage: {np.sum(output > 127) / (size[0]*size[1])}") |
|
continue |
|
|
|
cv2.imwrite(f"{output_dir}/{bb_id}_original.png", cv2.imread(image_path)) |
|
|
|
|
|
init_image = init_image * (1-mask_image) |
|
init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB") |
|
mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") |
|
cv2.imwrite(f"{output_dir}/{bb_id}_mask.png", np.asarray(mask_image)) |
|
|
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
|
|
caption = f"A photo of a woman who is a {category}" |
|
print(f"Image {bb_id}: {caption}") |
|
image = pipe( |
|
caption, |
|
init_image, |
|
mask_image, |
|
num_inference_steps=steps, |
|
generator=generator, |
|
brushnet_conditioning_scale=brushnet_conditioning_scale |
|
).images[0] |
|
|
|
if blended: |
|
image_np=np.array(image) |
|
init_image_np=cv2.imread(image_path)[:,:,::-1] |
|
mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] |
|
|
|
|
|
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 |
|
mask_blurred = mask_blurred[:,:,np.newaxis] |
|
mask_np = 1-(1-mask_np) * (1-mask_blurred) |
|
|
|
image_pasted=init_image_np * (1-mask_np) + image_np*mask_np |
|
image_pasted=image_pasted.astype(image_np.dtype) |
|
image=Image.fromarray(image_pasted) |
|
|
|
|
|
image.save(f"{output_dir}/{bb_id}_male_to_female.png") |
|
|
|
|
|
|
|
caption = f"A photo of a man who is a {category}" |
|
print(f"Image {bb_id}: {caption}") |
|
image = pipe( |
|
caption, |
|
init_image, |
|
mask_image, |
|
num_inference_steps=steps, |
|
generator=generator, |
|
brushnet_conditioning_scale=brushnet_conditioning_scale |
|
).images[0] |
|
|
|
if blended: |
|
image_np=np.array(image) |
|
init_image_np=cv2.imread(image_path)[:,:,::-1] |
|
mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] |
|
|
|
|
|
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 |
|
mask_blurred = mask_blurred[:,:,np.newaxis] |
|
mask_np = 1-(1-mask_np) * (1-mask_blurred) |
|
|
|
image_pasted=init_image_np * (1-mask_np) + image_np*mask_np |
|
image_pasted=image_pasted.astype(image_np.dtype) |
|
image=Image.fromarray(image_pasted) |
|
|
|
image.save(f"{output_dir}/{bb_id}_male_to_male.png") |
|
elif gender == "female": |
|
bb = eval(img["bounding_box"]) |
|
|
|
input_box = np.array([int(bb['x']), int(bb['y']), int(bb['x'])+int(bb['width']), int(bb['y'])+int(bb['height'])]) |
|
|
|
img_id = str(img['filename']).replace(".jpg", "") |
|
bb_id = str(img['person_id']) |
|
|
|
if not int(bb_id) in imagelist: |
|
print("file not in the in-painting list, skipping") |
|
continue |
|
if os.path.exists(f"{output_dir}/{bb_id}_original.png"): |
|
continue |
|
|
|
image_path = f"{root}/{bb_id}.jpg" |
|
mask_path = f"{mask_root}/{bb_id}.jpg" |
|
|
|
if not os.path.exists(mask_path): |
|
print(f"No mask found for image id {bb_id}") |
|
continue |
|
|
|
if not os.path.exists(image_path): |
|
print(f"Image not found for id {bb_id}") |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
init_image = cv2.imread(image_path) |
|
(h, w) = init_image.shape[:2] |
|
if h < 224 or w < 224: |
|
continue |
|
init_image = image_resize(init_image, width=min(512, init_image.shape[1])) |
|
|
|
init_image = init_image[:,:,::-1] |
|
mask_image = 1.*(cv2.resize(cv2.imread(mask_path), (init_image.shape[1], init_image.shape[0])).sum(-1)>255)[:,:,np.newaxis] |
|
|
|
|
|
mask = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") |
|
size = np.asarray(mask).shape |
|
ret, output = cv2.threshold(np.asarray(mask),127,255,cv2.THRESH_BINARY) |
|
is_empty = True if np.sum(output > 127) / (size[0]*size[1]) < 0.10 else False |
|
if is_empty: |
|
print(f"Image {bb_id}: mask percentage: {np.sum(output > 127) / (size[0]*size[1])}") |
|
continue |
|
|
|
cv2.imwrite(f"{output_dir}/{bb_id}_original.png", cv2.imread(image_path)) |
|
init_image = init_image * (1-mask_image) |
|
init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB") |
|
mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") |
|
|
|
cv2.imwrite(f"{output_dir}/{bb_id}_mask.png", np.asarray(mask_image)) |
|
|
|
|
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
|
|
caption = f"A photo of a woman who is a {category}" |
|
image = pipe( |
|
caption, |
|
init_image, |
|
mask_image, |
|
num_inference_steps=steps, |
|
generator=generator, |
|
brushnet_conditioning_scale=brushnet_conditioning_scale |
|
).images[0] |
|
|
|
if blended: |
|
image_np=np.array(image) |
|
init_image_np=cv2.imread(image_path)[:,:,::-1] |
|
mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] |
|
|
|
|
|
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 |
|
mask_blurred = mask_blurred[:,:,np.newaxis] |
|
mask_np = 1-(1-mask_np) * (1-mask_blurred) |
|
|
|
image_pasted=init_image_np * (1-mask_np) + image_np*mask_np |
|
image_pasted=image_pasted.astype(image_np.dtype) |
|
image=Image.fromarray(image_pasted) |
|
|
|
image.save(f"{output_dir}/{bb_id}_female_to_female.png") |
|
|
|
|
|
caption = f"A photo of a man who is a {category}" |
|
image = pipe( |
|
caption, |
|
init_image, |
|
mask_image, |
|
num_inference_steps=steps, |
|
generator=generator, |
|
brushnet_conditioning_scale=brushnet_conditioning_scale |
|
).images[0] |
|
|
|
if blended: |
|
image_np=np.array(image) |
|
init_image_np=cv2.imread(image_path)[:,:,::-1] |
|
mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] |
|
|
|
|
|
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 |
|
mask_blurred = mask_blurred[:,:,np.newaxis] |
|
mask_np = 1-(1-mask_np) * (1-mask_blurred) |
|
|
|
image_pasted=init_image_np * (1-mask_np) + image_np*mask_np |
|
image_pasted=image_pasted.astype(image_np.dtype) |
|
image=Image.fromarray(image_pasted) |
|
|
|
image.save(f"{output_dir}/{bb_id}_female_to_male.png") |
|
|
|
|
|
|