NestedAttentionEncoder / nested_attention_pipeline.py
orpatashnik's picture
update
07e3040
import os
from typing import List
import torch
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from safetensors.torch import load_file
from nested_attention_processor import AttnProcessor, NestedAttnProcessor
from utils import get_generator
from resampler import Resampler
def add_special_token_to_tokenizer(
pipe,
placeholder_token,
initializer_token
):
num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token])
num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token])
if num_added_tokens1 != 1 or num_added_tokens2 != 1:
raise ValueError("Failed to add placeholder token to tokenizer")
token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False)
token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False)
if len(token_ids1) > 1 or len(token_ids2) > 1:
raise ValueError("The initializer token must be a single token.")
initializer_token_id1 = token_ids1[0]
initializer_token_id2 = token_ids2[0]
placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token])
placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token])
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2))
token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data
token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data
with torch.no_grad():
for token_id in placeholder_token_ids1:
token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone()
for token_id in placeholder_token_ids2:
token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone()
class NestedAdapterInference:
def __init__(
self,
sd_pipe,
image_encoder_path,
adapter_ckpt,
resampler_num_queries,
vq_normalize_factor,
device,
):
self.device = device
self.image_encoder_path = image_encoder_path
self.adapter_ckpt = adapter_ckpt
self.vq_normalize_factor = vq_normalize_factor
self.pipe = sd_pipe.to(self.device)
self.set_nested_adapter()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
self.image_encoder_path, use_safetensors=True
).to(self.device, dtype=torch.float16)
self.clip_image_processor = CLIPImageProcessor()
# spatial features model
self.qformer = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=resampler_num_queries,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
if adapter_ckpt is not None:
self.load_nested_adapter()
def set_nested_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
attn_procs[name] = NestedAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
normalize_factor=self.vq_normalize_factor,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
def load_nested_adapter(self):
state_dict = {"adapter_modules": {}, "qformer": {}}
f = load_file(self.adapter_ckpt)
for key in f.keys():
if key.startswith("adapter_modules."):
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[
key
]
elif key.startswith("spatial_features_model."):
state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[
key
]
self.qformer.load_state_dict(state_dict["qformer"])
adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
adapter_layers.load_state_dict(state_dict["adapter_modules"])
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(
images=pil_image, return_tensors="pt"
).pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=torch.float16)
)
spatial_clip_image_embeds = clip_image_embeds.last_hidden_state
spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] # remove CLS token
return spatial_clip_image_embeds
def generate(
self,
pil_image=None,
clip_image_embeds=None,
prompt=None,
placeholder_token_ids=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=5.0,
num_inference_steps=30,
multiple_images=False,
special_token_weight=1.0,
**kwargs,
):
if pil_image is not None:
num_prompts = (
1
if isinstance(pil_image, Image.Image) or multiple_images
else len(pil_image)
)
else:
num_prompts = clip_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = (
"monochrome, lowres, bad anatomy, worst quality, low quality"
)
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
text_input_ids = self.pipe.tokenizer(
prompt,
max_length=self.pipe.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[
:, 1
]
spatial_clip_image_embeds = self.get_image_embeds(
pil_image=pil_image, clip_image_embeds=clip_image_embeds
) # (bs, 256, 1280)
with torch.no_grad():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[
:, 1
]
with torch.no_grad():
qformer_tokens_out = self.qformer(spatial_clip_image_embeds)
if multiple_images:
b, num_tokens, d = qformer_tokens_out.shape
qformer_tokens_out = qformer_tokens_out.reshape(
1, num_tokens * b, d
)
bs_embed, num_tokens, _ = qformer_tokens_out.shape
qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1)
qformer_tokens_out = qformer_tokens_out.view(
bs_embed * num_samples, num_tokens, -1
)
qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0)
cross_attention_kwargs = {
"qformer_tokens_out": qformer_tokens_out,
"special_token_indices": special_token_indices,
"special_token_weight": special_token_weight,
"inference_mode": True,
}
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
).images
return images