Spaces:
Configuration error
Configuration error
File size: 5,395 Bytes
8866644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# This is for loading the CLIP (bert?) + mT5 encoder for HunYuanDiT
import os
import torch
from transformers import AutoTokenizer, modeling_utils
from transformers import T5Config, T5EncoderModel, BertConfig, BertModel
from comfy import model_management
import comfy.model_patcher
import comfy.utils
class mT5Model(torch.nn.Module):
def __init__(self, textmodel_json_config=None, device="cpu", max_length=256, freeze=True, dtype=None):
super().__init__()
self.device = device
self.dtype = dtype
self.max_length = max_length
if textmodel_json_config is None:
textmodel_json_config = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
f"config_mt5.json"
)
config = T5Config.from_json_file(textmodel_json_config)
with modeling_utils.no_init_weights():
self.transformer = T5EncoderModel(config)
self.to(dtype)
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
def to(self, *args, **kwargs):
return self.transformer.to(*args, **kwargs)
class hyCLIPModel(torch.nn.Module):
def __init__(self, textmodel_json_config=None, device="cpu", max_length=77, freeze=True, dtype=None):
super().__init__()
self.device = device
self.dtype = dtype
self.max_length = max_length
if textmodel_json_config is None:
textmodel_json_config = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
f"config_clip.json"
)
config = BertConfig.from_json_file(textmodel_json_config)
with modeling_utils.no_init_weights():
self.transformer = BertModel(config)
self.to(dtype)
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
def to(self, *args, **kwargs):
return self.transformer.to(*args, **kwargs)
class EXM_HyDiT_Tenc_Temp:
def __init__(self, no_init=False, device="cpu", dtype=None, model_class="mT5", *kwargs):
if no_init:
return
size = 8 if model_class == "mT5" else 2
if dtype == torch.float32:
size *= 2
size *= (1024**3)
if device == "auto":
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.init_device = "cpu"
elif device == "cpu":
size = 0 # doesn't matter
self.load_device = "cpu"
self.offload_device = "cpu"
self.init_device="cpu"
elif device.startswith("cuda"):
print("Direct CUDA device override!\nVRAM will not be freed by default.")
size = 0 # not used
self.load_device = device
self.offload_device = device
self.init_device = device
else:
self.load_device = model_management.get_torch_device()
self.offload_device = "cpu"
self.init_device="cpu"
self.dtype = dtype
self.device = self.load_device
if model_class == "mT5":
self.cond_stage_model = mT5Model(
device = self.load_device,
dtype = self.dtype,
)
tokenizer_args = {"subfolder": "t2i/mt5"} # web
tokenizer_path = os.path.join( # local
os.path.dirname(os.path.realpath(__file__)),
"mt5_tokenizer",
)
else:
self.cond_stage_model = hyCLIPModel(
device = self.load_device,
dtype = self.dtype,
)
tokenizer_args = {"subfolder": "t2i/tokenizer",} # web
tokenizer_path = os.path.join( # local
os.path.dirname(os.path.realpath(__file__)),
"tokenizer",
)
# self.tokenizer = AutoTokenizer.from_pretrained(
# "Tencent-Hunyuan/HunyuanDiT",
# **tokenizer_args
# )
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.patcher = comfy.model_patcher.ModelPatcher(
self.cond_stage_model,
load_device = self.load_device,
offload_device = self.offload_device,
current_device = self.load_device,
size = size,
)
def clone(self):
n = EXM_HyDiT_Tenc_Temp(no_init=True)
n.patcher = self.patcher.clone()
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
return n
def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
def load_model(self):
if self.load_device != "cpu":
model_management.load_model_gpu(self.patcher)
return self.patcher
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
def get_key_patches(self):
return self.patcher.get_key_patches()
def load_clip(model_path, **kwargs):
model = EXM_HyDiT_Tenc_Temp(model_class="clip", **kwargs)
sd = comfy.utils.load_torch_file(model_path)
prefix = "bert."
state_dict = {}
for key in sd:
nkey = key
if key.startswith(prefix):
nkey = key[len(prefix):]
state_dict[nkey] = sd[key]
m, e = model.load_sd(state_dict)
if len(m) > 0 or len(e) > 0:
print(f"HYDiT: clip missing {len(m)} keys ({len(e)} extra)")
return model
def load_t5(model_path, **kwargs):
model = EXM_HyDiT_Tenc_Temp(model_class="mT5", **kwargs)
sd = comfy.utils.load_torch_file(model_path)
m, e = model.load_sd(sd)
if len(m) > 0 or len(e) > 0:
print(f"HYDiT: mT5 missing {len(m)} keys ({len(e)} extra)")
return model
|