|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
import numpy as np |
|
from diffusers.utils import deprecate, logging |
|
from diffusers.utils.torch_utils import maybe_allow_in_graph |
|
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU |
|
from diffusers.models.attention_processor import Attention |
|
from diffusers.models.embeddings import SinusoidalPositionalEmbedding |
|
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX |
|
|
|
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): |
|
|
|
if hidden_states.shape[chunk_dim] % chunk_size != 0: |
|
raise ValueError( |
|
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." |
|
) |
|
|
|
num_chunks = hidden_states.shape[chunk_dim] // chunk_size |
|
ff_output = torch.cat( |
|
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], |
|
dim=chunk_dim, |
|
) |
|
return ff_output |
|
|
|
class FeedForward(nn.Module): |
|
r""" |
|
A feed-forward layer. |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input. |
|
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. |
|
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
|
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. |
|
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: Optional[int] = None, |
|
mult: int = 4, |
|
dropout: float = 0.0, |
|
activation_fn: str = "geglu", |
|
final_dropout: bool = False, |
|
inner_dim=None, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
if inner_dim is None: |
|
inner_dim = int(dim * mult) |
|
dim_out = dim_out if dim_out is not None else dim |
|
|
|
if activation_fn == "gelu": |
|
act_fn = GELU(dim, inner_dim, bias=bias) |
|
if activation_fn == "gelu-approximate": |
|
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) |
|
elif activation_fn == "geglu": |
|
act_fn = GEGLU(dim, inner_dim, bias=bias) |
|
elif activation_fn == "geglu-approximate": |
|
act_fn = ApproximateGELU(dim, inner_dim, bias=bias) |
|
elif activation_fn == "swiglu": |
|
act_fn = SwiGLU(dim, inner_dim, bias=bias) |
|
|
|
self.net = nn.ModuleList([]) |
|
|
|
self.net.append(act_fn) |
|
|
|
self.net.append(nn.Dropout(dropout)) |
|
|
|
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) |
|
|
|
if final_dropout: |
|
self.net.append(nn.Dropout(dropout)) |
|
|
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
for module in self.net: |
|
hidden_states = module(hidden_states) |
|
return hidden_states |
|
|
|
class FeedForwardControl(nn.Module): |
|
r""" |
|
A feed-forward layer. |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input. |
|
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. |
|
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
|
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. |
|
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: Optional[int] = None, |
|
mult: int = 4, |
|
dropout: float = 0.0, |
|
activation_fn: str = "geglu", |
|
final_dropout: bool = False, |
|
inner_dim=None, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
if inner_dim is None: |
|
inner_dim = int(dim * mult) |
|
dim_out = dim_out if dim_out is not None else dim |
|
|
|
if activation_fn == "gelu": |
|
act_fn = GELU(dim, inner_dim, bias=bias) |
|
if activation_fn == "gelu-approximate": |
|
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) |
|
elif activation_fn == "geglu": |
|
act_fn = GEGLU(dim, inner_dim, bias=bias) |
|
elif activation_fn == "geglu-approximate": |
|
act_fn = ApproximateGELU(dim, inner_dim, bias=bias) |
|
elif activation_fn == "swiglu": |
|
act_fn = SwiGLU(dim, inner_dim, bias=bias) |
|
|
|
self.net = nn.ModuleList([]) |
|
|
|
self.net.append(act_fn) |
|
|
|
self.net.append(nn.Dropout(dropout)) |
|
|
|
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) |
|
self.control_conv = zero_module(nn.Conv2d(inner_dim, inner_dim, 3, stride=1, padding=1, groups=inner_dim)) |
|
|
|
if final_dropout: |
|
self.net.append(nn.Dropout(dropout)) |
|
|
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
for i, module in enumerate(self.net): |
|
hidden_states = module(hidden_states) |
|
if i == 1: |
|
hidden_states, hidden_states_control_org = hidden_states.chunk(2, dim=1) |
|
B, N, C = hidden_states.shape |
|
h = w = int(np.sqrt(N)) |
|
assert h * w == N |
|
hidden_states_control = hidden_states_control_org.reshape(B, h, w, C).permute(0, 3, 1, 2) |
|
hidden_states_control = self.control_conv(hidden_states_control) |
|
hidden_states_control = hidden_states_control.reshape(B, C, N).permute(0, 2, 1) |
|
hidden_states = hidden_states + 1.2 * hidden_states_control |
|
hidden_states = torch.cat([hidden_states, hidden_states_control_org], dim=1) |
|
return hidden_states |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
@maybe_allow_in_graph |
|
class JointTransformerBlock(nn.Module): |
|
r""" |
|
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. |
|
|
|
Reference: https://arxiv.org/abs/2403.03206 |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input and output. |
|
num_attention_heads (`int`): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): The number of channels in each head. |
|
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the |
|
processing of `context` conditions. |
|
""" |
|
|
|
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, qk_norm: Optional[str] = None, use_dual_attention: bool = False): |
|
super().__init__() |
|
|
|
self.use_dual_attention = use_dual_attention |
|
self.context_pre_only = context_pre_only |
|
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" |
|
|
|
if use_dual_attention: |
|
self.norm1 = SD35AdaLayerNormZeroX(dim) |
|
else: |
|
self.norm1 = AdaLayerNormZero(dim) |
|
|
|
if context_norm_type == "ada_norm_continous": |
|
self.norm1_context = AdaLayerNormContinuous( |
|
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" |
|
) |
|
elif context_norm_type == "ada_norm_zero": |
|
self.norm1_context = AdaLayerNormZero(dim) |
|
else: |
|
raise ValueError( |
|
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" |
|
) |
|
if hasattr(F, "scaled_dot_product_attention"): |
|
processor = JointAttnProcessor2_0() |
|
else: |
|
raise ValueError( |
|
"The current PyTorch version does not support the `scaled_dot_product_attention` function." |
|
) |
|
self.attn = AttentionZero( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
added_kv_proj_dim=dim, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
out_dim=dim, |
|
context_pre_only=context_pre_only, |
|
bias=True, |
|
processor=processor, |
|
qk_norm=qk_norm, |
|
eps=1e-6, |
|
) |
|
|
|
if use_dual_attention: |
|
self.attn2 = AttentionZero( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
added_kv_proj_dim=dim, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
out_dim=dim, |
|
context_pre_only=context_pre_only, |
|
bias=True, |
|
processor=processor, |
|
qk_norm=qk_norm, |
|
eps=1e-6, |
|
) |
|
else: |
|
self.attn2 = None |
|
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
self.ff = FeedForwardControl(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
|
|
if not context_pre_only: |
|
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
else: |
|
self.norm2_context = None |
|
self.ff_context = None |
|
|
|
|
|
self._chunk_size = None |
|
self._chunk_dim = 0 |
|
|
|
|
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): |
|
|
|
self._chunk_size = chunk_size |
|
self._chunk_dim = dim |
|
|
|
def forward( |
|
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor |
|
): |
|
if self.use_dual_attention: |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( |
|
hidden_states, emb=temb |
|
) |
|
else: |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) |
|
|
|
if self.context_pre_only: |
|
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) |
|
else: |
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( |
|
encoder_hidden_states, emb=temb |
|
) |
|
|
|
|
|
attn_output, context_attn_output = self.attn( |
|
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, |
|
) |
|
|
|
|
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = hidden_states + attn_output |
|
if self.use_dual_attention: |
|
attn_output2 = self.attn2(hidden_states=norm_hidden_states2) |
|
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 |
|
hidden_states = hidden_states + attn_output2 |
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
if self._chunk_size is not None: |
|
|
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) |
|
else: |
|
ff_output = self.ff(norm_hidden_states) |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
if self.context_pre_only: |
|
encoder_hidden_states = None |
|
else: |
|
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output |
|
encoder_hidden_states = encoder_hidden_states + context_attn_output |
|
|
|
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
|
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
|
if self._chunk_size is not None: |
|
|
|
context_ff_output = _chunked_feed_forward( |
|
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size |
|
) |
|
else: |
|
context_ff_output = self.ff_context(norm_encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output |
|
|
|
return encoder_hidden_states, hidden_states |
|
|
|
class AttentionZero(Attention): |
|
def __init__(self, |
|
query_dim, |
|
cross_attention_dim, |
|
added_kv_proj_dim, |
|
dim_head, |
|
heads, |
|
out_dim, |
|
context_pre_only, |
|
bias, |
|
processor, |
|
qk_norm, |
|
eps): |
|
super(AttentionZero, self).__init__( |
|
query_dim=query_dim, |
|
cross_attention_dim=cross_attention_dim, |
|
added_kv_proj_dim=added_kv_proj_dim, |
|
dim_head=dim_head, |
|
heads=heads, |
|
out_dim=out_dim, |
|
context_pre_only=context_pre_only, |
|
bias=bias, |
|
processor=processor, |
|
qk_norm=qk_norm, |
|
eps=1e-6,) |
|
self.to_q_control = zero_module(nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias)) |
|
self.to_k_control = zero_module(nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=self.use_bias)) |
|
self.to_v_control = zero_module(nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=self.use_bias)) |
|
self.to_out_control = nn.Linear(self.inner_dim, self.out_dim, bias=True) |
|
self.to_out_control.weight.data.copy_(self.to_out[0].weight.data) |
|
self.to_out_control.bias.data.copy_(self.to_out[0].bias.data) |
|
|
|
def zero_module(module): |
|
for p in module.parameters(): |
|
nn.init.zeros_(p) |
|
return module |
|
|
|
class JointAttnProcessor2_0: |
|
"""Attention processor used typically in processing the SD3-like self-attention projections.""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
|
|
residual = hidden_states |
|
|
|
batch_size = hidden_states.shape[0] |
|
|
|
hidden_states, hidden_states_control = hidden_states.chunk(2, dim=1) |
|
hidden_states_control_res = hidden_states_control |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
query_control = attn.to_q_control(attn.to_q(hidden_states_control)) |
|
key_control = attn.to_k_control(attn.to_k(hidden_states_control)) |
|
value_control = attn.to_v_control(attn.to_v(hidden_states_control)) |
|
|
|
query_control = query_control.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key_control = key_control.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value_control = value_control.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
query_control = attn.norm_q(query_control) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
key_control = attn.norm_k(key) |
|
|
|
if encoder_hidden_states is not None: |
|
|
|
|
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_added_q is not None: |
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
|
if attn.norm_added_k is not None: |
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
|
|
|
|
|
query = torch.cat([query, query_control, encoder_hidden_states_query_proj], dim=2) |
|
key = torch.cat([key, key_control, encoder_hidden_states_key_proj], dim=2) |
|
value = torch.cat([value, value_control, encoder_hidden_states_value_proj], dim=2) |
|
|
|
else : |
|
query = torch.cat([query, query_control], dim=2) |
|
key = torch.cat([key, key_control], dim=2) |
|
value = torch.cat([value, value_control], dim=2) |
|
|
|
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) |
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
if encoder_hidden_states is not None: |
|
|
|
hidden_states, encoder_hidden_states = ( |
|
hidden_states[:, : residual.shape[1]], |
|
hidden_states[:, residual.shape[1] :], |
|
) |
|
if not attn.context_pre_only: |
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
|
|
|
hidden_states, hidden_states_control = hidden_states.chunk(2, dim=1) |
|
|
|
hidden_states_control = hidden_states_control + hidden_states_control_res |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
hidden_states_control = attn.to_out_control(hidden_states_control) |
|
|
|
hidden_states = torch.cat([hidden_states, hidden_states_control], dim=1) |
|
|
|
if encoder_hidden_states is not None: |
|
return hidden_states, encoder_hidden_states |
|
else : |
|
return hidden_states |