JasonSmithSO's picture
Upload 578 files
8866644 verified
"""
List of all HYDiT model types / settings
"""
from argparse import Namespace
hydit_args = Namespace(**{ # normally from argparse
"infer_mode": "torch",
"norm": "layer",
"learn_sigma": True,
"text_states_dim": 1024,
"text_states_dim_t5": 2048,
"text_len": 77,
"text_len_t5": 256,
})
hydit_conf = {
"G/2": { # Seems to be the main one
"unet_config": {
"depth" : 40,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1408,
"mlp_ratio" : 4.3637,
"input_size": (1024//8, 1024//8),
"args": hydit_args,
},
"sampling_settings" : {
"beta_schedule" : "linear",
"linear_start" : 0.00085,
"linear_end" : 0.03,
"timesteps" : 1000,
},
},
"G/2-1.2": {
"unet_config": {
"depth" : 40,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1408,
"mlp_ratio" : 4.3637,
"input_size": (1024//8, 1024//8),
"cond_style": False,
"cond_res" : False,
"args": hydit_args,
},
"sampling_settings" : {
"beta_schedule" : "linear",
"linear_start" : 0.00085,
"linear_end" : 0.018,
"timesteps" : 1000,
},
}
}
# these are the same as regular DiT, I think
from ..DiT.conf import dit_conf
for name in ["XL/2", "L/2", "B/2"]:
hydit_conf[name] = {
"unet_config": dit_conf[name]["unet_config"].copy(),
"sampling_settings": hydit_conf["G/2"]["sampling_settings"],
}
hydit_conf[name]["unet_config"]["args"] = hydit_args