Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,915 Bytes
456aee9 5cd47b3 456aee9 04c943b 456aee9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import numpy as np
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from PIL import Image
torch.backends.cuda.enable_cudnn_sdp(False) # a fix for torch 2.5.0
from ip_adapter import IPAdapterPlus
from ip_adapter import IPAdapter
# %%
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
@torch.inference_mode()
def extract_clip_embedding_pil(pil_image, ip_model):
clip_image = ip_model.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(ip_model.device, dtype=torch.float16)
clip_image_embeds = ip_model.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
clip_image_embeds = clip_image_embeds.float()
return clip_image_embeds
def extract_clip_embedding_pil_batch(pil_images, ip_model):
feats = []
for image in pil_images:
feats.append(extract_clip_embedding_pil(image, ip_model))
feats = torch.cat(feats, dim=0)
return feats
@torch.inference_mode()
def extract_clip_embedding_tensor(tensor_image, ip_model):
tensor_image = tensor_image.to(ip_model.device, dtype=torch.float16)
tensor_image = torch.nn.functional.interpolate(tensor_image, size=(224, 224), mode="bilinear", align_corners=False)
clip_image_embeds = ip_model.image_encoder(tensor_image, output_hidden_states=True).hidden_states[-2]
clip_image_embeds = clip_image_embeds.float()
return clip_image_embeds
@torch.inference_mode()
def _myheck_ipadapter_get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros(1, 3, 224, 224).to(self.device, dtype=torch.float16),
output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
@torch.inference_mode()
def load_sdxl():
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None,
)
return pipe
@torch.inference_mode()
def load_ipadapter(device="cuda"):
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "./downloads/models/image_encoder"
ip_ckpt = "./downloads/models/ip-adapter-plus_sd15.bin"
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None
)
# load ip-adapter
ip_model = IPAdapterPlus(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16)
setattr(ip_model.__class__, "get_image_embeds", _myheck_ipadapter_get_image_embeds)
return ip_model
@torch.inference_mode()
def generate(ip_model, clip_embeds, num_samples=4, num_inference_steps=50, seed=42):
if clip_embeds.ndim == 2:
clip_embeds = clip_embeds.unsqueeze(0)
assert clip_embeds.ndim == 3
assert clip_embeds.shape[0] == 1
clip_embeds = clip_embeds.half().to(ip_model.device)
images = ip_model.generate(clip_image_embeds=clip_embeds, pil_image=None,
num_samples=num_samples, num_inference_steps=num_inference_steps, seed=seed)
return images |