Spaces:
Running
on
A100
Running
on
A100
File size: 13,444 Bytes
43c5292 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from hyimage.models.hunyuan.modules.flash_attn_no_pad import flash_attn_no_pad
from .activation_layers import get_activation_layer
from .mlp_layers import MLP, LinearWarpforSingle
from .modulate_layers import ModulateDiT, apply_gate, modulate
from .norm_layers import get_norm_layer
from .posemb_layers import apply_rotary_emb
@torch.compiler.disable
def attention(
q,
k,
v,
attn_mode="flash",
text_mask=None,
):
"""Multi-modal attention function that processes image and text sequences."""
query, encoder_query = q
key, encoder_key = k
value, encoder_value = v
assert attn_mode == "flash" # Only flash attention is implemented for now
sequence_length = query.size(1)
encoder_sequence_length = encoder_query.size(1)
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
# Stack query, key, value: B, S, 3, H, D
qkv = torch.stack([query, key, value], dim=2)
attn_mask = torch.nn.functional.pad(text_mask, (sequence_length, 0), value=True)
hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states = encoder_hidden_states.to(query.dtype)
attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
return attn
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal DiT block with separate modulation for text and image/video.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
# Image stream components
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
# Text stream components
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.core_attn = attention
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
freqs_cis: tuple = None,
text_mask: torch.Tensor = None,
cu_seqlens=None,
max_s=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Extract modulation parameters for image and text streams
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Process image stream for attention
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_q = self.img_attn_q(img_modulated)
img_k = self.img_attn_k(img_modulated)
img_v = self.img_attn_v(img_modulated)
img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num)
img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num)
img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num)
# Apply QK-Norm if enabled
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if provided
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Process text stream for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_q = self.txt_attn_q(txt_modulated)
txt_k = self.txt_attn_k(txt_modulated)
txt_v = self.txt_attn_v(txt_modulated)
txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num)
txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num)
txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num)
# Apply QK-Norm if enabled
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Compute cross-modal attention
attn = self.core_attn(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
text_mask=text_mask,
)
# Split attention outputs for image and text streams
img_attn, txt_attn = (
attn[:, : img_q.shape[1]].contiguous(),
attn[:, img_q.shape[1] :].contiguous(),
)
# Apply attention projection and residual connection for image stream
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
# Apply MLP and residual connection for image stream
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
# Apply attention projection and residual connection for text stream
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
# Apply MLP and residual connection for text stream
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers for multimodal processing.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim**-0.5
# Separate linear layers for Q, K, V, and MLP input
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear1_k = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear1_v = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear1_mlp = nn.Linear(hidden_size, mlp_hidden_dim, **factory_kwargs)
# Output projection layer
self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True, **factory_kwargs)
# QK normalization layers
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.core_attn = attention
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
text_mask: torch.Tensor = None,
cu_seqlens=None,
max_s=None,
) -> torch.Tensor:
# Extract modulation parameters
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
# Compute Q, K, V, and MLP input
q = self.linear1_q(x_mod)
k = self.linear1_k(x_mod)
v = self.linear1_v(x_mod)
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num)
v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num)
mlp = self.linear1_mlp(x_mod)
# Apply QK-Norm if enabled
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Split into image and text sequences
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
# Apply RoPE to image sequence
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Compute cross-modal attention
attn = self.core_attn(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
text_mask=text_mask,
)
# Combine attention output with MLP activation and apply final projection
output = self.linear2(attn, self.mlp_act(mlp))
return x + apply_gate(output, gate=mod_gate)
|