Spaces:
Paused
Paused
import os | |
import json | |
import cv2 | |
import torch | |
from torch import nn | |
from PIL import Image | |
import numpy as np | |
from diffusers import UniPCMultistepScheduler | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | |
from transformers import CLIPImageProcessor | |
from src.pipelines.stage3_refined_pipeline import Stage3_RefinedPipeline | |
import argparse | |
from transformers import Dinov2Model | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from skimage.metrics import structural_similarity as compare_ssim | |
import torch | |
import torch.nn as nn | |
import torch.multiprocessing as mp | |
import json | |
import time | |
def split_list_into_chunks(lst, n): | |
chunk_size = len(lst) // n | |
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] | |
if len(chunks) > n: | |
last_chunk = chunks.pop() | |
chunks[-1].extend(last_chunk) | |
return chunks | |
def image_grid(imgs, rows, cols): | |
assert len(imgs) == rows * cols | |
w, h = imgs[0].size | |
grid = Image.new("RGB", size=(cols * w, rows * h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i % cols * w, i // cols * h)) | |
return grid | |
def zero_module(module): | |
for p in module.parameters(): | |
nn.init.zeros_(p) | |
return module | |
class ImageProjModel_p(torch.nn.Module): | |
"""SD model with image prompt""" | |
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(in_dim, hidden_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.LayerNorm(hidden_dim), | |
nn.Linear(hidden_dim, out_dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): # b, 257,1280 | |
return self.net(x) | |
def inference(): | |
device = "cuda" | |
generator = torch.Generator(device=device).manual_seed(42) | |
clip_image_processor = CLIPImageProcessor() | |
img_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
]) | |
# model define | |
image_proj_model_p_dict = {} | |
unet_dict = {} | |
image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant').to(device).eval() | |
image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).eval() | |
#model_ckpt = "{}/mp_rank_00_model_states.pt".format('{save_ckpt}') | |
model_ckpt = "s3_512.pt" | |
with torch.no_grad(): | |
model_sd = torch.load(model_ckpt)["module"] | |
for k in model_sd.keys(): | |
if k.startswith("image_proj_model_p"): | |
image_proj_model_p_dict[k.replace("image_proj_model_p.", "")] = model_sd[k] | |
elif k.startswith("unet"): | |
unet_dict[k.replace("unet.", "")] = model_sd[k] | |
else: | |
print(k) | |
image_proj_model_p.load_state_dict(image_proj_model_p_dict) | |
pipe = Stage3_RefinedPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base",torch_dtype=torch.float16).to(device) | |
pipe.unet= UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", | |
in_channels=8, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device) | |
pipe.unet.load_state_dict(unet_dict) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_xformers_memory_efficient_attention() | |
all_ssim = [] | |
s_img_path = 'imgs/sm.png' | |
#t_img_path = 'imgs/expected.png' | |
gen_t_img_path = 'imgs/coarse.png' | |
s_img = Image.open(s_img_path).convert("RGB").resize((512,512), Image.BICUBIC) | |
#t_img = Image.open(t_img_path).convert("RGB").resize((512,512), Image.BICUBIC) | |
gen_t_img = Image.open(gen_t_img_path).convert("RGB").resize((512,512), Image.BICUBIC) | |
clip_processor_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values | |
s_img_f = image_encoder_p(clip_processor_s_img.to(device)).last_hidden_state | |
s_img_proj_f = image_proj_model_p(s_img_f) # s_img | |
vae_gen_t_image = torch.unsqueeze(img_transform(gen_t_img), 0) | |
output = pipe( | |
height=512, | |
width=512, | |
guidance_rescale=2.0, | |
vae_gen_t_image=vae_gen_t_image, | |
s_img_proj_f=s_img_proj_f, | |
num_images_per_prompt=4, | |
guidance_scale=1.0, | |
generator=generator, | |
num_inference_steps=20, | |
) | |
for i, r in enumerate(output.images): | |
r.save('out'+str(i)+'.png') | |
save_output = [] | |
result = output.images[0].crop((512, 0, 512 * 2, 512)) | |
save_output.append(result.resize((352, 512), Image.BICUBIC)) | |
save_output.insert(0, gen_t_img.resize((352, 512), Image.BICUBIC)) | |
save_output.insert(0, s_img.resize((352, 512), Image.BICUBIC)) | |
grid = image_grid(save_output, 1, 3) | |
grid.save("out.png") | |
if __name__ == "__main__": | |
inference() | |