Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import List | |
import torch | |
from PIL import Image | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from safetensors.torch import load_file | |
from nested_attention_processor import AttnProcessor, NestedAttnProcessor | |
from utils import get_generator | |
from resampler import Resampler | |
def add_special_token_to_tokenizer( | |
pipe, | |
placeholder_token, | |
initializer_token | |
): | |
num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token]) | |
num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token]) | |
if num_added_tokens1 != 1 or num_added_tokens2 != 1: | |
raise ValueError("Failed to add placeholder token to tokenizer") | |
token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False) | |
token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False) | |
if len(token_ids1) > 1 or len(token_ids2) > 1: | |
raise ValueError("The initializer token must be a single token.") | |
initializer_token_id1 = token_ids1[0] | |
initializer_token_id2 = token_ids2[0] | |
placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) | |
placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token]) | |
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) | |
pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2)) | |
token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data | |
token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data | |
with torch.no_grad(): | |
for token_id in placeholder_token_ids1: | |
token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone() | |
for token_id in placeholder_token_ids2: | |
token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone() | |
class NestedAdapterInference: | |
def __init__( | |
self, | |
sd_pipe, | |
image_encoder_path, | |
adapter_ckpt, | |
resampler_num_queries, | |
vq_normalize_factor, | |
device, | |
): | |
self.device = device | |
self.image_encoder_path = image_encoder_path | |
self.adapter_ckpt = adapter_ckpt | |
self.vq_normalize_factor = vq_normalize_factor | |
self.pipe = sd_pipe.to(self.device) | |
self.set_nested_adapter() | |
# load image encoder | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
self.image_encoder_path, use_safetensors=True | |
).to(self.device, dtype=torch.float16) | |
self.clip_image_processor = CLIPImageProcessor() | |
# spatial features model | |
self.qformer = Resampler( | |
dim=self.pipe.unet.config.cross_attention_dim, | |
depth=4, | |
dim_head=64, | |
heads=12, | |
num_queries=resampler_num_queries, | |
embedding_dim=self.image_encoder.config.hidden_size, | |
output_dim=self.pipe.unet.config.cross_attention_dim, | |
ff_mult=4, | |
).to(self.device, dtype=torch.float16) | |
if adapter_ckpt is not None: | |
self.load_nested_adapter() | |
def set_nested_adapter(self): | |
unet = self.pipe.unet | |
attn_procs = {} | |
for name in unet.attn_processors.keys(): | |
cross_attention_dim = ( | |
None | |
if name.endswith("attn1.processor") | |
else unet.config.cross_attention_dim | |
) | |
if name.startswith("mid_block"): | |
hidden_size = unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = unet.config.block_out_channels[block_id] | |
if cross_attention_dim is None: | |
attn_procs[name] = AttnProcessor() | |
else: | |
attn_procs[name] = NestedAttnProcessor( | |
hidden_size=hidden_size, | |
cross_attention_dim=cross_attention_dim, | |
normalize_factor=self.vq_normalize_factor, | |
).to(self.device, dtype=torch.float16) | |
unet.set_attn_processor(attn_procs) | |
def load_nested_adapter(self): | |
state_dict = {"adapter_modules": {}, "qformer": {}} | |
f = load_file(self.adapter_ckpt) | |
for key in f.keys(): | |
if key.startswith("adapter_modules."): | |
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[ | |
key | |
] | |
elif key.startswith("spatial_features_model."): | |
state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[ | |
key | |
] | |
self.qformer.load_state_dict(state_dict["qformer"]) | |
adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) | |
adapter_layers.load_state_dict(state_dict["adapter_modules"]) | |
def get_image_embeds(self, pil_image=None, clip_image_embeds=None): | |
if isinstance(pil_image, Image.Image): | |
pil_image = [pil_image] | |
clip_image = self.clip_image_processor( | |
images=pil_image, return_tensors="pt" | |
).pixel_values | |
clip_image_embeds = self.image_encoder( | |
clip_image.to(self.device, dtype=torch.float16) | |
) | |
spatial_clip_image_embeds = clip_image_embeds.last_hidden_state | |
spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] # remove CLS token | |
return spatial_clip_image_embeds | |
def generate( | |
self, | |
pil_image=None, | |
clip_image_embeds=None, | |
prompt=None, | |
placeholder_token_ids=None, | |
negative_prompt=None, | |
scale=1.0, | |
num_samples=4, | |
seed=None, | |
guidance_scale=5.0, | |
num_inference_steps=30, | |
multiple_images=False, | |
special_token_weight=1.0, | |
**kwargs, | |
): | |
if pil_image is not None: | |
num_prompts = ( | |
1 | |
if isinstance(pil_image, Image.Image) or multiple_images | |
else len(pil_image) | |
) | |
else: | |
num_prompts = clip_image_embeds.size(0) | |
if prompt is None: | |
prompt = "best quality, high quality" | |
if negative_prompt is None: | |
negative_prompt = ( | |
"monochrome, lowres, bad anatomy, worst quality, low quality" | |
) | |
if not isinstance(prompt, List): | |
prompt = [prompt] * num_prompts | |
if not isinstance(negative_prompt, List): | |
negative_prompt = [negative_prompt] * num_prompts | |
text_input_ids = self.pipe.tokenizer( | |
prompt, | |
max_length=self.pipe.tokenizer.model_max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
).input_ids | |
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ | |
:, 1 | |
] | |
spatial_clip_image_embeds = self.get_image_embeds( | |
pil_image=pil_image, clip_image_embeds=clip_image_embeds | |
) # (bs, 256, 1280) | |
with torch.no_grad(): | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = self.pipe.encode_prompt( | |
prompt, | |
num_images_per_prompt=num_samples, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ | |
:, 1 | |
] | |
with torch.no_grad(): | |
qformer_tokens_out = self.qformer(spatial_clip_image_embeds) | |
if multiple_images: | |
b, num_tokens, d = qformer_tokens_out.shape | |
qformer_tokens_out = qformer_tokens_out.reshape( | |
1, num_tokens * b, d | |
) | |
bs_embed, num_tokens, _ = qformer_tokens_out.shape | |
qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1) | |
qformer_tokens_out = qformer_tokens_out.view( | |
bs_embed * num_samples, num_tokens, -1 | |
) | |
qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0) | |
cross_attention_kwargs = { | |
"qformer_tokens_out": qformer_tokens_out, | |
"special_token_indices": special_token_indices, | |
"special_token_weight": special_token_weight, | |
"inference_mode": True, | |
} | |
generator = get_generator(seed, self.device) | |
images = self.pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
cross_attention_kwargs=cross_attention_kwargs, | |
**kwargs, | |
).images | |
return images | |