Hunyuan-Avatar / hymm_sp /modules /models_audio.py
rahul7star's picture
Update hymm_sp/modules/models_audio.py
0c66538 verified
from typing import List, Tuple, Optional, Union, Dict
from einops import rearrange
import torch, os
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attn_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate
from .token_refiner import SingleTokenRefiner
from .audio_adapters import AudioProjNet2, PerceiverAttentionCA
from .parallel_states import (
nccl_info,
get_cu_seqlens,
get_sequence_parallel_state,
parallel_attention,
all_gather,
)
CPU_OFFLOAD = int(os.environ.get("CPU_OFFLOAD", 0))
DISABLE_SP = int(os.environ.get("DISABLE_SP", 0))
print(f'models: cpu_offload={CPU_OFFLOAD}, DISABLE_SP={DISABLE_SP}')
class DoubleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_width_ratio: float,
mlp_act_type: str = 'gelu_tanh',
qk_norm: bool = True,
qk_norm_type: str = 'rms',
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.deterministic = False
self.num_heads = num_heads
head_dim = hidden_size // num_heads
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs
)
self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs
)
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: tuple = None
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = (
self.img_mod(vec).chunk(6, dim=-1)
)
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = (
self.txt_mod(vec).chunk(6, dim=-1)
)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_qkv = self.img_attn_qkv(img_modulated)
if CPU_OFFLOAD: torch.cuda.empty_cache()
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Apply RoPE if needed.
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert img_qq.shape == img_q.shape and img_kk.shape == img_k.shape, \
f'img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}'
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
if CPU_OFFLOAD: torch.cuda.empty_cache()
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
# Compute attention.
if CPU_OFFLOAD or DISABLE_SP:
assert cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
q, k, v = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [q, k, v]
]
attn = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
attn = attn.view(img_k.shape[0], max_seqlen_q, -1).contiguous()
else:
attn, _ = parallel_attention(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:]
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Calculate the img bloks.
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), gate=img_mod2_gate)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Calculate the txt bloks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), gate=txt_mod2_gate)
if CPU_OFFLOAD: torch.cuda.empty_cache()
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = 'gelu_tanh',
qk_norm: bool = True,
qk_norm_type: str = 'rms',
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim**-0.5
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
mod_shift, mod_scale, mod_gate = (
self.modulation(vec).chunk(3, dim=-1)
)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
if CPU_OFFLOAD: torch.cuda.empty_cache()
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Apply RoPE if needed.
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert img_qq.shape == img_q.shape and img_kk.shape == img_k.shape, \
f'img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}'
img_q, img_k = img_qq, img_kk
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Compute attention.
if CPU_OFFLOAD or DISABLE_SP:
assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
# [b, s+l, a, d] -> [s+l, b, a, d]
q, k, v = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [q, k, v]
]
attn = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
attn = attn.view(x.shape[0], max_seqlen_q, -1).contiguous()
else:
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
attn, _ = parallel_attention(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
if CPU_OFFLOAD:
torch.cuda.empty_cache()
tmp = torch.cat((attn, self.mlp_act(mlp)), 2)
torch.cuda.empty_cache()
output = self.linear2(tmp)
else:
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + apply_gate(output, gate=mod_gate)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206,
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
"""
@register_to_config
def __init__(
self,
args,
patch_size: list = [1,2,2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = 'gelu_tanh',
num_heads: int = 24,
depth_double_blocks: int = 19,
depth_single_blocks: int = 38,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = 'rms',
guidance_embed: bool = False, # For modulation.
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.text_projection = args.text_projection
self.text_states_dim = args.text_states_dim
self.use_attention_mask = args.use_attention_mask
self.text_states_dim_2 = args.text_states_dim_2
# Now we only use above configs from args.
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
if hidden_size % num_heads != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
)
pe_dim = hidden_size // num_heads
if sum(rope_dim_list) != pe_dim:
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
# image projection
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
self.ref_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, num_heads, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
# time modulation
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance modulation
self.guidance_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
) if guidance_embed else None
# double blocks
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs
)
for _ in range(depth_double_blocks)
]
)
# single blocks
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs
)
for _ in range(depth_single_blocks)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs
)
# -------------------- audio_proj_model --------------------
self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4)
# -------------------- motion-embeder --------------------
self.motion_exp = TimestepEmbedder(
self.hidden_size // 4,
get_activation_layer("silu"),
**factory_kwargs
)
self.motion_pose = TimestepEmbedder(
self.hidden_size // 4,
get_activation_layer("silu"),
**factory_kwargs
)
self.fps_proj = TimestepEmbedder(
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs
)
self.before_proj = nn.Linear(self.hidden_size, self.hidden_size)
# -------------------- audio_insert_model --------------------
self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
self.single_stream_list = []
self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)}
self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)}
self.audio_adapter_blocks = nn.ModuleList([
PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list))
])
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
ref_latents: torch.Tensor=None,
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
is_cache: bool = False,
**additional_kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
img = x
txt = text_states
bsz, _, ot, oh, ow = x.shape
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
# Prepare modulation vectors.
vec = self.time_in(t)
motion_exp_vec = self.motion_exp(additional_kwargs["motion_exp"].view(-1)).view(x.shape[0], -1) # (b, 3072)
vec = vec + motion_exp_vec
motion_pose_vec = self.motion_pose(additional_kwargs["motion_pose"].view(-1)).view(x.shape[0], -1) # (b, 3072)
vec = vec + motion_pose_vec
fps_vec = self.fps_proj(additional_kwargs["fps"]) # (b, 3072)
vec = vec + fps_vec
audio_feature_all = self.audio_proj(additional_kwargs["audio_prompts"])
# text modulation
vec = vec + self.vector_in(text_states_2)
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
else:
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
if CPU_OFFLOAD: torch.cuda.empty_cache()
# Embed image and text.
ref_latents_first = ref_latents[:, :, :1].clone()
img, shape_mask = self.img_in(img)
ref_latents,_ = self.ref_in(ref_latents)
ref_latents_first,_ = self.img_in(ref_latents_first)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
# [b, l, h]
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
img = self.before_proj(ref_latents) + img
if CPU_OFFLOAD: torch.cuda.empty_cache()
ref_length = ref_latents_first.shape[-2] # [b s c]
img = torch.cat([ref_latents_first, img], dim=-2) # t c
img_len = img.shape[1]
mask_len = img_len - ref_length
if additional_kwargs["face_mask"].shape[2] == 1:
face_mask = additional_kwargs["face_mask"].repeat(1,1,ot,1,1) # repeat if number of mask frame is 1
else:
face_mask = additional_kwargs["face_mask"]
face_mask = torch.nn.functional.interpolate(face_mask, size=[ot, shape_mask[-2], shape_mask[-1]], mode="nearest")
face_mask = face_mask.view(-1,mask_len,1).repeat(1,1,img.shape[-1]).type_as(img)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
if get_sequence_parallel_state():
sp_size = nccl_info.sp_size
sp_rank = nccl_info.rank_within_group
assert img.shape[1] % sp_size == 0, f"Cannot split video sequence into ulysses SP ({sp_size}) parts evenly"
img = torch.chunk(img, sp_size, dim=1)[sp_rank]
freqs_cos = torch.chunk(freqs_cos, sp_size, dim=0)[sp_rank]
freqs_sin = torch.chunk(freqs_sin, sp_size, dim=0)[sp_rank]
if CPU_OFFLOAD: torch.cuda.empty_cache()
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
if not is_cache:
for layer_num, block in enumerate(self.double_blocks):
double_block_args = [img, txt, vec, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, freqs_cis]
img, txt = block(*double_block_args)
if CPU_OFFLOAD: torch.cuda.empty_cache()
""" insert audio feature to img """
if layer_num in self.double_stream_list:
if get_sequence_parallel_state():
img = all_gather(img, dim=1)
real_img = img[:,ref_length:].clone().view(bsz, ot, -1, 3072)
real_ref_img = torch.zeros_like(img[:,:ref_length].clone())
audio_feature_pad = audio_feature_all[:,:1].repeat(1,3,1,1)
audio_feature_all_insert = torch.cat([audio_feature_pad, audio_feature_all], dim=1).view(bsz, ot, 16, 3072)
double_idx = self.double_stream_map[str(layer_num)]
real_img = self.audio_adapter_blocks[double_idx](audio_feature_all_insert, real_img).view(bsz, -1, 3072)
img = img + torch.cat((real_ref_img, real_img * face_mask), dim=1)
if get_sequence_parallel_state():
sp_size = nccl_info.sp_size
sp_rank = nccl_info.rank_within_group
assert img.shape[1] % sp_size == 0, f"Cannot split video sequence into ulysses SP ({sp_size}) parts evenly"
img = torch.chunk(img, sp_size, dim=1)[sp_rank]
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
# Compatible with MMDiT.
if len(self.single_blocks) > 0:
for layer_num, block in enumerate(self.single_blocks):
if layer_num == (len(self.single_blocks) - 1):
# self.cache_out = x
tmp = x[:, :-txt_seq_len, ...]
if get_sequence_parallel_state():
tmp = all_gather(tmp, dim=1)
self.cache_out = torch.cat([tmp, x[:, -txt_seq_len:, ...]], dim=1)
single_block_args = [x, vec, txt_seq_len, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, (freqs_cos, freqs_sin)]
x = block(*single_block_args)
if CPU_OFFLOAD: torch.cuda.empty_cache()
else:
if get_sequence_parallel_state():
sp_size = nccl_info.sp_size
sp_rank = nccl_info.rank_within_group
tmp, txt = self.cache_out[:, :-txt_seq_len], self.cache_out[:, -txt_seq_len:]
tmp = torch.chunk(tmp, sp_size, dim=1)[sp_rank]
x = torch.cat([tmp, txt], dim=1)
else:
x = self.cache_out
if len(self.single_blocks) > 0:
for layer_num, block in enumerate(self.single_blocks):
if layer_num < (len(self.single_blocks) - 1):
continue
single_block_args = [x, vec, txt_seq_len, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, (freqs_cos, freqs_sin)]
x = block(*single_block_args)
if CPU_OFFLOAD: torch.cuda.empty_cache()
img = x[:, :-txt_seq_len, ...]
if get_sequence_parallel_state():
img = all_gather(img, dim=1)
img = img[:, ref_length:]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out['x'] = img
return out
return img
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum('nthwcopq->nctohpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double": sum([
sum(p.numel() for p in block.img_attn_qkv.parameters()) +
sum(p.numel() for p in block.img_attn_proj.parameters()) +
sum(p.numel() for p in block.img_mlp.parameters()) +
sum(p.numel() for p in block.txt_attn_qkv.parameters()) +
sum(p.numel() for p in block.txt_attn_proj.parameters()) +
sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]),
"single": sum([
sum(p.numel() for p in block.linear1.parameters()) +
sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = { # Attn+MLP / Total
'HYVideo-T/2': { # 9.0B / 12.5B
'depth_double_blocks': 20,
'depth_single_blocks': 40,
'rope_dim_list': [16, 56, 56],
'hidden_size': 3072,
'num_heads': 24,
'mlp_width_ratio': 4,
},
}