JasonSmithSO's picture
Upload 578 files
8866644 verified
import os
import folder_paths
from copy import deepcopy
from .conf import hydit_conf
from .loader import load_hydit
class HYDiTCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
"model": (list(hydit_conf.keys()),{"default":"G/2"}),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_checkpoint"
CATEGORY = "ExtraModels/HunyuanDiT"
TITLE = "Hunyuan DiT Checkpoint Loader"
def load_checkpoint(self, ckpt_name, model):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
model_conf = hydit_conf[model]
model = load_hydit(
model_path = ckpt_path,
model_conf = model_conf,
)
return (model,)
#### temp stuff for the text encoder ####
import torch
from .tenc import load_clip, load_t5
from ..utils.dtype import string_to_dtype
dtypes = [
"default",
"auto (comfy)",
"FP32",
"FP16",
"BF16"
]
class HYDiTTextEncoderLoader:
@classmethod
def INPUT_TYPES(s):
devices = ["auto", "cpu", "gpu"]
# hack for using second GPU as offload
for k in range(1, torch.cuda.device_count()):
devices.append(f"cuda:{k}")
return {
"required": {
"clip_name": (folder_paths.get_filename_list("clip"),),
"mt5_name": (folder_paths.get_filename_list("t5"),),
"device": (devices, {"default":"cpu"}),
"dtype": (dtypes,),
}
}
RETURN_TYPES = ("CLIP", "T5")
FUNCTION = "load_model"
CATEGORY = "ExtraModels/HunyuanDiT"
TITLE = "Hunyuan DiT Text Encoder Loader"
def load_model(self, clip_name, mt5_name, device, dtype):
dtype = string_to_dtype(dtype, "text_encoder")
if device == "cpu":
assert dtype in [None, torch.float32, torch.bfloat16], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default' or 'bf16'."
clip = load_clip(
model_path = folder_paths.get_full_path("clip", clip_name),
device = device,
dtype = dtype,
)
t5 = load_t5(
model_path = folder_paths.get_full_path("t5", mt5_name),
device = device,
dtype = dtype,
)
return(clip, t5)
class HYDiTTextEncode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"text": ("STRING", {"multiline": True}),
"text_t5": ("STRING", {"multiline": True}),
"CLIP": ("CLIP",),
"T5": ("T5",),
}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "ExtraModels/HunyuanDiT"
TITLE = "Hunyuan DiT Text Encode"
def encode(self, text, text_t5, CLIP, T5):
# T5
T5.load_model()
t5_pre = T5.tokenizer(
text_t5,
max_length = T5.cond_stage_model.max_length,
padding = 'max_length',
truncation = True,
return_attention_mask = True,
add_special_tokens = True,
return_tensors = 'pt'
)
t5_mask = t5_pre["attention_mask"]
with torch.no_grad():
t5_outs = T5.cond_stage_model.transformer(
input_ids = t5_pre["input_ids"].to(T5.load_device),
attention_mask = t5_mask.to(T5.load_device),
output_hidden_states = True,
)
# to-do: replace -1 for clip skip
t5_embs = t5_outs["hidden_states"][-1].float().cpu()
# "clip"
CLIP.load_model()
clip_pre = CLIP.tokenizer(
text,
max_length = CLIP.cond_stage_model.max_length,
padding = 'max_length',
truncation = True,
return_attention_mask = True,
add_special_tokens = True,
return_tensors = 'pt'
)
clip_mask = clip_pre["attention_mask"]
with torch.no_grad():
clip_outs = CLIP.cond_stage_model.transformer(
input_ids = clip_pre["input_ids"].to(CLIP.load_device),
attention_mask = clip_mask.to(CLIP.load_device),
)
# to-do: add hidden states
clip_embs = clip_outs[0].float().cpu()
# combined cond
return ([[
clip_embs, {
"context_t5": t5_embs,
"context_mask": clip_mask.float(),
"context_t5_mask": t5_mask.float()
}
]],)
class HYDiTTextEncodeSimple(HYDiTTextEncode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"text": ("STRING", {"multiline": True}),
"CLIP": ("CLIP",),
"T5": ("T5",),
}
}
FUNCTION = "encode_simple"
TITLE = "Hunyuan DiT Text Encode (simple)"
def encode_simple(self, text, **args):
return self.encode(text=text, text_t5=text, **args)
class HYDiTSrcSizeCond:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond": ("CONDITIONING", ),
"width": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}),
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("cond",)
FUNCTION = "add_cond"
CATEGORY = "ExtraModels/HunyuanDiT"
TITLE = "Hunyuan DiT Size Conditioning (advanced)"
def add_cond(self, cond, width, height):
cond = deepcopy(cond)
for c in range(len(cond)):
cond[c][1].update({
"src_size_cond": [[height, width]],
})
return (cond,)
NODE_CLASS_MAPPINGS = {
"HYDiTCheckpointLoader": HYDiTCheckpointLoader,
"HYDiTTextEncoderLoader": HYDiTTextEncoderLoader,
"HYDiTTextEncode": HYDiTTextEncode,
"HYDiTTextEncodeSimple": HYDiTTextEncodeSimple,
"HYDiTSrcSizeCond": HYDiTSrcSizeCond,
}