AisingioroHao0's picture
support pulid
323d67d
from typing import Callable, List, Optional, Tuple, Union
from diffusers.models.attention_processor import Attention
from diffusers.models.embeddings import (
ImageProjection,
IPAdapterPlusImageProjection,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from diffusers.models.normalization import RMSNorm
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class IPAdapterFluxSingleAttnProcessor2_0(nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self, cross_attention_dim, hidden_size, scale=1.0, num_text_tokens=512
):
super().__init__()
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim, hidden_size, bias=True)
self.to_v_ip = nn.Linear(cross_attention_dim, hidden_size, bias=True)
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.ip_hidden_states = None
self.num_text_tokens = 512
nn.init.zeros_(self.to_k_ip.weight)
nn.init.zeros_(self.to_k_ip.bias)
nn.init.zeros_(self.to_v_ip.weight)
nn.init.zeros_(self.to_v_ip.bias)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, _, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
ip_query = query[:, :, self.num_text_tokens :].clone()
# Apply RoPE if needed
if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
## ip adapter
ip_key = self.to_k_ip(self.ip_hidden_states)
ip_value = self.to_v_ip(self.ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_hidden_states = F.scaled_dot_product_attention(
ip_query,
ip_key,
ip_value,
dropout_p=0.0,
is_causal=False,
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states[:, self.num_text_tokens :] += self.scale * ip_hidden_states
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states
class IPAdapterFluxAttnProcessor2_0(nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, cross_attention_dim, hidden_size, scale=1.0):
super().__init__()
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim, hidden_size, bias=True)
self.to_v_ip = nn.Linear(cross_attention_dim, hidden_size, bias=True)
self.ip_hidden_states = None
nn.init.zeros_(self.to_k_ip.weight)
nn.init.zeros_(self.to_k_ip.bias)
nn.init.zeros_(self.to_v_ip.weight)
nn.init.zeros_(self.to_v_ip.bias)
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
ip_query = query.clone()
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# ip adapter
ip_key = self.to_k_ip(self.ip_hidden_states)
ip_value = self.to_v_ip(self.ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_hidden_states = F.scaled_dot_product_attention(
ip_query,
ip_key,
ip_value,
dropout_p=0.0,
is_causal=False,
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states, encoder_hidden_states
def save_ip_adapter(dit, path):
state_dict = {}
state_dict["encoder_hid_proj"] = dit.encoder_hid_proj.state_dict()
for name, module in dit.named_modules():
if isinstance(module, FluxIPAdapterAttnProcessor2_0) or isinstance(
module, FluxIPAdapterSingleAttnProcessor2_0
):
state_dict[name] = module.state_dict()
torch.save(state_dict, path)
def load_ip_adapter(
dit,
path=None,
clip_embeddings_dim=1024,
cross_attention_dim=3072,
num_image_text_embeds=8,
attn_blocks=["single", "double"],
):
if path is not None:
state_dict = torch.load(path, map_location="cpu")
clip_embeddings_dim = state_dict["encoder_hid_proj.image_embeds.weight"].shape[
1
]
num_image_text_embeds = (
state_dict["encoder_hid_proj.image_embeds.weight"].shape[0]
// cross_attention_dim
)
dit.encoder_hid_proj = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
).to(dit.device, dit.dtype)
for name, module in dit.named_modules():
if isinstance(module, Attention):
if "single" in name:
if "single" in attn_blocks:
module.set_processor(
IPAdapterFluxSingleAttnProcessor2_0(
hidden_size=module.query_dim,
cross_attention_dim=cross_attention_dim,
).to(dit.device, dit.dtype)
)
elif "double" in attn_blocks:
module.set_processor(
IPAdapterFluxAttnProcessor2_0(
hidden_size=module.query_dim,
cross_attention_dim=cross_attention_dim,
).to(dit.device, dit.dtype)
)
if path is not None:
dit.load_state_dict(state_dict, strict=False)
def set_ip_hidden_states(dit, image_embeds):
for name, module in dit.named_modules():
if (
isinstance(module, IPAdapterFluxSingleAttnProcessor2_0)
or IPAdapterFluxAttnProcessor2_0
):
module.ip_hidden_states = image_embeds.clone()
def clear_ip_hidden_states(dit):
for name, module in dit.named_modules():
if (
isinstance(module, IPAdapterFluxSingleAttnProcessor2_0)
or IPAdapterFluxAttnProcessor2_0
):
module.ip_hidden_states = None
def set_ip_adapter_scale(dit, scale=1.0):
for name, module in dit.named_modules():
if isinstance(module, IPAdapterFluxSingleAttnProcessor2_0) or isinstance(
module, IPAdapterFluxAttnProcessor2_0
):
module.scale = scale
def load_ip_adapter_plus(
dit,
path=None,
embed_dims=1280,
output_dims=2048,
hidden_dims=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=16,
ffn_ratio=4,
cross_attention_dim=2048,
):
if path is not None:
state_dict = torch.load(path)
else:
state_dict = None
if not hasattr(dit, "encoder_hid_proj") or dit.encoder_hid_proj is None:
dit.encoder_hid_proj = MultiIPAdapterImageProjection(
[
IPAdapterPlusImageProjection(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
depth=depth,
dim_head=dim_head,
heads=heads,
num_queries=num_queries,
ffn_ratio=ffn_ratio,
)
]
).to(dit.device, dit.dtype)
if state_dict is not None:
dit.encoder_hid_proj.load_state_dict(state_dict["encoder_hid_proj"])
dit.config.encoder_hid_dim_type = "ip_image_proj"
for name, module in dit.named_modules():
if "attn2" in name and isinstance(module, Attention):
if not isinstance(module.processor, IPAdapterAttnProcessor2_0):
module.set_processor(
IPAdapterAttnProcessor2_0(
hidden_size=module.query_dim,
cross_attention_dim=cross_attention_dim,
).to(dit.device, dit.dtype)
)
if state_dict is not None:
module.processor.load_state_dict(state_dict[f"{name}.processor"])
else:
module.processor.to_k_ip.load_state_dict(module.to_k.state_dict())
module.processor.to_v_ip.load_state_dict(module.to_v.state_dict())