KeyframesAI / test3.py
acmyu's picture
initial commit
3366cca verified
import os
import json
import cv2
import torch
from torch import nn
from PIL import Image
import numpy as np
from diffusers import UniPCMultistepScheduler
import torch.nn.functional as F
from torchvision import transforms
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPImageProcessor
from src.pipelines.stage3_refined_pipeline import Stage3_RefinedPipeline
import argparse
from transformers import Dinov2Model
from typing import Any, Dict, List, Optional, Tuple, Union
from skimage.metrics import structural_similarity as compare_ssim
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import json
import time
def split_list_into_chunks(lst, n):
chunk_size = len(lst) // n
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
if len(chunks) > n:
last_chunk = chunks.pop()
chunks[-1].extend(last_chunk)
return chunks
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class ImageProjModel_p(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
nn.Dropout(dropout)
)
def forward(self, x): # b, 257,1280
return self.net(x)
def inference():
device = "cuda"
generator = torch.Generator(device=device).manual_seed(42)
clip_image_processor = CLIPImageProcessor()
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# model define
image_proj_model_p_dict = {}
unet_dict = {}
image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant').to(device).eval()
image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).eval()
#model_ckpt = "{}/mp_rank_00_model_states.pt".format('{save_ckpt}')
model_ckpt = "s3_512.pt"
with torch.no_grad():
model_sd = torch.load(model_ckpt)["module"]
for k in model_sd.keys():
if k.startswith("image_proj_model_p"):
image_proj_model_p_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
elif k.startswith("unet"):
unet_dict[k.replace("unet.", "")] = model_sd[k]
else:
print(k)
image_proj_model_p.load_state_dict(image_proj_model_p_dict)
pipe = Stage3_RefinedPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base",torch_dtype=torch.float16).to(device)
pipe.unet= UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet",
in_channels=8, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
pipe.unet.load_state_dict(unet_dict)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
all_ssim = []
s_img_path = 'imgs/sm.png'
#t_img_path = 'imgs/expected.png'
gen_t_img_path = 'imgs/coarse.png'
s_img = Image.open(s_img_path).convert("RGB").resize((512,512), Image.BICUBIC)
#t_img = Image.open(t_img_path).convert("RGB").resize((512,512), Image.BICUBIC)
gen_t_img = Image.open(gen_t_img_path).convert("RGB").resize((512,512), Image.BICUBIC)
clip_processor_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
s_img_f = image_encoder_p(clip_processor_s_img.to(device)).last_hidden_state
s_img_proj_f = image_proj_model_p(s_img_f) # s_img
vae_gen_t_image = torch.unsqueeze(img_transform(gen_t_img), 0)
output = pipe(
height=512,
width=512,
guidance_rescale=2.0,
vae_gen_t_image=vae_gen_t_image,
s_img_proj_f=s_img_proj_f,
num_images_per_prompt=4,
guidance_scale=1.0,
generator=generator,
num_inference_steps=20,
)
for i, r in enumerate(output.images):
r.save('out'+str(i)+'.png')
save_output = []
result = output.images[0].crop((512, 0, 512 * 2, 512))
save_output.append(result.resize((352, 512), Image.BICUBIC))
save_output.insert(0, gen_t_img.resize((352, 512), Image.BICUBIC))
save_output.insert(0, s_img.resize((352, 512), Image.BICUBIC))
grid = image_grid(save_output, 1, 3)
grid.save("out.png")
if __name__ == "__main__":
inference()