Spaces:
Running
on
A100
Running
on
A100
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from hyimage.models.hunyuan.modules.flash_attn_no_pad import flash_attn_no_pad | |
from .activation_layers import get_activation_layer | |
from .mlp_layers import MLP, LinearWarpforSingle | |
from .modulate_layers import ModulateDiT, apply_gate, modulate | |
from .norm_layers import get_norm_layer | |
from .posemb_layers import apply_rotary_emb | |
def attention( | |
q, | |
k, | |
v, | |
attn_mode="flash", | |
text_mask=None, | |
): | |
"""Multi-modal attention function that processes image and text sequences.""" | |
query, encoder_query = q | |
key, encoder_key = k | |
value, encoder_value = v | |
assert attn_mode == "flash" # Only flash attention is implemented for now | |
sequence_length = query.size(1) | |
encoder_sequence_length = encoder_query.size(1) | |
query = torch.cat([query, encoder_query], dim=1) | |
key = torch.cat([key, encoder_key], dim=1) | |
value = torch.cat([value, encoder_value], dim=1) | |
# Stack query, key, value: B, S, 3, H, D | |
qkv = torch.stack([query, key, value], dim=2) | |
attn_mask = torch.nn.functional.pad(text_mask, (sequence_length, 0), value=True) | |
hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None) | |
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( | |
(sequence_length, encoder_sequence_length), dim=1 | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
encoder_hidden_states = encoder_hidden_states.to(query.dtype) | |
attn = torch.cat([hidden_states, encoder_hidden_states], dim=1) | |
b, s, a, d = attn.shape | |
attn = attn.reshape(b, s, -1) | |
return attn | |
class MMDoubleStreamBlock(nn.Module): | |
""" | |
A multimodal DiT block with separate modulation for text and image/video. | |
""" | |
def __init__( | |
self, | |
hidden_size: int, | |
heads_num: int, | |
mlp_width_ratio: float, | |
mlp_act_type: str = "gelu_tanh", | |
qk_norm: bool = True, | |
qk_norm_type: str = "rms", | |
qkv_bias: bool = False, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device] = None, | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.deterministic = False | |
self.heads_num = heads_num | |
head_dim = hidden_size // heads_num | |
mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | |
# Image stream components | |
self.img_mod = ModulateDiT( | |
hidden_size, | |
factor=6, | |
act_layer=get_activation_layer("silu"), | |
**factory_kwargs, | |
) | |
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) | |
self.img_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
self.img_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
self.img_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
qk_norm_layer = get_norm_layer(qk_norm_type) | |
self.img_attn_q_norm = ( | |
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() | |
) | |
self.img_attn_k_norm = ( | |
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() | |
) | |
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) | |
self.img_mlp = MLP( | |
hidden_size, | |
mlp_hidden_dim, | |
act_layer=get_activation_layer(mlp_act_type), | |
bias=True, | |
**factory_kwargs, | |
) | |
# Text stream components | |
self.txt_mod = ModulateDiT( | |
hidden_size, | |
factor=6, | |
act_layer=get_activation_layer("silu"), | |
**factory_kwargs, | |
) | |
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) | |
self.txt_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
self.txt_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
self.txt_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
self.txt_attn_q_norm = ( | |
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() | |
) | |
self.txt_attn_k_norm = ( | |
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() | |
) | |
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) | |
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) | |
self.txt_mlp = MLP( | |
hidden_size, | |
mlp_hidden_dim, | |
act_layer=get_activation_layer(mlp_act_type), | |
bias=True, | |
**factory_kwargs, | |
) | |
self.core_attn = attention | |
def enable_deterministic(self): | |
self.deterministic = True | |
def disable_deterministic(self): | |
self.deterministic = False | |
def forward( | |
self, | |
img: torch.Tensor, | |
txt: torch.Tensor, | |
vec: torch.Tensor, | |
freqs_cis: tuple = None, | |
text_mask: torch.Tensor = None, | |
cu_seqlens=None, | |
max_s=None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Extract modulation parameters for image and text streams | |
( | |
img_mod1_shift, | |
img_mod1_scale, | |
img_mod1_gate, | |
img_mod2_shift, | |
img_mod2_scale, | |
img_mod2_gate, | |
) = self.img_mod(vec).chunk(6, dim=-1) | |
( | |
txt_mod1_shift, | |
txt_mod1_scale, | |
txt_mod1_gate, | |
txt_mod2_shift, | |
txt_mod2_scale, | |
txt_mod2_gate, | |
) = self.txt_mod(vec).chunk(6, dim=-1) | |
# Process image stream for attention | |
img_modulated = self.img_norm1(img) | |
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) | |
img_q = self.img_attn_q(img_modulated) | |
img_k = self.img_attn_k(img_modulated) | |
img_v = self.img_attn_v(img_modulated) | |
img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num) | |
img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num) | |
img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num) | |
# Apply QK-Norm if enabled | |
img_q = self.img_attn_q_norm(img_q).to(img_v) | |
img_k = self.img_attn_k_norm(img_k).to(img_v) | |
# Apply RoPE if provided | |
if freqs_cis is not None: | |
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) | |
assert ( | |
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape | |
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" | |
img_q, img_k = img_qq, img_kk | |
# Process text stream for attention | |
txt_modulated = self.txt_norm1(txt) | |
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) | |
txt_q = self.txt_attn_q(txt_modulated) | |
txt_k = self.txt_attn_k(txt_modulated) | |
txt_v = self.txt_attn_v(txt_modulated) | |
txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num) | |
txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num) | |
txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num) | |
# Apply QK-Norm if enabled | |
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | |
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | |
# Compute cross-modal attention | |
attn = self.core_attn( | |
(img_q, txt_q), | |
(img_k, txt_k), | |
(img_v, txt_v), | |
text_mask=text_mask, | |
) | |
# Split attention outputs for image and text streams | |
img_attn, txt_attn = ( | |
attn[:, : img_q.shape[1]].contiguous(), | |
attn[:, img_q.shape[1] :].contiguous(), | |
) | |
# Apply attention projection and residual connection for image stream | |
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) | |
# Apply MLP and residual connection for image stream | |
img = img + apply_gate( | |
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), | |
gate=img_mod2_gate, | |
) | |
# Apply attention projection and residual connection for text stream | |
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) | |
# Apply MLP and residual connection for text stream | |
txt = txt + apply_gate( | |
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), | |
gate=txt_mod2_gate, | |
) | |
return img, txt | |
class MMSingleStreamBlock(nn.Module): | |
""" | |
A DiT block with parallel linear layers for multimodal processing. | |
""" | |
def __init__( | |
self, | |
hidden_size: int, | |
heads_num: int, | |
mlp_width_ratio: float = 4.0, | |
mlp_act_type: str = "gelu_tanh", | |
qk_norm: bool = True, | |
qk_norm_type: str = "rms", | |
qk_scale: float = None, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device] = None, | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.deterministic = False | |
self.hidden_size = hidden_size | |
self.heads_num = heads_num | |
head_dim = hidden_size // heads_num | |
mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | |
self.mlp_hidden_dim = mlp_hidden_dim | |
self.scale = qk_scale or head_dim**-0.5 | |
# Separate linear layers for Q, K, V, and MLP input | |
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) | |
self.linear1_k = nn.Linear(hidden_size, hidden_size, **factory_kwargs) | |
self.linear1_v = nn.Linear(hidden_size, hidden_size, **factory_kwargs) | |
self.linear1_mlp = nn.Linear(hidden_size, mlp_hidden_dim, **factory_kwargs) | |
# Output projection layer | |
self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True, **factory_kwargs) | |
# QK normalization layers | |
qk_norm_layer = get_norm_layer(qk_norm_type) | |
self.q_norm = ( | |
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() | |
) | |
self.k_norm = ( | |
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() | |
) | |
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) | |
self.mlp_act = get_activation_layer(mlp_act_type)() | |
self.modulation = ModulateDiT( | |
hidden_size, | |
factor=3, | |
act_layer=get_activation_layer("silu"), | |
**factory_kwargs, | |
) | |
self.core_attn = attention | |
def enable_deterministic(self): | |
self.deterministic = True | |
def disable_deterministic(self): | |
self.deterministic = False | |
def forward( | |
self, | |
x: torch.Tensor, | |
vec: torch.Tensor, | |
txt_len: int, | |
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, | |
text_mask: torch.Tensor = None, | |
cu_seqlens=None, | |
max_s=None, | |
) -> torch.Tensor: | |
# Extract modulation parameters | |
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) | |
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) | |
# Compute Q, K, V, and MLP input | |
q = self.linear1_q(x_mod) | |
k = self.linear1_k(x_mod) | |
v = self.linear1_v(x_mod) | |
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) | |
k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num) | |
v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num) | |
mlp = self.linear1_mlp(x_mod) | |
# Apply QK-Norm if enabled | |
q = self.q_norm(q).to(v) | |
k = self.k_norm(k).to(v) | |
# Split into image and text sequences | |
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] | |
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] | |
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :] | |
# Apply RoPE to image sequence | |
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) | |
assert ( | |
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape | |
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" | |
img_q, img_k = img_qq, img_kk | |
# Compute cross-modal attention | |
attn = self.core_attn( | |
(img_q, txt_q), | |
(img_k, txt_k), | |
(img_v, txt_v), | |
text_mask=text_mask, | |
) | |
# Combine attention output with MLP activation and apply final projection | |
output = self.linear2(attn, self.mlp_act(mlp)) | |
return x + apply_gate(output, gate=mod_gate) | |