File size: 4,915 Bytes
456aee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cd47b3
456aee9
 
 
04c943b
 
456aee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import numpy as np
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from PIL import Image
torch.backends.cuda.enable_cudnn_sdp(False)  # a fix for torch 2.5.0

from ip_adapter import IPAdapterPlus
from ip_adapter import IPAdapter
# %%
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


@torch.inference_mode()
def extract_clip_embedding_pil(pil_image, ip_model):
    clip_image = ip_model.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
    clip_image = clip_image.to(ip_model.device, dtype=torch.float16)
    clip_image_embeds = ip_model.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
    clip_image_embeds = clip_image_embeds.float()
    return clip_image_embeds

def extract_clip_embedding_pil_batch(pil_images, ip_model):
    feats = []
    for image in pil_images:
        feats.append(extract_clip_embedding_pil(image, ip_model))
    feats = torch.cat(feats, dim=0)
    return feats

@torch.inference_mode()
def extract_clip_embedding_tensor(tensor_image, ip_model):
    tensor_image = tensor_image.to(ip_model.device, dtype=torch.float16)
    tensor_image = torch.nn.functional.interpolate(tensor_image, size=(224, 224), mode="bilinear", align_corners=False)
    clip_image_embeds = ip_model.image_encoder(tensor_image, output_hidden_states=True).hidden_states[-2]
    clip_image_embeds = clip_image_embeds.float()
    return clip_image_embeds


@torch.inference_mode()
def _myheck_ipadapter_get_image_embeds(self, pil_image=None, clip_image_embeds=None):
    if pil_image is not None:
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
        clip_image = clip_image.to(self.device, dtype=torch.float16)
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
    image_prompt_embeds = self.image_proj_model(clip_image_embeds)
    uncond_clip_image_embeds = self.image_encoder(
        torch.zeros(1, 3, 224, 224).to(self.device, dtype=torch.float16),
        output_hidden_states=True
    ).hidden_states[-2]
    uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
    return image_prompt_embeds, uncond_image_prompt_embeds


@torch.inference_mode()
def load_sdxl():

    base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
    vae_model_path = "stabilityai/sd-vae-ft-mse"

    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
        steps_offset=1,
    )
    vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
    # load SD pipeline
    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None,
    )
    return pipe

@torch.inference_mode()
def load_ipadapter(device="cuda"):

    base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
    vae_model_path = "stabilityai/sd-vae-ft-mse"
    image_encoder_path = "./downloads/models/image_encoder"
    ip_ckpt = "./downloads/models/ip-adapter-plus_sd15.bin"

    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
        steps_offset=1,
    )
    vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
    # load SD pipeline
    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
    # load ip-adapter
    ip_model = IPAdapterPlus(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16)

    setattr(ip_model.__class__, "get_image_embeds", _myheck_ipadapter_get_image_embeds)
    
    return ip_model


@torch.inference_mode()
def generate(ip_model, clip_embeds, num_samples=4, num_inference_steps=50, seed=42):
    if clip_embeds.ndim == 2:
        clip_embeds = clip_embeds.unsqueeze(0)
    assert clip_embeds.ndim == 3
    assert clip_embeds.shape[0] == 1
    clip_embeds = clip_embeds.half().to(ip_model.device)
    images = ip_model.generate(clip_image_embeds=clip_embeds, pil_image=None,
        num_samples=num_samples, num_inference_steps=num_inference_steps, seed=seed)
    
    return images