File size: 2,807 Bytes
6cd6a16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
from dataclasses import dataclass

ENCODED_TEXT_DIM = 4096
POOLED_TEXT_DIM = 2048
VAE_COMPRESSION_RATIO = 8


@dataclass
class MMDiTConfig:
    # General
    num_layers: int = 12
    hidden_dim: int = 768  # common hidden dimension for the transformer arch
    patch_size: int = 2
    image_dim: int = 224
    in_channel: int = 4
    out_channel: int = 4
    modulation_dim: int = ENCODED_TEXT_DIM  # input dimension for modulation layer (shifting and scaling)
    height: int = 1024
    width: int = 1024
    vae_compression: int = VAE_COMPRESSION_RATIO  # reducing resolution with the VAE
    vae_type: str = "SD3"  # SDXL or SD3
    pos_emb_size: int = None
    conv_header: bool = False

    # Outside of the MMDiT block
    time_embed_dim: int = 2048  # Initial projection (discrete_time embedding) output dim
    pooled_text_dim: int = POOLED_TEXT_DIM
    text_emb_dim: int = 768

    # MMDiTBlock
    t_emb_dim: int = 256
    attn_embed_dim: int = 768  # hidden dimension during the attention
    mlp_hidden_dim: int = 2048
    attn_mode: str = None  # {'flash', 'sdpa', None}
    use_final_layer_norm: bool = False
    use_time_token_in_attn: bool = False

    # GroupedQueryAttention
    num_attention_heads: int = 12
    num_key_value_heads: int = 6
    use_scaled_dot_product_attention: bool = True
    dropout: float = 0.0

    # Modulation
    use_modulation: bool = True
    modulation_type: str = "film"  # Choose from 'film', 'adain', or 'spade'

    # Register tokens
    register_token_num: int = 4
    additional_register_token_num: int = 12

    # use dinov2 feature-align loss
    dinov2_feature_align_loss: bool = False
    feature_align_loss_weight: float = 0.5
    num_feature_align_layers: int = 8  # number of transformer layers to calculate feature-align loss

    # Personalization related
    image_encoder_name: str = None  # if set, the persoanlized image encoder will be loaded
    freeze_dit_backbone: bool = False

    # Preference optimization
    preference_train: bool = False
    lora_rank: int = 64 
    lora_alpha: int = 8

    skip_register_token_num: int = 0

    @classmethod
    def from_json_file(cls, json_file):
        """
        Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.

        Args:
            json_file (`str` or `os.PathLike`):
                Path to the JSON file containing the parameters.

        Returns:
            [`PretrainedConfig`]: The configuration object instantiated from that JSON file.

        """
        config_dict = cls._dict_from_json_file(json_file)
        return cls(**config_dict)

    @classmethod
    def _dict_from_json_file(cls, json_file):
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        return json.loads(text)