KevinNg99's picture
Initial commit.
43c5292
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from hyimage.models.hunyuan.modules.flash_attn_no_pad import flash_attn_no_pad
from .activation_layers import get_activation_layer
from .mlp_layers import MLP, LinearWarpforSingle
from .modulate_layers import ModulateDiT, apply_gate, modulate
from .norm_layers import get_norm_layer
from .posemb_layers import apply_rotary_emb
@torch.compiler.disable
def attention(
q,
k,
v,
attn_mode="flash",
text_mask=None,
):
"""Multi-modal attention function that processes image and text sequences."""
query, encoder_query = q
key, encoder_key = k
value, encoder_value = v
assert attn_mode == "flash" # Only flash attention is implemented for now
sequence_length = query.size(1)
encoder_sequence_length = encoder_query.size(1)
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
# Stack query, key, value: B, S, 3, H, D
qkv = torch.stack([query, key, value], dim=2)
attn_mask = torch.nn.functional.pad(text_mask, (sequence_length, 0), value=True)
hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states = encoder_hidden_states.to(query.dtype)
attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
return attn
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal DiT block with separate modulation for text and image/video.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
# Image stream components
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
# Text stream components
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.core_attn = attention
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
freqs_cis: tuple = None,
text_mask: torch.Tensor = None,
cu_seqlens=None,
max_s=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Extract modulation parameters for image and text streams
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Process image stream for attention
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_q = self.img_attn_q(img_modulated)
img_k = self.img_attn_k(img_modulated)
img_v = self.img_attn_v(img_modulated)
img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num)
img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num)
img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num)
# Apply QK-Norm if enabled
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if provided
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Process text stream for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_q = self.txt_attn_q(txt_modulated)
txt_k = self.txt_attn_k(txt_modulated)
txt_v = self.txt_attn_v(txt_modulated)
txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num)
txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num)
txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num)
# Apply QK-Norm if enabled
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Compute cross-modal attention
attn = self.core_attn(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
text_mask=text_mask,
)
# Split attention outputs for image and text streams
img_attn, txt_attn = (
attn[:, : img_q.shape[1]].contiguous(),
attn[:, img_q.shape[1] :].contiguous(),
)
# Apply attention projection and residual connection for image stream
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
# Apply MLP and residual connection for image stream
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
# Apply attention projection and residual connection for text stream
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
# Apply MLP and residual connection for text stream
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers for multimodal processing.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim**-0.5
# Separate linear layers for Q, K, V, and MLP input
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear1_k = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear1_v = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear1_mlp = nn.Linear(hidden_size, mlp_hidden_dim, **factory_kwargs)
# Output projection layer
self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True, **factory_kwargs)
# QK normalization layers
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.core_attn = attention
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
text_mask: torch.Tensor = None,
cu_seqlens=None,
max_s=None,
) -> torch.Tensor:
# Extract modulation parameters
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
# Compute Q, K, V, and MLP input
q = self.linear1_q(x_mod)
k = self.linear1_k(x_mod)
v = self.linear1_v(x_mod)
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num)
v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num)
mlp = self.linear1_mlp(x_mod)
# Apply QK-Norm if enabled
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Split into image and text sequences
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
# Apply RoPE to image sequence
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Compute cross-modal attention
attn = self.core_attn(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
text_mask=text_mask,
)
# Combine attention output with MLP activation and apply final projection
output = self.linear2(attn, self.mlp_act(mlp))
return x + apply_gate(output, gate=mod_gate)