Spaces:
Paused
Paused
File size: 5,190 Bytes
3366cca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
import json
import cv2
import torch
from torch import nn
from PIL import Image
import numpy as np
from diffusers import UniPCMultistepScheduler
import torch.nn.functional as F
from torchvision import transforms
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPImageProcessor
from src.pipelines.stage3_refined_pipeline import Stage3_RefinedPipeline
import argparse
from transformers import Dinov2Model
from typing import Any, Dict, List, Optional, Tuple, Union
from skimage.metrics import structural_similarity as compare_ssim
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import json
import time
def split_list_into_chunks(lst, n):
chunk_size = len(lst) // n
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
if len(chunks) > n:
last_chunk = chunks.pop()
chunks[-1].extend(last_chunk)
return chunks
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class ImageProjModel_p(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
nn.Dropout(dropout)
)
def forward(self, x): # b, 257,1280
return self.net(x)
def inference():
device = "cuda"
generator = torch.Generator(device=device).manual_seed(42)
clip_image_processor = CLIPImageProcessor()
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# model define
image_proj_model_p_dict = {}
unet_dict = {}
image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant').to(device).eval()
image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).eval()
#model_ckpt = "{}/mp_rank_00_model_states.pt".format('{save_ckpt}')
model_ckpt = "s3_512.pt"
with torch.no_grad():
model_sd = torch.load(model_ckpt)["module"]
for k in model_sd.keys():
if k.startswith("image_proj_model_p"):
image_proj_model_p_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
elif k.startswith("unet"):
unet_dict[k.replace("unet.", "")] = model_sd[k]
else:
print(k)
image_proj_model_p.load_state_dict(image_proj_model_p_dict)
pipe = Stage3_RefinedPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base",torch_dtype=torch.float16).to(device)
pipe.unet= UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet",
in_channels=8, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
pipe.unet.load_state_dict(unet_dict)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
all_ssim = []
s_img_path = 'imgs/sm.png'
#t_img_path = 'imgs/expected.png'
gen_t_img_path = 'imgs/coarse.png'
s_img = Image.open(s_img_path).convert("RGB").resize((512,512), Image.BICUBIC)
#t_img = Image.open(t_img_path).convert("RGB").resize((512,512), Image.BICUBIC)
gen_t_img = Image.open(gen_t_img_path).convert("RGB").resize((512,512), Image.BICUBIC)
clip_processor_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
s_img_f = image_encoder_p(clip_processor_s_img.to(device)).last_hidden_state
s_img_proj_f = image_proj_model_p(s_img_f) # s_img
vae_gen_t_image = torch.unsqueeze(img_transform(gen_t_img), 0)
output = pipe(
height=512,
width=512,
guidance_rescale=2.0,
vae_gen_t_image=vae_gen_t_image,
s_img_proj_f=s_img_proj_f,
num_images_per_prompt=4,
guidance_scale=1.0,
generator=generator,
num_inference_steps=20,
)
for i, r in enumerate(output.images):
r.save('out'+str(i)+'.png')
save_output = []
result = output.images[0].crop((512, 0, 512 * 2, 512))
save_output.append(result.resize((352, 512), Image.BICUBIC))
save_output.insert(0, gen_t_img.resize((352, 512), Image.BICUBIC))
save_output.insert(0, s_img.resize((352, 512), Image.BICUBIC))
grid = image_grid(save_output, 1, 3)
grid.save("out.png")
if __name__ == "__main__":
inference()
|