File size: 19,810 Bytes
43c5292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
import os
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin

from hyimage.models.hunyuan.modules.posemb_layers import get_nd_rotary_pos_embed
from hyimage.models.hunyuan.modules.flash_attn_no_pad import get_cu_seqlens

from .activation_layers import get_activation_layer
from .embed_layers import PatchEmbed, PatchEmbed2D, TextProjection, TimestepEmbedder
from .mlp_layers import FinalLayer
from .models import MMDoubleStreamBlock, MMSingleStreamBlock
from .token_refiner import SingleTokenRefiner

from hyimage.models.text_encoder.byT5 import ByT5Mapper


def convert_hunyuan_dict_for_tensor_parallel(state_dict):
    """
    Convert a Hunyuan model state dict to be compatible with tensor parallel architectures.

    Args:
        state_dict: Original state dict

    Returns:
        new_dict: Converted state dict
    """
    new_dict = {}
    for k, w in state_dict.items():
        if k.startswith("double_blocks") and "attn_qkv.weight" in k:
            hidden_size = w.shape[1]
            k1 = k.replace("attn_qkv.weight", "attn_q.weight")
            w1 = w[:hidden_size, :]
            new_dict[k1] = w1
            k2 = k.replace("attn_qkv.weight", "attn_k.weight")
            w2 = w[hidden_size : 2 * hidden_size, :]
            new_dict[k2] = w2
            k3 = k.replace("attn_qkv.weight", "attn_v.weight")
            w3 = w[-hidden_size:, :]
            new_dict[k3] = w3
        elif k.startswith("double_blocks") and "attn_qkv.bias" in k:
            hidden_size = w.shape[0] // 3
            k1 = k.replace("attn_qkv.bias", "attn_q.bias")
            w1 = w[:hidden_size]
            new_dict[k1] = w1
            k2 = k.replace("attn_qkv.bias", "attn_k.bias")
            w2 = w[hidden_size : 2 * hidden_size]
            new_dict[k2] = w2
            k3 = k.replace("attn_qkv.bias", "attn_v.bias")
            w3 = w[-hidden_size:]
            new_dict[k3] = w3
        elif k.startswith("single_blocks") and "linear1" in k:
            hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0]
            k1 = k.replace("linear1", "linear1_q")
            w1 = w[:hidden_size]
            new_dict[k1] = w1
            k2 = k.replace("linear1", "linear1_k")
            w2 = w[hidden_size : 2 * hidden_size]
            new_dict[k2] = w2
            k3 = k.replace("linear1", "linear1_v")
            w3 = w[2 * hidden_size : 3 * hidden_size]
            new_dict[k3] = w3
            k4 = k.replace("linear1", "linear1_mlp")
            w4 = w[3 * hidden_size :]
            new_dict[k4] = w4
        elif k.startswith("single_blocks") and "linear2" in k:
            k1 = k.replace("linear2", "linear2.fc")
            new_dict[k1] = w
        else:
            new_dict[k] = w
    return new_dict


def load_hunyuan_dit_state_dict(model, dit_model_name_or_path, strict=True, assign=False):
    """
    Load a state dict for a Hunyuan model, handling both safetensors and torch formats.

    Args:
        model: Model instance to load weights into
        dit_model_name_or_path: Path to the checkpoint file
        strict: Whether to strictly enforce that the keys in state_dict match the model's keys
        assign: If True, assign weights directly without copying

    Returns:
        model: The model with loaded weights
    """
    from safetensors.torch import load_file as safetensors_load_file

    if not os.path.exists(dit_model_name_or_path):
        raise FileNotFoundError(f"Checkpoint file not found: {dit_model_name_or_path}")

    if dit_model_name_or_path.endswith(".safetensors"):
        state_dict = safetensors_load_file(dit_model_name_or_path)
    else:
        state_dict = torch.load(
            dit_model_name_or_path,
            map_location="cpu",
            weights_only=True,
        )
    try:
        state_dict = convert_hunyuan_dict_for_tensor_parallel(state_dict)
    except Exception:
        pass
    model.load_state_dict(state_dict, strict=strict, assign=assign)
    return model


