LL3RD's picture
test
f96f677
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
from diffusers.models.attention import Attention
from diffusers.models.embeddings import Timesteps, TimestepEmbedding, PixArtAlphaTextProjection
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
if (guidance >= 0).all():
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
else:
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
if cos.ndim == 2:
cos = cos[None, None]
else:
cos = cos.unsqueeze(1)
if sin.ndim == 2:
sin = sin[None, None]
else:
sin = sin.unsqueeze(1)
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class FluxAttnSharedProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
data_num_per_group: Optional[int] = 1,
max_sequence_length: Optional[int] = 512,
mix_attention: bool = True,
cond_latents = None,
cond_image_rotary_emb = None,
work_mode = None,
mask_cond = None,
) -> torch.FloatTensor:
with_cond = cond_latents is not None and mix_attention
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if with_cond:
cond_bs = cond_latents.shape[0]
# update condition
cond_query = attn.to_q(cond_latents)
cond_query = cond_query.view(cond_bs, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
cond_query = attn.norm_q(cond_query)
cond_query = apply_rotary_emb(cond_query, cond_image_rotary_emb)
cond_query = torch.cat(cond_query.chunk(len(cond_query), dim=0), dim=2)
cond_key = attn.to_k(cond_latents)
cond_value = attn.to_v(cond_latents)
cond_key = cond_key.view(cond_bs, -1, attn.heads, head_dim).transpose(1, 2)
cond_value = cond_value.view(cond_bs, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_k is not None:
cond_key = attn.norm_k(cond_key)
cond_key = apply_rotary_emb(cond_key, cond_image_rotary_emb)
cond_key = torch.cat(cond_key.chunk(len(cond_key), dim=0), dim=2)
cond_value = torch.cat(cond_value.chunk(len(cond_value), dim=0), dim=2)
if data_num_per_group > 1 and mix_attention:
E = max_sequence_length # according to text len
key_enc, key_hid = key[:, :, :E], key[:, :, E:]
value_enc, value_hid = value[:, :, :E], value[:, :, E:]
key_layer = key_hid.chunk(data_num_per_group, dim=0)
key_layer = torch.cat(key_layer, dim=2).repeat(data_num_per_group, 1, 1, 1)
value_layer = value_hid.chunk(data_num_per_group, dim=0)
value_layer = torch.cat(value_layer, dim=2).repeat(data_num_per_group, 1, 1, 1)
key = torch.cat([key_enc, key_layer], dim=2)
value = torch.cat([value_enc, value_layer], dim=2)
elif data_num_per_group == 1 and mix_attention and with_cond:
E = max_sequence_length # according to text len
key_enc, key_hid = key[:, :, :E], key[:, :, E:]
value_enc, value_hid = value[:, :, :E], value[:, :, E:]
# todo: support bs != 1
key_layer = torch.cat([key_hid, cond_key], dim=2)
value_layer = torch.cat([value_hid, cond_value], dim=2)
key = torch.cat([key_enc, key_layer], dim=2)
value = torch.cat([value_enc, value_layer], dim=2)
# concat query
query_enc, query_hid = query[:, :, :E], query[:, :, E:]
query_layer = torch.cat([query_hid, cond_query], dim=2)
query = torch.cat([query_enc, query_layer], dim=2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
if with_cond:
encoder_hidden_states, hidden_states, cond_latents = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] : -cond_latents.shape[1]*cond_bs],
hidden_states[:, -cond_latents.shape[1]*cond_bs :],
)
cond_latents = cond_latents.view(cond_bs, cond_latents.shape[1] // cond_bs, cond_latents.shape[2])
cond_latents = attn.to_out[0](cond_latents)
cond_latents = attn.to_out[1](cond_latents)
else:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1]:],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if with_cond:
return hidden_states, encoder_hidden_states, cond_latents
return hidden_states, encoder_hidden_states
else:
if with_cond:
hidden_states, cond_latents = (
hidden_states[:, : -cond_latents.shape[1]*cond_bs],
hidden_states[:, -cond_latents.shape[1]*cond_bs :],
)
cond_latents = cond_latents.view(cond_bs, cond_latents.shape[1] // cond_bs, cond_latents.shape[2])
return hidden_states, cond_latents
return hidden_states