Spaces:
Running
on
Zero
Running
on
Zero
from diffusers.models.attention_processor import Attention | |
from diffusers.models.embeddings import ImageProjection, MultiIPAdapterImageProjection | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import copy | |
from .resampler import Resampler | |
from typing import Optional | |
from diffusers.image_processor import IPAdapterMaskProcessor | |
import math | |
import warnings | |
from pulid.encoders_transformer import IDFormer | |
def save_ip_adapter(unet, path): | |
state_dict = {} | |
if ( | |
hasattr(unet, "encoder_hid_proj") | |
and unet.encoder_hid_proj is not None | |
and isinstance(unet.encoder_hid_proj, torch.nn.Module) | |
): | |
state_dict["encoder_hid_proj"] = unet.encoder_hid_proj.state_dict() | |
for name, module in unet.attn_processors.items(): | |
if isinstance(module, torch.nn.Module): | |
state_dict[name] = module.state_dict() | |
torch.save(state_dict, path) | |
def load_ip_adapter( | |
unet, | |
path=None, | |
clip_embeddings_dim=1280, | |
cross_attention_dim=2048, | |
num_image_text_embeds=4, | |
attn_blocks=["down", "mid", "up"], | |
): | |
if path is None: | |
state_dict = None | |
else: | |
state_dict = torch.load(path, map_location="cpu") | |
clip_embeddings_dim = state_dict["encoder_hid_proj"][ | |
"image_embeds.weight" | |
].shape[-1] | |
num_image_text_embeds = ( | |
state_dict["encoder_hid_proj"]["image_embeds.weight"].shape[0] | |
// cross_attention_dim | |
) | |
if not hasattr(unet, "encoder_hid_proj") or unet.encoder_hid_proj is None: | |
unet.encoder_hid_proj = ImageProjection( | |
cross_attention_dim=cross_attention_dim, | |
image_embed_dim=clip_embeddings_dim, | |
num_image_text_embeds=num_image_text_embeds, | |
).to(unet.device, unet.dtype) | |
if state_dict is not None: | |
unet.encoder_hid_proj.load_state_dict(state_dict["encoder_hid_proj"]) | |
for name, module in unet.named_modules(): | |
if ( | |
"attn2" in name | |
and isinstance(module, Attention) | |
and any([attn in name for attn in attn_blocks]) | |
): | |
if not isinstance(module.processor, IPAttnProcessor2_0): | |
module.set_processor( | |
IPAttnProcessor2_0( | |
hidden_size=module.query_dim, | |
cross_attention_dim=cross_attention_dim, | |
).to(unet.device, unet.dtype) | |
) | |
if state_dict is not None: | |
module.processor.load_state_dict(state_dict[f"{name}.processor"]) | |
else: | |
module.processor.to_k_ip.load_state_dict(module.to_k.state_dict()) | |
module.processor.to_v_ip.load_state_dict(module.to_v.state_dict()) | |
def parse_clip_embeddings_dim( | |
path, | |
state_dict, | |
): | |
if "pulid" in path: | |
return None | |
else: | |
return state_dict["encoder_hid_proj"]["image_embeds.weight"].shape[-1] | |
def parse_num_image_text_embeds(path, state_dict, cross_attention_dim=2048): | |
if "pulid" in path: | |
return None | |
else: | |
return ( | |
state_dict["encoder_hid_proj"]["image_embeds.weight"].shape[0] | |
// cross_attention_dim | |
) | |
def parse_encoder_hid_proj_module( | |
path=None, | |
cross_attention_dim=2048, | |
image_embed_dim=None, | |
num_image_text_embeds=None, | |
): | |
if "pulid" in path: | |
return IDFormer() | |
else: | |
return ImageProjection( | |
cross_attention_dim=cross_attention_dim, | |
image_embed_dim=image_embed_dim, | |
num_image_text_embeds=num_image_text_embeds, | |
) | |
def load_multi_ip_adapter( | |
unet, | |
paths=None, | |
clip_embeddings_dim=[1280], | |
cross_attention_dim=2048, | |
num_image_text_embeds=[4], | |
): | |
if paths is None: | |
state_dict = None | |
else: | |
state_dict = [torch.load(path, map_location="cpu") for path in paths] | |
clip_embeddings_dim = [ | |
parse_clip_embeddings_dim(path=single_path, state_dict=single_state_dict) | |
for single_path, single_state_dict in zip(paths, state_dict) | |
] | |
num_image_text_embeds = [ | |
parse_num_image_text_embeds( | |
path=single_path, | |
state_dict=single_state_dict, | |
cross_attention_dim=unet.config.cross_attention_dim, | |
) | |
for single_path, single_state_dict in zip(paths, state_dict) | |
] | |
if not hasattr(unet, "encoder_hid_proj") or unet.encoder_hid_proj is None: | |
unet.encoder_hid_proj = MultiIPAdapterImageProjection( | |
[ | |
parse_encoder_hid_proj_module( | |
path=single_path, | |
cross_attention_dim=unet.config.cross_attention_dim, | |
image_embed_dim=single_clip_embeddings_dim, | |
num_image_text_embeds=single_num_image_text_embeds, | |
).to(unet.device, unet.dtype) | |
for single_path, single_clip_embeddings_dim, single_num_image_text_embeds in zip( | |
paths, clip_embeddings_dim, num_image_text_embeds | |
) | |
] | |
).to(unet.device, unet.dtype) | |
if state_dict is not None: | |
for single_encoder_hid_proj, single_state_dict in zip( | |
unet.encoder_hid_proj.image_projection_layers, state_dict | |
): | |
single_encoder_hid_proj.load_state_dict( | |
single_state_dict["encoder_hid_proj"] | |
) | |
for name, module in unet.named_modules(): | |
if "attn2" in name and isinstance(module, Attention): | |
if not isinstance(module.processor, MultiIPAttnProcessor2_0): | |
module.set_processor( | |
MultiIPAttnProcessor2_0( | |
hidden_size=module.query_dim, | |
cross_attention_dim=unet.config.cross_attention_dim, | |
num_tokens=num_image_text_embeds, | |
).to(unet.device, unet.dtype) | |
) | |
if state_dict is not None: | |
for ( | |
to_k_ip, | |
to_v_ip, | |
single_state_dict, | |
) in zip( | |
module.processor.to_k_ip, | |
module.processor.to_v_ip, | |
state_dict, | |
): | |
if f"{name}.processor" in single_state_dict.keys(): | |
to_k_ip.weight = nn.Parameter( | |
single_state_dict[f"{name}.processor"]["to_k_ip.weight"] | |
) | |
to_v_ip.weight = nn.Parameter( | |
single_state_dict[f"{name}.processor"]["to_v_ip.weight"] | |
) | |
module.processor = module.processor.to(unet.device, unet.dtype) | |
def load_ip_adapter_plus( | |
unet, | |
path=None, | |
embed_dims=1664, | |
depth=4, | |
dim_head=64, | |
heads=12, | |
num_queries=32, | |
ff_mult=4, | |
attn_blocks=["down", "mid", "up"], | |
): | |
if path is not None: | |
state_dict = torch.load(path) | |
else: | |
state_dict = None | |
if not hasattr(unet, "encoder_hid_proj") or unet.encoder_hid_proj is None: | |
unet.encoder_hid_proj = Resampler( | |
dim=unet.config.cross_attention_dim, | |
depth=depth, | |
dim_head=dim_head, | |
heads=heads, | |
num_queries=num_queries, | |
embedding_dim=embed_dims, | |
output_dim=unet.config.cross_attention_dim, | |
ff_mult=ff_mult, | |
).to(unet.device, unet.dtype) | |
if state_dict is not None: | |
unet.encoder_hid_proj.load_state_dict(state_dict["encoder_hid_proj"]) | |
for name, module in unet.named_modules(): | |
if ( | |
"attn2" in name | |
and isinstance(module, Attention) | |
and any([attn in name for attn in attn_blocks]) | |
): | |
if not isinstance(module.processor, IPAttnProcessor2_0): | |
module.set_processor( | |
IPAttnProcessor2_0( | |
hidden_size=module.query_dim, | |
cross_attention_dim=unet.config.cross_attention_dim, | |
).to(unet.device, unet.dtype) | |
) | |
if state_dict is not None and f"{name}.processor" in state_dict.keys(): | |
module.processor.load_state_dict(state_dict[f"{name}.processor"]) | |
else: | |
module.processor.to_k_ip.load_state_dict(module.to_k.state_dict()) | |
module.processor.to_v_ip.load_state_dict(module.to_v.state_dict()) | |
def set_ip_hidden_states(unet, image_embeds): | |
for name, module in unet.attn_processors.items(): | |
if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
module, MultiIPAttnProcessor2_0 | |
): | |
module.ip_hidden_states = image_embeds.clone() | |
def set_multi_ip_hidden_states(unet, image_embeds): | |
for name, module in unet.attn_processors.items(): | |
if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
module, MultiIPAttnProcessor2_0 | |
): | |
module.ip_hidden_states = image_embeds | |
def set_multi_ip_attn_masks(unet, attn_masks): | |
for name, module in unet.attn_processors.items(): | |
if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
module, MultiIPAttnProcessor2_0 | |
): | |
module.ip_hidden_states = attn_masks | |
def clear_ip_hidden_states(model): | |
for name, module in model.named_modules(): | |
if isinstance(module, IPAttnProcessor2_0): | |
module.ip_hidden_states = None | |
def set_ip_adapter_scale(unet, scale=1.0, attn_blocks=["down", "mid", "up"]): | |
for name, module in unet.named_modules(): | |
if isinstance(module, IPAttnProcessor2_0) and any( | |
tarhet_module in name for tarhet_module in attn_blocks | |
): | |
module.scale = scale | |
def downsample( | |
mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int | |
): | |
""" | |
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the | |
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. | |
Args: | |
mask (`torch.Tensor`): | |
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. | |
batch_size (`int`): | |
The batch size. | |
num_queries (`int`): | |
The number of queries. | |
value_embed_dim (`int`): | |
The dimensionality of the value embeddings. | |
Returns: | |
`torch.Tensor`: | |
The downsampled mask tensor. | |
""" | |
o_h = mask.shape[2] | |
o_w = mask.shape[3] | |
ratio = o_w / o_h | |
mask_h = int(math.sqrt(num_queries / ratio)) | |
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) | |
mask_w = num_queries // mask_h | |
mask_downsample = F.interpolate(mask, size=(mask_h, mask_w), mode="bicubic") | |
# Repeat batch_size times | |
if mask_downsample.shape[0] < batch_size: | |
mask_downsample = mask_downsample.repeat(batch_size, 1, 1, 1) | |
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) | |
downsampled_area = mask_h * mask_w | |
# If the output image and the mask do not have the same aspect ratio, tensor shapes will not match | |
# Pad tensor if downsampled_mask.shape[1] is smaller than num_queries | |
if downsampled_area < num_queries: | |
warnings.warn( | |
"The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
"Please update your masks or adjust the output size for optimal performance.", | |
UserWarning, | |
) | |
mask_downsample = F.pad( | |
mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0 | |
) | |
# Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries | |
if downsampled_area > num_queries: | |
warnings.warn( | |
"The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
"Please update your masks or adjust the output size for optimal performance.", | |
UserWarning, | |
) | |
mask_downsample = mask_downsample[:, :num_queries] | |
# Repeat last dimension to match SDPA output shape | |
mask_downsample = mask_downsample.view( | |
mask_downsample.shape[0], mask_downsample.shape[1], 1 | |
).repeat(1, 1, value_embed_dim) | |
return mask_downsample | |
class IPAttnProcessor2_0(torch.nn.Module): | |
r""" | |
Attention processor for IP-Adapater for PyTorch 2.0. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
scale (`float`, defaults to 1.0): | |
the weight scale of image prompt. | |
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
The context length of the image features. | |
""" | |
def __init__( | |
self, | |
hidden_size, | |
cross_attention_dim=None, | |
scale=1.0, | |
num_tokens=4, | |
use_align_sem_and_layout_loss=False, | |
): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
) | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.scale = scale | |
self.num_tokens = num_tokens | |
self.to_k_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
self.to_v_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
self.ip_hidden_states = None | |
self.use_align_sem_and_layout_loss = use_align_sem_and_layout_loss | |
if self.use_align_sem_and_layout_loss: | |
self.align_sem_loss = None | |
self.align_layout_loss = None | |
self.cache_query = None | |
self.cache_attn_weights = None | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
ip_adapter_masks: Optional[torch.FloatTensor] = None, | |
*args, | |
**kwargs, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view( | |
batch_size, attn.heads, -1, attention_mask.shape[-1] | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
if self.use_align_sem_and_layout_loss: | |
if self.cache_query is None: | |
self.cache_query = query.clone().detach() | |
self.cache_attn_weights = (key @ query.transpose(-2, -1)) / math.sqrt( | |
query.size(-1) | |
) | |
self.cache_attn_weights = torch.softmax(self.cache_attn_weights, dim=-1) | |
else: | |
self.attn_weights = (key @ query.transpose(-2, -1)) / math.sqrt( | |
query.size(-1) | |
) | |
self.query = query | |
self.attn_weights = torch.softmax(self.attn_weights, dim=-1) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
if self.scale != 0.0: | |
# for ip-adapter | |
ip_key = self.to_k_ip(self.ip_hidden_states).to(dtype=query.dtype) | |
ip_value = self.to_v_ip(self.ip_hidden_states).to(dtype=query.dtype) | |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose( | |
1, 2 | |
) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
ip_hidden_states = F.scaled_dot_product_attention( | |
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | |
) | |
# with torch.no_grad(): | |
# self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) | |
# print(self.attn_map.shape) | |
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
ip_hidden_states = ip_hidden_states.to(query.dtype) | |
if ip_adapter_masks is not None: | |
mask_downsample = downsample( | |
ip_adapter_masks, | |
batch_size, | |
ip_hidden_states.shape[1], | |
ip_hidden_states.shape[2], | |
) | |
mask_downsample = mask_downsample.to( | |
dtype=query.dtype, device=query.device | |
) | |
ip_hidden_states = ip_hidden_states * mask_downsample | |
hidden_states = hidden_states + self.scale * ip_hidden_states | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def set_ortho(unet, ortho): | |
for name, module in unet.attn_processors.items(): | |
if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
module, MultiIPAttnProcessor2_0 | |
): | |
module.ortho = ortho | |
def set_num_zero(unet, num_zero): | |
for name, module in unet.attn_processors.items(): | |
if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
module, MultiIPAttnProcessor2_0 | |
): | |
module.num_zero = num_zero | |
class MultiIPAttnProcessor2_0(torch.nn.Module): | |
r""" | |
Attention processor for IP-Adapater for PyTorch 2.0. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): | |
The context length of the image features. | |
scale (`float` or `List[float]`, defaults to 1.0): | |
the weight scale of image prompt. | |
""" | |
def __init__( | |
self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0 | |
): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
) | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
if not isinstance(num_tokens, (tuple, list)): | |
num_tokens = [num_tokens] | |
self.num_tokens = num_tokens | |
if not isinstance(scale, list): | |
scale = [scale] * len(num_tokens) | |
if len(scale) != len(num_tokens): | |
raise ValueError( | |
"`scale` should be a list of integers with the same length as `num_tokens`." | |
) | |
self.scale = scale | |
self.to_k_ip = nn.ModuleList( | |
[ | |
nn.Linear(cross_attention_dim, hidden_size, bias=False) | |
for _ in range(len(num_tokens)) | |
] | |
) | |
self.to_v_ip = nn.ModuleList( | |
[ | |
nn.Linear(cross_attention_dim, hidden_size, bias=False) | |
for _ in range(len(num_tokens)) | |
] | |
) | |
self.ip_hidden_states = None | |
self.num_zero = [None] * (len(num_tokens)) | |
self.ortho = [None] * len(num_tokens) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
ip_adapter_masks: Optional[torch.FloatTensor] = None, | |
): | |
residual = hidden_states | |
ip_hidden_states = self.ip_hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view( | |
batch_size, attn.heads, -1, attention_mask.shape[-1] | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
if ip_adapter_masks is not None: | |
if ( | |
not isinstance(ip_adapter_masks, torch.Tensor) | |
or ip_adapter_masks.ndim != 4 | |
): | |
raise ValueError( | |
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." | |
" Please use `IPAdapterMaskProcessor` to preprocess your mask" | |
) | |
if len(ip_adapter_masks) != len(self.scale): | |
raise ValueError( | |
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" | |
) | |
else: | |
ip_adapter_masks = [None] * len(self.scale) | |
# for ip-adapter | |
for ( | |
current_ip_hidden_states, | |
scale, | |
to_k_ip, | |
to_v_ip, | |
mask, | |
num_zero, | |
ortho, | |
) in zip( | |
ip_hidden_states, | |
self.scale, | |
self.to_k_ip, | |
self.to_v_ip, | |
ip_adapter_masks, | |
self.num_zero, | |
self.ortho, | |
): | |
if scale == 0: | |
continue | |
if num_zero is not None: | |
zero_tensor = torch.zeros( | |
( | |
current_ip_hidden_states.size(0), | |
num_zero, | |
current_ip_hidden_states.size(-1), | |
), | |
dtype=current_ip_hidden_states.dtype, | |
device=current_ip_hidden_states.device, | |
) | |
current_ip_hidden_states = torch.concat( | |
[current_ip_hidden_states, zero_tensor], dim=1 | |
) | |
ip_key = to_k_ip(current_ip_hidden_states) | |
ip_value = to_v_ip(current_ip_hidden_states) | |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose( | |
1, 2 | |
) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
current_ip_hidden_states = F.scaled_dot_product_attention( | |
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | |
) | |
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) | |
if mask is not None: | |
mask_downsample = IPAdapterMaskProcessor.downsample( | |
mask, | |
batch_size, | |
current_ip_hidden_states.shape[1], | |
current_ip_hidden_states.shape[2], | |
) | |
mask_downsample = mask_downsample.to( | |
dtype=query.dtype, device=query.device | |
) | |
current_ip_hidden_states = current_ip_hidden_states * mask_downsample | |
if ortho is None: | |
hidden_states = hidden_states + scale * current_ip_hidden_states | |
elif ortho == "ortho": | |
orig_dtype = hidden_states.dtype | |
hidden_states = hidden_states.to(torch.float32) | |
current_ip_hidden_states = current_ip_hidden_states.to(torch.float32) | |
projection = ( | |
torch.sum( | |
(hidden_states * current_ip_hidden_states), dim=-2, keepdim=True | |
) | |
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) | |
* hidden_states | |
) | |
orthogonal = current_ip_hidden_states - projection | |
hidden_states = hidden_states + current_ip_hidden_states * orthogonal | |
hidden_states = hidden_states.to(orig_dtype) | |
elif ortho == "ortho_v2": | |
orig_dtype = hidden_states.dtype | |
hidden_states = hidden_states.to(torch.float32) | |
current_ip_hidden_states = current_ip_hidden_states.to(torch.float32) | |
attn_map = query @ ip_key.transpose(-2, -1) | |
attn_mean = attn_map.softmax(dim=-1).mean(dim=1) | |
attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) | |
projection = ( | |
torch.sum( | |
(hidden_states * current_ip_hidden_states), dim=-2, keepdim=True | |
) | |
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) | |
* hidden_states | |
) | |
orthogonal = current_ip_hidden_states + (attn_mean - 1) * projection | |
hidden_states = hidden_states + current_ip_hidden_states * orthogonal | |
hidden_states = hidden_states.to(orig_dtype) | |
else: | |
raise ValueError(f"{ortho} not supported") | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |