KevinNg99's picture
Initial commit.
43c5292
import os
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from hyimage.models.hunyuan.modules.posemb_layers import get_nd_rotary_pos_embed
from hyimage.models.hunyuan.modules.flash_attn_no_pad import get_cu_seqlens
from .activation_layers import get_activation_layer
from .embed_layers import PatchEmbed, PatchEmbed2D, TextProjection, TimestepEmbedder
from .mlp_layers import FinalLayer
from .models import MMDoubleStreamBlock, MMSingleStreamBlock
from .token_refiner import SingleTokenRefiner
from hyimage.models.text_encoder.byT5 import ByT5Mapper
def convert_hunyuan_dict_for_tensor_parallel(state_dict):
"""
Convert a Hunyuan model state dict to be compatible with tensor parallel architectures.
Args:
state_dict: Original state dict
Returns:
new_dict: Converted state dict
"""
new_dict = {}
for k, w in state_dict.items():
if k.startswith("double_blocks") and "attn_qkv.weight" in k:
hidden_size = w.shape[1]
k1 = k.replace("attn_qkv.weight", "attn_q.weight")
w1 = w[:hidden_size, :]
new_dict[k1] = w1
k2 = k.replace("attn_qkv.weight", "attn_k.weight")
w2 = w[hidden_size : 2 * hidden_size, :]
new_dict[k2] = w2
k3 = k.replace("attn_qkv.weight", "attn_v.weight")
w3 = w[-hidden_size:, :]
new_dict[k3] = w3
elif k.startswith("double_blocks") and "attn_qkv.bias" in k:
hidden_size = w.shape[0] // 3
k1 = k.replace("attn_qkv.bias", "attn_q.bias")
w1 = w[:hidden_size]
new_dict[k1] = w1
k2 = k.replace("attn_qkv.bias", "attn_k.bias")
w2 = w[hidden_size : 2 * hidden_size]
new_dict[k2] = w2
k3 = k.replace("attn_qkv.bias", "attn_v.bias")
w3 = w[-hidden_size:]
new_dict[k3] = w3
elif k.startswith("single_blocks") and "linear1" in k:
hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0]
k1 = k.replace("linear1", "linear1_q")
w1 = w[:hidden_size]
new_dict[k1] = w1
k2 = k.replace("linear1", "linear1_k")
w2 = w[hidden_size : 2 * hidden_size]
new_dict[k2] = w2
k3 = k.replace("linear1", "linear1_v")
w3 = w[2 * hidden_size : 3 * hidden_size]
new_dict[k3] = w3
k4 = k.replace("linear1", "linear1_mlp")
w4 = w[3 * hidden_size :]
new_dict[k4] = w4
elif k.startswith("single_blocks") and "linear2" in k:
k1 = k.replace("linear2", "linear2.fc")
new_dict[k1] = w
else:
new_dict[k] = w
return new_dict
def load_hunyuan_dit_state_dict(model, dit_model_name_or_path, strict=True, assign=False):
"""
Load a state dict for a Hunyuan model, handling both safetensors and torch formats.
Args:
model: Model instance to load weights into
dit_model_name_or_path: Path to the checkpoint file
strict: Whether to strictly enforce that the keys in state_dict match the model's keys
assign: If True, assign weights directly without copying
Returns:
model: The model with loaded weights
"""
from safetensors.torch import load_file as safetensors_load_file
if not os.path.exists(dit_model_name_or_path):
raise FileNotFoundError(f"Checkpoint file not found: {dit_model_name_or_path}")
if dit_model_name_or_path.endswith(".safetensors"):
state_dict = safetensors_load_file(dit_model_name_or_path)
else:
state_dict = torch.load(
dit_model_name_or_path,
map_location="cpu",
weights_only=True,
)
try:
state_dict = convert_hunyuan_dict_for_tensor_parallel(state_dict)
except Exception:
pass
model.load_state_dict(state_dict, strict=strict, assign=assign)
return model
class HYImageDiffusionTransformer(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 4,
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False,
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
text_states_dim: int = 4096,
rope_theta: int = 256,
glyph_byT5_v2: bool = False,
use_meanflow: bool = False,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
self.rope_theta = rope_theta
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
if hidden_size % heads_num != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.heads_num = heads_num
self.glyph_byT5_v2 = glyph_byT5_v2
if self.glyph_byT5_v2:
self.byt5_in = ByT5Mapper(
in_dim=1472,
out_dim=2048,
hidden_dim=2048,
out_dim1=hidden_size,
use_residual=False
)
# Image projection
if len(self.patch_size) == 3:
self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
elif len(self.patch_size) == 2:
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
else:
raise ValueError(f"Unsupported patch_size: {self.patch_size}")
# Text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
text_states_dim,
hidden_size,
heads_num,
depth=2,
**factory_kwargs,
)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
# Time modulation
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
# MeanFlow support: only create time_r_in when needed
self.time_r_in = (
TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
if use_meanflow
else None
)
self.use_meanflow = use_meanflow
# Guidance modulation
self.guidance_in = (
TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
if guidance_embed
else None
)
# Double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# Single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
def enable_deterministic(self):
"""Enable deterministic mode for all transformer blocks."""
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
"""Disable deterministic mode for all transformer blocks."""
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def get_rotary_pos_embed(self, rope_sizes):
"""
Get rotary position embeddings for the given sizes.
Args:
rope_sizes: Sizes for each rotary dimension.
Returns:
freqs_cos, freqs_sin: Cosine and sine frequencies for rotary embedding.
"""
target_ndim = 3
head_dim = self.hidden_size // self.heads_num
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.rope_theta,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
def reorder_txt_token(self, byt5_txt, txt, byt5_text_mask, text_mask):
"""
Reorder text tokens for ByT5 integration.
Args:
byt5_txt: ByT5 text embeddings.
txt: Text embeddings.
byt5_text_mask: Mask for ByT5 tokens.
text_mask: Mask for text tokens.
Returns:
reorder_txt: Reordered text embeddings.
reorder_mask: Reordered mask.
"""
reorder_txt = []
reorder_mask = []
for i in range(text_mask.shape[0]):
byt5_text_mask_i = byt5_text_mask[i].bool()
text_mask_i = text_mask[i].bool()
byt5_txt_i = byt5_txt[i]
txt_i = txt[i]
reorder_txt_i = torch.cat([
byt5_txt_i[byt5_text_mask_i],
txt_i[text_mask_i],
byt5_txt_i[~byt5_text_mask_i],
txt_i[~text_mask_i]
], dim=0)
reorder_mask_i = torch.cat([
byt5_text_mask_i[byt5_text_mask_i],
text_mask_i[text_mask_i],
byt5_text_mask_i[~byt5_text_mask_i],
text_mask_i[~text_mask_i]
], dim=0)
reorder_txt.append(reorder_txt_i)
reorder_mask.append(reorder_mask_i)
reorder_txt = torch.stack(reorder_txt)
reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64)
return reorder_txt, reorder_mask
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
text_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
output_features: bool = False,
output_features_stride: int = 8,
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
return_dict: bool = False,
guidance=None,
extra_kwargs=None,
*,
timesteps_r: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass for the transformer.
Parameters
----------
hidden_states : torch.Tensor
Input image tensor.
timestep : torch.LongTensor
Timestep tensor.
text_states : torch.Tensor
Text embeddings.
encoder_attention_mask : torch.Tensor
Attention mask for text.
output_features : bool, optional
Whether to output intermediate features.
output_features_stride : int, optional
Stride for outputting features.
freqs_cos, freqs_sin : torch.Tensor, optional
Precomputed rotary embeddings.
return_dict : bool, optional
Not supported.
guidance : torch.Tensor, optional
Guidance vector for distillation.
extra_kwargs : dict, optional
Extra arguments for ByT5.
timesteps_r : torch.LongTensor, optional
Additional timestep for MeanFlow.
Returns
-------
tuple
(img, features_list, shape)
"""
if guidance is None:
guidance = torch.tensor([6016.0], device=hidden_states.device, dtype=torch.bfloat16)
img = x = hidden_states
text_mask = encoder_attention_mask
t = timestep
txt = text_states
input_shape = x.shape
# Calculate spatial dimensions and get rotary embeddings
if len(input_shape) == 5:
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
if freqs_cos is None or freqs_sin is None:
freqs_cos, freqs_sin = self.get_rotary_pos_embed((tt, th, tw))
elif len(input_shape) == 4:
_, _, oh, ow = x.shape
th, tw = (
oh // self.patch_size[0],
ow // self.patch_size[1],
)
if freqs_cos is None or freqs_sin is None:
assert freqs_cos is None and freqs_sin is None, "freqs_cos and freqs_sin must be both None or both not None"
freqs_cos, freqs_sin = self.get_rotary_pos_embed((th, tw))
else:
raise ValueError(f"Unsupported hidden_states shape: {x.shape}")
img = self.img_in(img)
# Prepare modulation vectors
vec = self.time_in(t)
# MeanFlow support: merge timestep and timestep_r if available
if self.use_meanflow:
assert self.time_r_in is not None, "use_meanflow is True but time_r_in is None"
if timesteps_r is not None:
assert self.time_r_in is not None, "timesteps_r is not None but time_r_in is None"
vec_r = self.time_r_in(timesteps_r)
vec = (vec + vec_r) / 2
# Guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(guidance)
# Embed image and text
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
if self.glyph_byT5_v2:
byt5_text_states = extra_kwargs["byt5_text_states"]
byt5_text_mask = extra_kwargs["byt5_text_mask"]
byt5_txt = self.byt5_in(byt5_text_states)
txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# Calculate cu_seqlens and max_s for flash attention
cu_seqlens, max_s = get_cu_seqlens(text_mask, img_seq_len)
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# Pass through double stream blocks
for block in self.double_blocks:
double_block_args = [img, txt, vec, freqs_cis, text_mask, cu_seqlens, max_s]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks
x = torch.cat((img, txt), 1)
features_list = [] if output_features else None
if len(self.single_blocks) > 0:
for index, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
(freqs_cos, freqs_sin),
text_mask,
cu_seqlens,
max_s,
]
x = block(*single_block_args)
if output_features and index % output_features_stride == 0:
features_list.append(x[:, :img_seq_len, ...])
img = x[:, :img_seq_len, ...]
# Final layer
img = self.final_layer(img, vec)
# Unpatchify based on input shape
if len(input_shape) == 5:
img = self.unpatchify(img, tt, th, tw)
shape = (tt, th, tw)
elif len(input_shape) == 4:
img = self.unpatchify_2d(img, th, tw)
shape = (th, tw)
else:
raise ValueError(f"Unsupported input_shape: {input_shape}")
assert not return_dict, "return_dict is not supported."
if output_features:
features_list = torch.stack(features_list, dim=0)
else:
features_list = None
return (img, features_list, shape)
def unpatchify(self, x, t, h, w):
"""
Unpatchify 3D tensor.
Parameters
----------
x: torch.Tensor
Input tensor of shape (N, T, patch_size**2 * C)
t, h, w: int
Temporal and spatial dimensions
Returns
-------
torch.Tensor
Unpatchified tensor of shape (N, C, T*pt, H*ph, W*pw)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def unpatchify_2d(self, x, h, w):
"""
Unpatchify 2D tensor.
Parameters
----------
x: torch.Tensor
Input tensor of shape (N, T, patch_size**2 * C)
h, w: int
Spatial dimensions
Returns
-------
torch.Tensor
Unpatchified tensor of shape (N, C, H*ph, W*pw)
"""
c = self.unpatchify_channels
ph, pw = self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, c, ph, pw))
x = torch.einsum('nhwcpq->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * ph, w * pw))
return imgs