Spaces:
Running
on
A100
Running
on
A100
File size: 4,430 Bytes
43c5292 1a63574 43c5292 1a63574 43c5292 1a63574 43c5292 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import copy
from hyimage.common.config import LazyCall as L
from hyimage.models.hunyuan.configs.hunyuanimage_config import (
hunyuanimage_v2_1_cfg,
hunyuanimage_v2_1_distilled_cfg,
hunyuanimage_refiner_cfg,
)
from hyimage.models.vae import load_refiner_vae, load_vae
from hyimage.common.config.base_config import (
DiTConfig,
RepromptConfig,
TextEncoderConfig,
VAEConfig,
)
from hyimage.models.text_encoder import TextEncoder
HUNYUANIMAGE_V2_1_MODEL_ROOT = os.environ.get("HUNYUANIMAGE_V2_1_MODEL_ROOT", "./ckpts")
# =============================================================================
# MODEL CONFIGURATIONS
# =============================================================================
# =============================================================================
# V2.1 MODELS
# =============================================================================
def HUNYUANIMAGE_V2_1_TEXT_ENCODER(**kwargs):
return TextEncoderConfig(
model=L(TextEncoder)(
text_encoder_type="llm",
max_length=1000,
text_encoder_precision='fp16',
tokenizer_type="llm",
text_encoder_path=None,
prompt_template=None,
prompt_template_video=None,
hidden_state_skip_layer=2,
apply_final_norm=False,
reproduce=False,
logger=None,
device=None,
),
prompt_template="dit-llm-encode-v2",
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/text_encoder",
text_len=1000,
)
def HUNYUANIMAGE_V2_1_VAE_32x(**kwargs):
return VAEConfig(
model=L(load_vae)(
vae_path=None,
device="cuda",
),
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/vae/vae_2_1",
cpu_offload=False,
)
def HUNYUANIMAGE_V2_1_DIT(**kwargs):
return DiTConfig(
model=copy.deepcopy(hunyuanimage_v2_1_cfg),
use_lora=False,
use_cpu_offload=False,
gradient_checkpointing=True,
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1.safetensors",
use_compile=False,
)
def HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL(**kwargs):
return DiTConfig(
model=copy.deepcopy(hunyuanimage_v2_1_distilled_cfg),
use_lora=False,
use_cpu_offload=False,
gradient_checkpointing=True,
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1-distilled.safetensors",
use_compile=False,
)
# =============================================================================
# REFINER MODELS
# =============================================================================
def HUNYUANIMAGE_REFINER_DIT(**kwargs):
return DiTConfig(
model=copy.deepcopy(hunyuanimage_refiner_cfg),
use_lora=False,
use_cpu_offload=False,
gradient_checkpointing=True,
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage-refiner.safetensors",
use_compile=False,
)
def HUNYUANIMAGE_REFINER_VAE_16x(**kwargs):
return VAEConfig(
model=L(load_refiner_vae)(
vae_path=None,
device="cuda",
),
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/vae/vae_refiner",
cpu_offload=False,
)
def HUNYUANIMAGE_REFINER_TEXT_ENCODER(**kwargs):
return TextEncoderConfig(
model=L(TextEncoder)(
text_encoder_type="llm",
max_length=1000,
text_encoder_precision='fp16',
tokenizer_type="llm",
text_encoder_path=None,
prompt_template=None,
prompt_template_video=None,
hidden_state_skip_layer=2,
apply_final_norm=False,
reproduce=False,
logger=None,
device=None,
),
prompt_template="dit-llm-encode",
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/text_encoder",
text_len=256,
)
# =============================================================================
# SPECIALIZED MODELS
# =============================================================================
def HUNYUANIMAGE_REPROMPT(**kwargs):
from hyimage.models.reprompt import RePrompt
return RepromptConfig(
model=L(RePrompt)(
models_root_path=None,
device_map="auto",
),
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/reprompt",
) |