class HYImageDiffusionTransformer(ModelMixin, ConfigMixin):

    @register_to_config
    def __init__(
        self,
        patch_size: list = [1, 2, 2],
        in_channels: int = 4,
        out_channels: int = None,
        hidden_size: int = 3072,
        heads_num: int = 24,
        mlp_width_ratio: float = 4.0,
        mlp_act_type: str = "gelu_tanh",
        mm_double_blocks_depth: int = 20,
        mm_single_blocks_depth: int = 40,
        rope_dim_list: List[int] = [16, 56, 56],
        qkv_bias: bool = True,
        qk_norm: bool = True,
        qk_norm_type: str = "rms",
        guidance_embed: bool = False,
        text_projection: str = "single_refiner",
        use_attention_mask: bool = True,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        text_states_dim: int = 4096,
        rope_theta: int = 256,
        glyph_byT5_v2: bool = False,
        use_meanflow: bool = False,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.unpatchify_channels = self.out_channels
        self.guidance_embed = guidance_embed
        self.rope_dim_list = rope_dim_list
        self.rope_theta = rope_theta
        self.use_attention_mask = use_attention_mask
        self.text_projection = text_projection

        if hidden_size % heads_num != 0:
            raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
        pe_dim = hidden_size // heads_num
        if sum(rope_dim_list) != pe_dim:
            raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
        self.hidden_size = hidden_size
        self.heads_num = heads_num

        self.glyph_byT5_v2 = glyph_byT5_v2
        if self.glyph_byT5_v2:
            self.byt5_in = ByT5Mapper(
                in_dim=1472,
                out_dim=2048,
                hidden_dim=2048,
                out_dim1=hidden_size,
                use_residual=False
            )

        # Image projection
        if len(self.patch_size) == 3:
            self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
        elif len(self.patch_size) == 2:
            self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
        else:
            raise ValueError(f"Unsupported patch_size: {self.patch_size}")

        # Text projection
        if self.text_projection == "linear":
            self.txt_in = TextProjection(
                text_states_dim,
                self.hidden_size,
                get_activation_layer("silu"),
                **factory_kwargs,
            )
        elif self.text_projection == "single_refiner":
            self.txt_in = SingleTokenRefiner(
                text_states_dim,
                hidden_size,
                heads_num,
                depth=2,
                **factory_kwargs,
            )
        else:
            raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")

        # Time modulation
        self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)

        # MeanFlow support: only create time_r_in when needed
        self.time_r_in = (
            TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
            if use_meanflow
            else None
        )
        self.use_meanflow = use_meanflow

        # Guidance modulation
        self.guidance_in = (
            TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
            if guidance_embed
            else None
        )

        # Double blocks
        self.double_blocks = nn.ModuleList(
            [
                MMDoubleStreamBlock(
                    self.hidden_size,
                    self.heads_num,
                    mlp_width_ratio=mlp_width_ratio,
                    mlp_act_type=mlp_act_type,
                    qk_norm=qk_norm,
                    qk_norm_type=qk_norm_type,
                    qkv_bias=qkv_bias,
                    **factory_kwargs,
                )
                for _ in range(mm_double_blocks_depth)
            ]
        )

        # Single blocks
        self.single_blocks = nn.ModuleList(
            [
                MMSingleStreamBlock(
                    self.hidden_size,
                    self.heads_num,
                    mlp_width_ratio=mlp_width_ratio,
                    mlp_act_type=mlp_act_type,
                    qk_norm=qk_norm,
                    qk_norm_type=qk_norm_type,
                    **factory_kwargs,
                )
                for _ in range(mm_single_blocks_depth)
            ]
        )

        self.final_layer = FinalLayer(
            self.hidden_size,
            self.patch_size,
            self.out_channels,
            get_activation_layer("silu"),
            **factory_kwargs,
        )

    def enable_deterministic(self):
        """Enable deterministic mode for all transformer blocks."""
        for block in self.double_blocks:
            block.enable_deterministic()
        for block in self.single_blocks:
            block.enable_deterministic()

    def disable_deterministic(self):
        """Disable deterministic mode for all transformer blocks."""
        for block in self.double_blocks:
            block.disable_deterministic()
        for block in self.single_blocks:
            block.disable_deterministic()

    def get_rotary_pos_embed(self, rope_sizes):
        """
        Get rotary position embeddings for the given sizes.

        Args:
            rope_sizes: Sizes for each rotary dimension.

        Returns:
            freqs_cos, freqs_sin: Cosine and sine frequencies for rotary embedding.
        """
        target_ndim = 3
        head_dim = self.hidden_size // self.heads_num
        rope_dim_list = self.rope_dim_list
        if rope_dim_list is None:
            rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
        assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
        freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
            rope_dim_list,
            rope_sizes,
            theta=self.rope_theta,
            use_real=True,
            theta_rescale_factor=1,
        )
        return freqs_cos, freqs_sin

    def reorder_txt_token(self, byt5_txt, txt, byt5_text_mask, text_mask):
        """
        Reorder text tokens for ByT5 integration.

        Args:
            byt5_txt: ByT5 text embeddings.
            txt: Text embeddings.
            byt5_text_mask: Mask for ByT5 tokens.
            text_mask: Mask for text tokens.

        Returns:
            reorder_txt: Reordered text embeddings.
            reorder_mask: Reordered mask.
        """
        reorder_txt = []
        reorder_mask = []

        for i in range(text_mask.shape[0]):
            byt5_text_mask_i = byt5_text_mask[i].bool()
            text_mask_i = text_mask[i].bool()
            byt5_txt_i = byt5_txt[i]
            txt_i = txt[i]
            reorder_txt_i = torch.cat([
                byt5_txt_i[byt5_text_mask_i],
                txt_i[text_mask_i],
                byt5_txt_i[~byt5_text_mask_i],
                txt_i[~text_mask_i]
            ], dim=0)

            reorder_mask_i = torch.cat([
                byt5_text_mask_i[byt5_text_mask_i],
                text_mask_i[text_mask_i],
                byt5_text_mask_i[~byt5_text_mask_i],
                text_mask_i[~text_mask_i]
            ], dim=0)

            reorder_txt.append(reorder_txt_i)
            reorder_mask.append(reorder_mask_i)

        reorder_txt = torch.stack(reorder_txt)
        reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64)

        return reorder_txt, reorder_mask

    def forward(
        self,
        hidden_states: torch.Tensor,
        timestep: torch.LongTensor,
        text_states: torch.Tensor,
        encoder_attention_mask: torch.Tensor,
        output_features: bool = False,
        output_features_stride: int = 8,
        freqs_cos: Optional[torch.Tensor] = None,
        freqs_sin: Optional[torch.Tensor] = None,
        return_dict: bool = False,
        guidance=None,
        extra_kwargs=None,
        *,
        timesteps_r: Optional[torch.LongTensor] = None,
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Forward pass for the transformer.

        Parameters
        ----------
        hidden_states : torch.Tensor
            Input image tensor.
        timestep : torch.LongTensor
            Timestep tensor.
        text_states : torch.Tensor
            Text embeddings.
        encoder_attention_mask : torch.Tensor
            Attention mask for text.
        output_features : bool, optional
            Whether to output intermediate features.
        output_features_stride : int, optional
            Stride for outputting features.
        freqs_cos, freqs_sin : torch.Tensor, optional
            Precomputed rotary embeddings.
        return_dict : bool, optional
            Not supported.
        guidance : torch.Tensor, optional
            Guidance vector for distillation.
        extra_kwargs : dict, optional
            Extra arguments for ByT5.
        timesteps_r : torch.LongTensor, optional
            Additional timestep for MeanFlow.

        Returns
        -------
        tuple
            (img, features_list, shape)
        """
        if guidance is None:
            guidance = torch.tensor([6016.0], device=hidden_states.device, dtype=torch.bfloat16)
        img = x = hidden_states
        text_mask = encoder_attention_mask
        t = timestep
        txt = text_states
        input_shape = x.shape

        # Calculate spatial dimensions and get rotary embeddings
        if len(input_shape) == 5:
            _, _, ot, oh, ow = x.shape
            tt, th, tw = (
                ot // self.patch_size[0],
                oh // self.patch_size[1],
                ow // self.patch_size[2],
            )
            if freqs_cos is None or freqs_sin is None:
                freqs_cos, freqs_sin = self.get_rotary_pos_embed((tt, th, tw))
        elif len(input_shape) == 4:
            _, _, oh, ow = x.shape
            th, tw = (
                oh // self.patch_size[0],
                ow // self.patch_size[1],
            )
            if freqs_cos is None or freqs_sin is None:
                assert freqs_cos is None and freqs_sin is None, "freqs_cos and freqs_sin must be both None or both not None"
                freqs_cos, freqs_sin = self.get_rotary_pos_embed((th, tw))
        else:
            raise ValueError(f"Unsupported hidden_states shape: {x.shape}")

        img = self.img_in(img)

        # Prepare modulation vectors
        vec = self.time_in(t)

        # MeanFlow support: merge timestep and timestep_r if available
        if self.use_meanflow:
            assert self.time_r_in is not None, "use_meanflow is True but time_r_in is None"
        if timesteps_r is not None:
            assert self.time_r_in is not None, "timesteps_r is not None but time_r_in is None"
            vec_r = self.time_r_in(timesteps_r)
            vec = (vec + vec_r) / 2

        # Guidance modulation
        if self.guidance_embed:
            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            vec = vec + self.guidance_in(guidance)

        # Embed image and text
        if self.text_projection == "linear":
            txt = self.txt_in(txt)
        elif self.text_projection == "single_refiner":
            txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
        else:
            raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")

        if self.glyph_byT5_v2:
            byt5_text_states = extra_kwargs["byt5_text_states"]
            byt5_text_mask = extra_kwargs["byt5_text_mask"]
            byt5_txt = self.byt5_in(byt5_text_states)
            txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)

        txt_seq_len = txt.shape[1]
        img_seq_len = img.shape[1]

        # Calculate cu_seqlens and max_s for flash attention
        cu_seqlens, max_s = get_cu_seqlens(text_mask, img_seq_len)

        freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None

        # Pass through double stream blocks
        for block in self.double_blocks:
            double_block_args = [img, txt, vec, freqs_cis, text_mask, cu_seqlens, max_s]
            img, txt = block(*double_block_args)

        # Merge txt and img to pass through single stream blocks
        x = torch.cat((img, txt), 1)
        features_list = [] if output_features else None

        if len(self.single_blocks) > 0:
            for index, block in enumerate(self.single_blocks):
                single_block_args = [
                    x,
                    vec,
                    txt_seq_len,
                    (freqs_cos, freqs_sin),
                    text_mask,
                    cu_seqlens,
                    max_s,
                ]
                x = block(*single_block_args)
                if output_features and index % output_features_stride == 0:
                    features_list.append(x[:, :img_seq_len, ...])

        img = x[:, :img_seq_len, ...]

        # Final layer
        img = self.final_layer(img, vec)

        # Unpatchify based on input shape
        if len(input_shape) == 5:
            img = self.unpatchify(img, tt, th, tw)
            shape = (tt, th, tw)
        elif len(input_shape) == 4:
            img = self.unpatchify_2d(img, th, tw)
            shape = (th, tw)
        else:
            raise ValueError(f"Unsupported input_shape: {input_shape}")

        assert not return_dict, "return_dict is not supported."

        if output_features:
            features_list = torch.stack(features_list, dim=0)
        else:
            features_list = None

        return (img, features_list, shape)

    def unpatchify(self, x, t, h, w):
        """
        Unpatchify 3D tensor.

        Parameters
        ----------
        x: torch.Tensor
            Input tensor of shape (N, T, patch_size**2 * C)
        t, h, w: int
            Temporal and spatial dimensions

        Returns
        -------
        torch.Tensor
            Unpatchified tensor of shape (N, C, T*pt, H*ph, W*pw)
        """
        c = self.unpatchify_channels
        pt, ph, pw = self.patch_size
        assert t * h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
        x = torch.einsum("nthwcopq->nctohpwq", x)
        imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

        return imgs

    def unpatchify_2d(self, x, h, w):
        """
        Unpatchify 2D tensor.
        
        Parameters
        ----------
        x: torch.Tensor
            Input tensor of shape (N, T, patch_size**2 * C)
        h, w: int
            Spatial dimensions

        Returns
        -------
        torch.Tensor
            Unpatchified tensor of shape (N, C, H*ph, W*pw)
        """
        c = self.unpatchify_channels
        ph, pw = self.patch_size
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, c, ph, pw))
        x = torch.einsum('nhwcpq->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * ph, w * pw))
        return imgs