KeyframesAI / test2.py
acmyu's picture
initial commit
3366cca verified
import os
from PIL import Image
import numpy as np
from diffusers import UniPCMultistepScheduler
from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
from src.pipelines.stage2_inpaint_pipeline import Stage2_InpaintDiffusionPipeline
import torch.nn.functional as F
from torchvision import transforms
from diffusers.models.controlnet import ControlNetConditioningEmbedding
from transformers import (
CLIPVisionModelWithProjection,
CLIPImageProcessor,
)
import argparse
from transformers import Dinov2Model
from typing import Any, Dict, List, Optional, Tuple, Union
from skimage.metrics import structural_similarity as compare_ssim
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import json
import time
def split_list_into_chunks(lst, n):
chunk_size = len(lst) // n
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
if len(chunks) > n:
last_chunk = chunks.pop()
chunks[-1].extend(last_chunk)
return chunks
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
class ImageProjModel_p(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
def inference(args):
device = torch.device("cuda")
generator = torch.Generator(device=device).manual_seed(args.seed_number)
# save path
save_dir = "{}/show_guidancescale{}_seed{}_numsteps{}/".format(args.save_path, args.guidance_scale, args.seed_number, args.num_inference_steps)
save_dir_metric = "{}/guidancescale{}_seed{}_numsteps{}/".format(args.save_path, args.guidance_scale, args.seed_number, args.num_inference_steps)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
if not os.path.exists(save_dir_metric):
os.makedirs(save_dir_metric, exist_ok=True)
clip_image_processor = CLIPImageProcessor()
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# model define
image_proj_model_p_dict = {}
pose_proj_dict = {}
unet_dict = {}
image_encoder_g = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_g_path).to(device).eval()
image_encoder_p = Dinov2Model.from_pretrained(args.image_encoder_p_path).to(device).eval()
image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).eval()
pose_proj = ControlNetConditioningEmbedding(320, 3, (16, 32, 96, 256)).to(device).eval()
model_ckpt = args.weights_name
model_sd = torch.load(model_ckpt, map_location="cpu")["module"]
for k in model_sd.keys():
if k.startswith("pose_proj"):
pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
elif k.startswith("image_proj_model_p"):
image_proj_model_p_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
elif k.startswith("unet"):
unet_dict[k.replace("unet.", "")] = model_sd[k]
else:
print(k)
pose_proj.load_state_dict(pose_proj_dict)
image_proj_model_p.load_state_dict(image_proj_model_p_dict)
pipe = Stage2_InpaintDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path,torch_dtype=torch.float16).to(device)
pipe.unet= Stage2_InapintUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",
in_channels=9, class_embed_type="projection",
projection_class_embeddings_input_dim=1024,torch_dtype=torch.float16,
low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
pipe.unet.load_state_dict(unet_dict)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
#print('====================== json_data: {}, model load finish ==================='.format((args.json_path).split('/')[-1]))
data = {
'source_image': 'sm.png',
'target_image': 'pose2.png',
}
s_img_path = (args.img_path + data["source_image"].replace('.jpg', '.png'))
s_pose_path = args.pose_path + data['source_image'].replace('.jpg', '_pose.jpg')
t_img_path = (args.img_path + data["target_image"].replace('.jpg', '.png'))
t_pose_path = (args.pose_path + data["target_image"].replace(".jpg", "_pose.jpg"))
s_img = Image.open(s_img_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC)
t_img = Image.open(t_img_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC)
black_image = Image.new("RGB", s_img.size, (0, 0, 0))
s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
s_img_t_mask.paste(s_img, (0, 0))
s_img_t_mask.paste(black_image, (s_img.width, 0))
s_pose = Image.open(s_pose_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC)
t_pose = Image.open(t_pose_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC)
st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
st_pose.paste(s_pose, (0, 0))
st_pose.paste(t_pose, (s_pose.width, 0))
clip_processor_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
s_img_f = image_encoder_p(clip_processor_s_img.to(device)).last_hidden_state
s_img_proj_f = image_proj_model_p(s_img_f) # s_img
vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
st_pose_f = pose_proj(cond_st_pose.to(device=device)) # t_pose
mode = 'train' # args.json_path.split('/')[-1].split('_')[0]
if mode == "train":
clip_processor_s_img = clip_image_processor(images=t_img, return_tensors="pt").pixel_values
pred_t_img_embed = (image_encoder_g(clip_processor_s_img.to(device)).image_embeds).unsqueeze(1)
#
elif mode == "test":
pred_t_img_embed = torch.tensor(np.load('embed.npy')).to(device)
pred_t_img_embed = pred_t_img_embed.unsqueeze(1)
else:
raise ValueError("Check the input JSON file path")
output = pipe(
height=args.img_height,
width=args.img_width*2,
guidance_rescale=0.0,
vae_image=vae_image,
s_img_proj_f=s_img_proj_f,
st_pose_f=st_pose_f,
pred_t_img_embed = pred_t_img_embed,
num_images_per_prompt=4,
guidance_scale=args.guidance_scale,
generator=generator,
num_inference_steps=args.num_inference_steps,
)
vis_st_pose = Image.new("RGB", (args.img_width*2, args.img_height))
vis_st_pose.paste(Image.open(s_pose_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC), (0, 0))
vis_st_pose.paste(Image.open(t_pose_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC), (args.img_width, 0))
vis_st_image = Image.new("RGB", (args.img_width*2, args.img_height))
vis_st_image.paste(Image.open(s_img_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC), (0, 0))
vis_st_image.paste(Image.open(t_img_path).convert("RGB").resize((args.img_width, args.img_height), Image.BICUBIC), (args.img_width, 0))
if args.calculate_metrics:
ssim_values = []
for gen_img in output.images:
gen_img = gen_img.crop((args.img_width,0, args.img_width*2,args.img_height))
ssim_values.append(compare_ssim(np.array(t_img)*255.0, np.array(gen_img)*255.0,
gaussian_weights=True, sigma=1.2,
use_sample_covariance=False, multichannel=True, channel_axis=2,
data_range=(np.array(gen_img)*255.0).max() - (np.array(gen_img)*255.0).min()
))
max_value = max(ssim_values)
all_ssim.append(max_value)
max_index = ssim_values.index(max_value)
grid_metric = output.images[max_index].crop((args.img_width,0, args.img_width*2,args.img_height))
grid_metric.save(save_dir_metric + s_img_path.split("/")[-1].replace(".png", "") + "_to_" + t_img_path.split("/")[-1])
else:
output.images.insert(0, vis_st_pose)
output.images.insert(0, vis_st_image)
grid = image_grid(output.images, 2, 3)
grid.save('coarse.png')
if args.calculate_metrics:
print(sum(all_ssim)/ len(all_ssim))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of an inpaint model of stage2 script.")
parser.add_argument("--pretrained_model_name_or_path", type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="Path to pretrained model or model identifier from huggingface.co/models.", )
parser.add_argument("--image_encoder_g_path",type=str,default="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", # openai/clip-vit-base-patch32
help="Path to pretrained model or model identifier from huggingface.co/models.",)
parser.add_argument("--image_encoder_p_path",type=str,default="facebook/dinov2-giant",
help="Path to pretrained model or model identifier from huggingface.co/models.",)
parser.add_argument("--img_path", type=str,default="imgs/", help="image path", )
parser.add_argument("--pose_path", type=str,default="imgs/",help="pose path", )
parser.add_argument("--json_path", type=str,default="./datasets/deepfashing/test_data.json",help="json path", )
parser.add_argument("--target_embed_path", type=str,default="./logs/view_stage1/512_512/",help="t_img_embed path", )
parser.add_argument("--save_path", type=str, default="./save_data/stage2", help="save path", ) # ./logs/view_stage2/512_512
parser.add_argument("--guidance_scale",type=int,default=2.0,help="guidance_scale",)
parser.add_argument("--seed_number",type=int,default=42,help="seed number",)
parser.add_argument("--num_inference_steps",type=int,default=20,help="num_inference_steps",)
parser.add_argument("--img_width",type=int,default=512,help="image width",)
parser.add_argument("--img_height",type=int,default=512,help="image height",)
parser.add_argument("--calculate_metrics", action='store_true', help="caculate ssim", )
parser.add_argument("--weights_name", type=str, default="s2_512.pt",help="weights number", )
args = parser.parse_args()
print(args)
inference(args)
"""
num_devices = torch.cuda.device_count()
print("using {} num_processes inference".format(num_devices))
test_data = json.load(open(args.json_path))
select_test_datas = test_data
print(len(select_test_datas))
mp.set_start_method("spawn")
data_list = split_list_into_chunks(select_test_datas, num_devices)
processes = []
for rank in range(num_devices):
p = mp.Process(target=inference, args=(args, rank, data_list[rank] ))
processes.append(p)
p.start()
for rank, p in enumerate(processes):
p.join()
"""