from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler import os import torch import cv2 import numpy as np import pandas as pd from tqdm import tqdm from PIL import Image imagelist = pd.read_csv("examples/brushnet/paper_imagelist_for_inpainting.csv").values def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA): # initialize the dimensions of the image to be resized and # grab the image size dim = None (h, w) = image.shape[:2] # if both the width and height are None, then return the # original image if width is None and height is None: return image # check to see if the width is None if width is None: # calculate the ratio of the height and construct the # dimensions r = height / float(h) dim = (int(w * r), height) # otherwise, the height is None else: # calculate the ratio of the width and construct the # dimensions r = width / float(w) dim = (width, int(h * r)) # resize the image resized = cv2.resize(image, dim, interpolation = inter) # return the resized image return resized # choose the base model here base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" # base_model_path = "runwayml/stable-diffusion-v1-5" # input brushnet ckpt path brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" # choose whether using blended operation blended = False occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter'] facet = pd.read_csv("../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'}) # Bounding boxes root = "../../datasets/facet/images_bb" # mask_root = "../Color-Invariant-Skin-Segmentation/FCN/paper_output/final_skin_mask" mask_root = "../Color-Invariant-Skin-Segmentation/FCN/output/person_masks" # mask_root = "../Color-Invariant-Skin-Segmentation/FCN/paper_output/clothes_mask" # output_dir = "facet_paper_skin_ours" # output_dir = "facet_paper_clothes_only" output_dir = "/home/kis/datasets/facet_paper_whole_body_occupation_prompt_filelist/" if not os.path.exists(output_dir): os.makedirs(output_dir) steps = 10 # conditioning scale brushnet_conditioning_scale = 1.0 seed = 12345 brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16) pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False, safety_checker=None, requires_safety_checker=False ) # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) # remove following line if xformers is not installed or when using Torch 2.0. # pipe.enable_xformers_memory_efficient_attention() # memory optimization. pipe.enable_model_cpu_offload() for category in occupations: n_imgs = facet[facet['class1'] == category]['person_id'].shape[0] for id_ in tqdm(range(n_imgs)): img = facet[facet['class1'] == category].iloc[id_] if int(img['visible_face']) != 1: continue if int(img['gender_presentation_masc']) == 1: gender = 'male' elif int(img['gender_presentation_fem']) == 1: gender = 'female' else: continue if gender == 'male': bb = eval(img["bounding_box"]) input_box = np.array([int(bb['x']), int(bb['y']), int(bb['x'])+int(bb['width']), int(bb['y'])+int(bb['height'])]) img_id = str(img['filename']).replace(".jpg", "") bb_id = str(img['person_id']) if not int(bb_id) in imagelist: print("file not in the in-painting list, skipping") continue if os.path.exists(f"{output_dir}/{bb_id}_original.png"): print(f"Skipping image {bb_id}: already processed") continue image_path = f"{root}/{bb_id}.jpg" mask_path = f"{mask_root}/{bb_id}.jpg" if not os.path.exists(mask_path): print(f"No mask found for image id {bb_id}") continue if not os.path.exists(image_path): print(f"No image found for image id {bb_id}") continue init_image = cv2.imread(image_path) (h, w) = init_image.shape[:2] if h < 224 or w < 224: print(f"Skipping image as it is too small: {h}x{w}") continue init_image = image_resize(init_image, width=min(512, init_image.shape[1])) init_image = init_image[:,:,::-1] mask_image = 1.*(cv2.resize(cv2.imread(mask_path), (init_image.shape[1], init_image.shape[0])).sum(-1)>255)[:,:,np.newaxis] mask = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") size = np.asarray(mask).shape ret, output = cv2.threshold(np.asarray(mask),127,255,cv2.THRESH_BINARY) is_empty = True if np.sum(output > 127) / (size[0]*size[1]) < 0.30 else False if is_empty: print(f"Skipping image {bb_id}: mask percentage: {np.sum(output > 127) / (size[0]*size[1])}") continue cv2.imwrite(f"{output_dir}/{bb_id}_original.png", cv2.imread(image_path)) init_image = init_image * (1-mask_image) init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB") mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") cv2.imwrite(f"{output_dir}/{bb_id}_mask.png", np.asarray(mask_image)) generator = torch.Generator("cuda").manual_seed(seed) # caption = f"A photo of a woman"# who is a {category}" caption = f"A photo of a woman who is a {category}" print(f"Image {bb_id}: {caption}") image = pipe( caption, init_image, mask_image, num_inference_steps=steps, generator=generator, brushnet_conditioning_scale=brushnet_conditioning_scale ).images[0] if blended: image_np=np.array(image) init_image_np=cv2.imread(image_path)[:,:,::-1] mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] # blur, you can adjust the parameters for better performance mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 mask_blurred = mask_blurred[:,:,np.newaxis] mask_np = 1-(1-mask_np) * (1-mask_blurred) image_pasted=init_image_np * (1-mask_np) + image_np*mask_np image_pasted=image_pasted.astype(image_np.dtype) image=Image.fromarray(image_pasted) # image.save(f"examples/brushnet/{output_dir}/{bb_id}_male_to_female.png") image.save(f"{output_dir}/{bb_id}_male_to_female.png") # caption = f"A photo of a man"# who is a {category}" caption = f"A photo of a man who is a {category}" print(f"Image {bb_id}: {caption}") image = pipe( caption, init_image, mask_image, num_inference_steps=steps, generator=generator, brushnet_conditioning_scale=brushnet_conditioning_scale ).images[0] if blended: image_np=np.array(image) init_image_np=cv2.imread(image_path)[:,:,::-1] mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] # blur, you can adjust the parameters for better performance mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 mask_blurred = mask_blurred[:,:,np.newaxis] mask_np = 1-(1-mask_np) * (1-mask_blurred) image_pasted=init_image_np * (1-mask_np) + image_np*mask_np image_pasted=image_pasted.astype(image_np.dtype) image=Image.fromarray(image_pasted) image.save(f"{output_dir}/{bb_id}_male_to_male.png") elif gender == "female": bb = eval(img["bounding_box"]) input_box = np.array([int(bb['x']), int(bb['y']), int(bb['x'])+int(bb['width']), int(bb['y'])+int(bb['height'])]) img_id = str(img['filename']).replace(".jpg", "") bb_id = str(img['person_id']) if not int(bb_id) in imagelist: print("file not in the in-painting list, skipping") continue if os.path.exists(f"{output_dir}/{bb_id}_original.png"): continue image_path = f"{root}/{bb_id}.jpg" mask_path = f"{mask_root}/{bb_id}.jpg" if not os.path.exists(mask_path): print(f"No mask found for image id {bb_id}") continue if not os.path.exists(image_path): print(f"Image not found for id {bb_id}") continue init_image = cv2.imread(image_path) (h, w) = init_image.shape[:2] if h < 224 or w < 224: continue init_image = image_resize(init_image, width=min(512, init_image.shape[1])) init_image = init_image[:,:,::-1] mask_image = 1.*(cv2.resize(cv2.imread(mask_path), (init_image.shape[1], init_image.shape[0])).sum(-1)>255)[:,:,np.newaxis] mask = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") size = np.asarray(mask).shape ret, output = cv2.threshold(np.asarray(mask),127,255,cv2.THRESH_BINARY) is_empty = True if np.sum(output > 127) / (size[0]*size[1]) < 0.10 else False if is_empty: print(f"Image {bb_id}: mask percentage: {np.sum(output > 127) / (size[0]*size[1])}") continue cv2.imwrite(f"{output_dir}/{bb_id}_original.png", cv2.imread(image_path)) init_image = init_image * (1-mask_image) init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB") mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") # cv2.imwrite(f"examples/brushnet/{output_dir}/{bb_id}_mask.png", np.asarray(mask_image)) cv2.imwrite(f"{output_dir}/{bb_id}_mask.png", np.asarray(mask_image)) generator = torch.Generator("cuda").manual_seed(seed) # caption = f"A photo of a woman"# who is a {category}" caption = f"A photo of a woman who is a {category}" image = pipe( caption, init_image, mask_image, num_inference_steps=steps, generator=generator, brushnet_conditioning_scale=brushnet_conditioning_scale ).images[0] if blended: image_np=np.array(image) init_image_np=cv2.imread(image_path)[:,:,::-1] mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] # blur, you can adjust the parameters for better performance mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 mask_blurred = mask_blurred[:,:,np.newaxis] mask_np = 1-(1-mask_np) * (1-mask_blurred) image_pasted=init_image_np * (1-mask_np) + image_np*mask_np image_pasted=image_pasted.astype(image_np.dtype) image=Image.fromarray(image_pasted) image.save(f"{output_dir}/{bb_id}_female_to_female.png") caption = f"A photo of a man who is a {category}" image = pipe( caption, init_image, mask_image, num_inference_steps=steps, generator=generator, brushnet_conditioning_scale=brushnet_conditioning_scale ).images[0] if blended: image_np=np.array(image) init_image_np=cv2.imread(image_path)[:,:,::-1] mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] # blur, you can adjust the parameters for better performance mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 mask_blurred = mask_blurred[:,:,np.newaxis] mask_np = 1-(1-mask_np) * (1-mask_blurred) image_pasted=init_image_np * (1-mask_np) + image_np*mask_np image_pasted=image_pasted.astype(image_np.dtype) image=Image.fromarray(image_pasted) image.save(f"{output_dir}/{bb_id}_female_to_male.png")