from typing import Optional, Tuple import torch import torch.nn as nn from einops import rearrange from hyimage.models.hunyuan.modules.flash_attn_no_pad import flash_attn_no_pad from .activation_layers import get_activation_layer from .mlp_layers import MLP, LinearWarpforSingle from .modulate_layers import ModulateDiT, apply_gate, modulate from .norm_layers import get_norm_layer from .posemb_layers import apply_rotary_emb @torch.compiler.disable def attention( q, k, v, attn_mode="flash", text_mask=None, ): """Multi-modal attention function that processes image and text sequences.""" query, encoder_query = q key, encoder_key = k value, encoder_value = v assert attn_mode == "flash" # Only flash attention is implemented for now sequence_length = query.size(1) encoder_sequence_length = encoder_query.size(1) query = torch.cat([query, encoder_query], dim=1) key = torch.cat([key, encoder_key], dim=1) value = torch.cat([value, encoder_value], dim=1) # Stack query, key, value: B, S, 3, H, D qkv = torch.stack([query, key, value], dim=2) attn_mask = torch.nn.functional.pad(text_mask, (sequence_length, 0), value=True) hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None) hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 ) hidden_states = hidden_states.to(query.dtype) encoder_hidden_states = encoder_hidden_states.to(query.dtype) attn = torch.cat([hidden_states, encoder_hidden_states], dim=1) b, s, a, d = attn.shape attn = attn.reshape(b, s, -1) return attn class MMDoubleStreamBlock(nn.Module): """ A multimodal DiT block with separate modulation for text and image/video. """ def __init__( self, hidden_size: int, heads_num: int, mlp_width_ratio: float, mlp_act_type: str = "gelu_tanh", qk_norm: bool = True, qk_norm_type: str = "rms", qkv_bias: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.deterministic = False self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) # Image stream components self.img_mod = ModulateDiT( hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs, ) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.img_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.img_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.img_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) qk_norm_layer = get_norm_layer(qk_norm_type) self.img_attn_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.img_attn_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.img_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs, ) # Text stream components self.txt_mod = ModulateDiT( hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs, ) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.txt_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.txt_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.txt_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.txt_attn_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.txt_attn_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.txt_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs, ) self.core_attn = attention def enable_deterministic(self): self.deterministic = True def disable_deterministic(self): self.deterministic = False def forward( self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, text_mask: torch.Tensor = None, cu_seqlens=None, max_s=None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Extract modulation parameters for image and text streams ( img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, ) = self.img_mod(vec).chunk(6, dim=-1) ( txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate, ) = self.txt_mod(vec).chunk(6, dim=-1) # Process image stream for attention img_modulated = self.img_norm1(img) img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) img_q = self.img_attn_q(img_modulated) img_k = self.img_attn_k(img_modulated) img_v = self.img_attn_v(img_modulated) img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num) img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num) img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num) # Apply QK-Norm if enabled img_q = self.img_attn_q_norm(img_q).to(img_v) img_k = self.img_attn_k_norm(img_k).to(img_v) # Apply RoPE if provided if freqs_cis is not None: img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) assert ( img_qq.shape == img_q.shape and img_kk.shape == img_k.shape ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" img_q, img_k = img_qq, img_kk # Process text stream for attention txt_modulated = self.txt_norm1(txt) txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) txt_q = self.txt_attn_q(txt_modulated) txt_k = self.txt_attn_k(txt_modulated) txt_v = self.txt_attn_v(txt_modulated) txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num) txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num) txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num) # Apply QK-Norm if enabled txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) # Compute cross-modal attention attn = self.core_attn( (img_q, txt_q), (img_k, txt_k), (img_v, txt_v), text_mask=text_mask, ) # Split attention outputs for image and text streams img_attn, txt_attn = ( attn[:, : img_q.shape[1]].contiguous(), attn[:, img_q.shape[1] :].contiguous(), ) # Apply attention projection and residual connection for image stream img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) # Apply MLP and residual connection for image stream img = img + apply_gate( self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), gate=img_mod2_gate, ) # Apply attention projection and residual connection for text stream txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) # Apply MLP and residual connection for text stream txt = txt + apply_gate( self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), gate=txt_mod2_gate, ) return img, txt class MMSingleStreamBlock(nn.Module): """ A DiT block with parallel linear layers for multimodal processing. """ def __init__( self, hidden_size: int, heads_num: int, mlp_width_ratio: float = 4.0, mlp_act_type: str = "gelu_tanh", qk_norm: bool = True, qk_norm_type: str = "rms", qk_scale: float = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.deterministic = False self.hidden_size = hidden_size self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.mlp_hidden_dim = mlp_hidden_dim self.scale = qk_scale or head_dim**-0.5 # Separate linear layers for Q, K, V, and MLP input self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) self.linear1_k = nn.Linear(hidden_size, hidden_size, **factory_kwargs) self.linear1_v = nn.Linear(hidden_size, hidden_size, **factory_kwargs) self.linear1_mlp = nn.Linear(hidden_size, mlp_hidden_dim, **factory_kwargs) # Output projection layer self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True, **factory_kwargs) # QK normalization layers qk_norm_layer = get_norm_layer(qk_norm_type) self.q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.mlp_act = get_activation_layer(mlp_act_type)() self.modulation = ModulateDiT( hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs, ) self.core_attn = attention def enable_deterministic(self): self.deterministic = True def disable_deterministic(self): self.deterministic = False def forward( self, x: torch.Tensor, vec: torch.Tensor, txt_len: int, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, text_mask: torch.Tensor = None, cu_seqlens=None, max_s=None, ) -> torch.Tensor: # Extract modulation parameters mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) # Compute Q, K, V, and MLP input q = self.linear1_q(x_mod) k = self.linear1_k(x_mod) v = self.linear1_v(x_mod) q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num) v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num) mlp = self.linear1_mlp(x_mod) # Apply QK-Norm if enabled q = self.q_norm(q).to(v) k = self.k_norm(k).to(v) # Split into image and text sequences img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :] # Apply RoPE to image sequence img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) assert ( img_qq.shape == img_q.shape and img_kk.shape == img_k.shape ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" img_q, img_k = img_qq, img_kk # Compute cross-modal attention attn = self.core_attn( (img_q, txt_q), (img_k, txt_k), (img_v, txt_v), text_mask=text_mask, ) # Combine attention output with MLP activation and apply final projection output = self.linear2(attn, self.mlp_act(mlp)) return x + apply_gate(output, gate=mod_gate)