KevinNg99's picture
Initial commit.
43c5292
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from hyimage.models.hunyuan.modules.flash_attn_no_pad import flash_attn_no_pad
from .activation_layers import get_activation_layer
from .embed_layers import TextProjection, TimestepEmbedder
from .mlp_layers import MLP
from .modulate_layers import apply_gate
from .norm_layers import get_norm_layer
@torch.compiler.disable
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
drop_rate: float = 0.0,
attn_mask: Optional[torch.Tensor] = None,
causal: bool = False,
) -> torch.Tensor:
"""
Compute attention using flash_attn_no_pad.
Args:
q: Query tensor of shape [B, L, H, D]
k: Key tensor of shape [B, L, H, D]
v: Value tensor of shape [B, L, H, D]
drop_rate: Dropout rate for attention weights.
attn_mask: Optional attention mask of shape [B, L].
causal: Whether to apply causal masking.
Returns:
Output tensor after attention of shape [B, L, H*D]
"""
qkv = torch.stack([q, k, v], dim=2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.bool()
x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class IndividualTokenRefinerBlock(nn.Module):
"""
A single block for token refinement with self-attention and MLP.
Args:
hidden_size: Hidden dimension size.
heads_num: Number of attention heads.
mlp_width_ratio: Expansion ratio for MLP hidden size.
mlp_drop_rate: Dropout rate for MLP.
act_type: Activation function type.
qk_norm: Whether to use QK normalization.
qk_norm_type: Type of QK normalization.
qkv_bias: Whether to use bias in QKV projections.
dtype: Optional torch dtype.
device: Optional torch device.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
act_layer = get_activation_layer(act_type)
self.mlp = MLP(
in_channels=hidden_size,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop_rate,
**factory_kwargs,
)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for IndividualTokenRefinerBlock.
Args:
x: Input tensor of shape [B, L, C].
c: Conditioning tensor of shape [B, C].
attn_mask: Optional attention mask of shape [B, L].
Returns:
Refined tensor of shape [B, L, C].
"""
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
attn = attention(q, k, v, attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
return x
class IndividualTokenRefiner(nn.Module):
"""
Stacks multiple IndividualTokenRefinerBlock modules.
Args:
hidden_size: Hidden dimension size.
heads_num: Number of attention heads.
depth: Number of blocks.
mlp_width_ratio: Expansion ratio for MLP hidden size.
mlp_drop_rate: Dropout rate for MLP.
act_type: Activation function type.
qk_norm: Whether to use QK normalization.
qk_norm_type: Type of QK normalization.
qkv_bias: Whether to use bias in QKV projections.
dtype: Optional torch dtype.
device: Optional torch device.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
depth: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.blocks = nn.ModuleList(
[
IndividualTokenRefinerBlock(
hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(depth)
]
)
def forward(
self,
x: torch.Tensor,
c: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for IndividualTokenRefiner.
Args:
x: Input tensor of shape [B, L, C].
c: Conditioning tensor of shape [B, C].
mask: Optional mask tensor of shape [B, L].
Returns:
Refined tensor of shape [B, L, C].
"""
if mask is not None:
mask = mask.clone().bool()
mask[:, 0] = True # Prevent attention weights from becoming NaN
for block in self.blocks:
x = block(x, c, mask)
return x
class SingleTokenRefiner(nn.Module):
"""
Single token refiner block for LLM text embedding refinement.
Args:
in_channels: Input feature dimension.
hidden_size: Hidden dimension size.
heads_num: Number of attention heads.
depth: Number of blocks.
mlp_width_ratio: Expansion ratio for MLP hidden size.
mlp_drop_rate: Dropout rate for MLP.
act_type: Activation function type.
qk_norm: Whether to use QK normalization.
qk_norm_type: Type of QK normalization.
qkv_bias: Whether to use bias in QKV projections.
dtype: Optional torch dtype.
device: Optional torch device.
"""
def __init__(
self,
in_channels: int,
hidden_size: int,
heads_num: int,
depth: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
act_layer = get_activation_layer(act_type)
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
self.individual_token_refiner = IndividualTokenRefiner(
hidden_size=hidden_size,
heads_num=heads_num,
depth=depth,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
def forward(
self,
x: torch.Tensor,
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
Forward pass for SingleTokenRefiner.
Args:
x: Input tensor of shape [B, L, in_channels].
t: Timestep tensor of shape [B].
mask: Optional mask tensor of shape [B, L].
Returns:
Refined tensor of shape [B, L, hidden_size].
"""
timestep_aware_representations = self.t_embedder(t)
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.unsqueeze(-1) # [B, L, 1]
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x