|
import json |
|
from dataclasses import dataclass |
|
|
|
ENCODED_TEXT_DIM = 4096 |
|
POOLED_TEXT_DIM = 2048 |
|
VAE_COMPRESSION_RATIO = 8 |
|
|
|
|
|
@dataclass |
|
class MMDiTConfig: |
|
|
|
num_layers: int = 12 |
|
hidden_dim: int = 768 |
|
patch_size: int = 2 |
|
image_dim: int = 224 |
|
in_channel: int = 4 |
|
out_channel: int = 4 |
|
modulation_dim: int = ENCODED_TEXT_DIM |
|
height: int = 1024 |
|
width: int = 1024 |
|
vae_compression: int = VAE_COMPRESSION_RATIO |
|
vae_type: str = "SD3" |
|
pos_emb_size: int = None |
|
conv_header: bool = False |
|
|
|
|
|
time_embed_dim: int = 2048 |
|
pooled_text_dim: int = POOLED_TEXT_DIM |
|
text_emb_dim: int = 768 |
|
|
|
|
|
t_emb_dim: int = 256 |
|
attn_embed_dim: int = 768 |
|
mlp_hidden_dim: int = 2048 |
|
attn_mode: str = None |
|
use_final_layer_norm: bool = False |
|
use_time_token_in_attn: bool = False |
|
|
|
|
|
num_attention_heads: int = 12 |
|
num_key_value_heads: int = 6 |
|
use_scaled_dot_product_attention: bool = True |
|
dropout: float = 0.0 |
|
|
|
|
|
use_modulation: bool = True |
|
modulation_type: str = "film" |
|
|
|
|
|
register_token_num: int = 4 |
|
additional_register_token_num: int = 12 |
|
|
|
|
|
dinov2_feature_align_loss: bool = False |
|
feature_align_loss_weight: float = 0.5 |
|
num_feature_align_layers: int = 8 |
|
|
|
|
|
image_encoder_name: str = None |
|
freeze_dit_backbone: bool = False |
|
|
|
|
|
preference_train: bool = False |
|
lora_rank: int = 64 |
|
lora_alpha: int = 8 |
|
|
|
skip_register_token_num: int = 0 |
|
|
|
@classmethod |
|
def from_json_file(cls, json_file): |
|
""" |
|
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. |
|
|
|
Args: |
|
json_file (`str` or `os.PathLike`): |
|
Path to the JSON file containing the parameters. |
|
|
|
Returns: |
|
[`PretrainedConfig`]: The configuration object instantiated from that JSON file. |
|
|
|
""" |
|
config_dict = cls._dict_from_json_file(json_file) |
|
return cls(**config_dict) |
|
|
|
@classmethod |
|
def _dict_from_json_file(cls, json_file): |
|
with open(json_file, "r", encoding="utf-8") as reader: |
|
text = reader.read() |
|
return json.loads(text) |
|
|