SeedVR2-3B / models /dit /blocks /mmdit_window_block.py
IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from typing import Tuple, Union
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _triple
from common.distributed.ops import (
gather_heads,
gather_heads_scatter_seq,
gather_seq_scatter_heads_qkv,
scatter_heads,
)
from ..attention import TorchAttention
from ..mlp import get_mlp
from ..mm import MMArg, MMModule
from ..modulation import ada_layer_type
from ..normalization import norm_layer_type
from ..rope import RotaryEmbedding3d
class MMWindowAttention(nn.Module):
def __init__(
self,
vid_dim: int,
txt_dim: int,
heads: int,
head_dim: int,
qk_bias: bool,
qk_rope: bool,
qk_norm: norm_layer_type,
qk_norm_eps: float,
window: Union[int, Tuple[int, int, int]],
window_method: str,
shared_qkv: bool,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
inner_dim = heads * head_dim
qkv_dim = inner_dim * 3
self.window = _triple(window)
self.window_method = window_method
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
self.head_dim = head_dim
self.proj_qkv = MMModule(nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_qkv)
self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_qkv)
self.norm_q = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True)
self.norm_k = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True)
self.rope = RotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None
self.attn = TorchAttention()
def forward(
self,
vid: torch.FloatTensor, # b T H W c
txt: torch.FloatTensor, # b L c
txt_mask: torch.BoolTensor, # b L
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
# Project q, k, v.
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
vid_qkv = gather_seq_scatter_heads_qkv(vid_qkv, seq_dim=2)
_, T, H, W, _ = vid_qkv.shape
_, L, _ = txt.shape
if self.window_method == "win":
nt, nh, nw = self.window
tt, hh, ww = T // nt, H // nh, W // nw
elif self.window_method == "win_by_size":
tt, hh, ww = self.window
tt, hh, ww = (
tt if tt > 0 else T,
hh if hh > 0 else H,
ww if ww > 0 else W,
)
nt, nh, nw = T // tt, H // hh, W // ww
else:
raise NotImplementedError
vid_qkv = rearrange(vid_qkv, "b T H W (o h d) -> o b h (T H W) d", o=3, d=self.head_dim)
txt_qkv = rearrange(txt_qkv, "b L (o h d) -> o b h L d", o=3, d=self.head_dim)
txt_qkv = scatter_heads(txt_qkv, dim=2)
vid_q, vid_k, vid_v = vid_qkv.unbind()
txt_q, txt_k, txt_v = txt_qkv.unbind()
vid_q, txt_q = self.norm_q(vid_q, txt_q)
vid_k, txt_k = self.norm_k(vid_k, txt_k)
if self.rope:
vid_q, vid_k = self.rope(vid_q, vid_k, (T, H, W))
def vid_window(v):
return rearrange(
v,
"b h (nt tt nh hh nw ww) d -> b h (nt nh nw) (tt hh ww) d",
hh=hh,
ww=ww,
tt=tt,
nh=nh,
nw=nw,
nt=nt,
)
def txt_window(t):
return rearrange(t, "b h L d -> b h 1 L d").expand(-1, -1, nt * nh * nw, -1, -1)
# Process video attention.
vid_msk = F.pad(txt_mask, (tt * hh * ww, 0), value=True)
vid_msk = rearrange(vid_msk, "b l -> b 1 1 1 l").expand(-1, 1, 1, tt * hh * ww, -1)
vid_out = self.attn(
vid_window(vid_q),
torch.cat([vid_window(vid_k), txt_window(txt_k)], dim=-2),
torch.cat([vid_window(vid_v), txt_window(txt_v)], dim=-2),
vid_msk,
)
vid_out = rearrange(
vid_out,
"b h (nt nh nw) (tt hh ww) d -> b (nt tt) (nh hh) (nw ww) (h d)",
hh=hh,
ww=ww,
tt=tt,
nh=nh,
nw=nw,
)
vid_out = gather_heads_scatter_seq(vid_out, head_dim=4, seq_dim=2)
# Process text attention.
txt_msk = F.pad(txt_mask, (T * H * W, 0), value=True)
txt_msk = rearrange(txt_msk, "b l -> b 1 1 l").expand(-1, 1, L, -1)
txt_out = self.attn(
txt_q,
torch.cat([vid_k, txt_k], dim=-2),
torch.cat([vid_v, txt_v], dim=-2),
txt_msk,
)
txt_out = rearrange(txt_out, "b h L d -> b L (h d)")
txt_out = gather_heads(txt_out, dim=2)
# Project output.
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
class MMWindowTransformerBlock(nn.Module):
def __init__(
self,
*,
vid_dim: int,
txt_dim: int,
emb_dim: int,
heads: int,
head_dim: int,
expand_ratio: int,
norm: norm_layer_type,
norm_eps: float,
ada: ada_layer_type,
qk_bias: bool,
qk_rope: bool,
qk_norm: norm_layer_type,
window: Union[int, Tuple[int, int, int]],
window_method: str,
shared_qkv: bool,
shared_mlp: bool,
mlp_type: str,
**kwargs,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False)
self.attn = MMWindowAttention(
vid_dim=vid_dim,
txt_dim=txt_dim,
heads=heads,
head_dim=head_dim,
qk_bias=qk_bias,
qk_rope=qk_rope,
qk_norm=qk_norm,
qk_norm_eps=norm_eps,
window=window,
window_method=window_method,
shared_qkv=shared_qkv,
)
self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False)
self.mlp = MMModule(
get_mlp(mlp_type),
dim=dim,
expand_ratio=expand_ratio,
shared_weights=shared_mlp,
)
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"])
def forward(
self,
vid: torch.FloatTensor,
txt: torch.FloatTensor,
txt_mask: torch.BoolTensor,
emb: torch.FloatTensor,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_attn, txt_attn = self.attn_norm(vid, txt)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="in")
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, txt_mask=txt_mask)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="out")
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="in")
vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="out")
vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
return vid_mlp, txt_mlp