File size: 5,395 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# This is for loading the CLIP (bert?) + mT5 encoder for HunYuanDiT
import os
import torch
from transformers import AutoTokenizer, modeling_utils
from transformers import T5Config, T5EncoderModel, BertConfig, BertModel

from comfy import model_management
import comfy.model_patcher
import comfy.utils

class mT5Model(torch.nn.Module):
	def __init__(self, textmodel_json_config=None, device="cpu", max_length=256, freeze=True, dtype=None):
		super().__init__()
		self.device = device
		self.dtype = dtype
		self.max_length = max_length
		if textmodel_json_config is None:
			textmodel_json_config = os.path.join(
				os.path.dirname(os.path.realpath(__file__)),
				f"config_mt5.json"
			)
		config = T5Config.from_json_file(textmodel_json_config)
		with modeling_utils.no_init_weights():
			self.transformer = T5EncoderModel(config)
		self.to(dtype)
		if freeze:
			self.freeze()

	def freeze(self):
		self.transformer = self.transformer.eval()
		for param in self.parameters():
			param.requires_grad = False

	def load_sd(self, sd):
		return self.transformer.load_state_dict(sd, strict=False)

	def to(self, *args, **kwargs):
		return self.transformer.to(*args, **kwargs)

class hyCLIPModel(torch.nn.Module):
	def __init__(self, textmodel_json_config=None, device="cpu", max_length=77, freeze=True, dtype=None):
		super().__init__()
		self.device = device
		self.dtype = dtype
		self.max_length = max_length
		if textmodel_json_config is None:
			textmodel_json_config = os.path.join(
				os.path.dirname(os.path.realpath(__file__)),
				f"config_clip.json"
			)
		config = BertConfig.from_json_file(textmodel_json_config)
		with modeling_utils.no_init_weights():
			self.transformer = BertModel(config)
		self.to(dtype)
		if freeze:
			self.freeze()

	def freeze(self):
		self.transformer = self.transformer.eval()
		for param in self.parameters():
			param.requires_grad = False

	def load_sd(self, sd):
		return self.transformer.load_state_dict(sd, strict=False)

	def to(self, *args, **kwargs):
		return self.transformer.to(*args, **kwargs)

class EXM_HyDiT_Tenc_Temp:
	def __init__(self, no_init=False, device="cpu", dtype=None, model_class="mT5", *kwargs):
		if no_init:
			return

		size = 8 if model_class == "mT5" else 2
		if dtype == torch.float32:
			size *= 2
		size *= (1024**3)

		if device == "auto":
			self.load_device = model_management.text_encoder_device()
			self.offload_device = model_management.text_encoder_offload_device()
			self.init_device = "cpu"
		elif device == "cpu":
			size = 0 # doesn't matter
			self.load_device = "cpu"
			self.offload_device = "cpu"
			self.init_device="cpu"
		elif device.startswith("cuda"):
			print("Direct CUDA device override!\nVRAM will not be freed by default.")
			size = 0 # not used
			self.load_device = device
			self.offload_device = device
			self.init_device = device
		else:
			self.load_device = model_management.get_torch_device()
			self.offload_device = "cpu"
			self.init_device="cpu"

		self.dtype = dtype
		self.device = self.load_device
		if model_class == "mT5":
			self.cond_stage_model = mT5Model(
				device         = self.load_device,
				dtype          = self.dtype,
			)
			tokenizer_args = {"subfolder": "t2i/mt5"} # web
			tokenizer_path = os.path.join( # local
				os.path.dirname(os.path.realpath(__file__)),
				"mt5_tokenizer",
			)
		else:
			self.cond_stage_model = hyCLIPModel(
				device         = self.load_device,
				dtype          = self.dtype,
			)
			tokenizer_args = {"subfolder": "t2i/tokenizer",} # web
			tokenizer_path = os.path.join( # local
				os.path.dirname(os.path.realpath(__file__)),
				"tokenizer",
			)
		# self.tokenizer = AutoTokenizer.from_pretrained(
			# "Tencent-Hunyuan/HunyuanDiT",
			# **tokenizer_args
		# )
		self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
		self.patcher = comfy.model_patcher.ModelPatcher(
			self.cond_stage_model,
			load_device    = self.load_device,
			offload_device = self.offload_device,
			current_device = self.load_device,
			size           = size,
		)

	def clone(self):
		n = EXM_HyDiT_Tenc_Temp(no_init=True)
		n.patcher = self.patcher.clone()
		n.cond_stage_model = self.cond_stage_model
		n.tokenizer = self.tokenizer
		return n

	def load_sd(self, sd):
		return self.cond_stage_model.load_sd(sd)

	def get_sd(self):
		return self.cond_stage_model.state_dict()

	def load_model(self):
		if self.load_device != "cpu":
			model_management.load_model_gpu(self.patcher)
		return self.patcher

	def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
		return self.patcher.add_patches(patches, strength_patch, strength_model)

	def get_key_patches(self):
		return self.patcher.get_key_patches()

def load_clip(model_path, **kwargs):
	model = EXM_HyDiT_Tenc_Temp(model_class="clip", **kwargs)
	sd = comfy.utils.load_torch_file(model_path)

	prefix = "bert."
	state_dict = {}
	for key in sd:
		nkey = key
		if key.startswith(prefix):
				nkey = key[len(prefix):]
		state_dict[nkey] = sd[key]

	m, e = model.load_sd(state_dict)
	if len(m) > 0 or len(e) > 0:
		print(f"HYDiT: clip missing {len(m)} keys ({len(e)} extra)")
	return model

def load_t5(model_path, **kwargs):
	model = EXM_HyDiT_Tenc_Temp(model_class="mT5", **kwargs)
	sd = comfy.utils.load_torch_file(model_path)
	m, e = model.load_sd(sd)
	if len(m) > 0 or len(e) > 0:
		print(f"HYDiT: mT5 missing {len(m)} keys ({len(e)} extra)")
	return model