Spaces:
Configuration error
Configuration error
import os | |
import folder_paths | |
from copy import deepcopy | |
from .conf import hydit_conf | |
from .loader import load_hydit | |
class HYDiTCheckpointLoader: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),), | |
"model": (list(hydit_conf.keys()),{"default":"G/2"}), | |
} | |
} | |
RETURN_TYPES = ("MODEL",) | |
RETURN_NAMES = ("model",) | |
FUNCTION = "load_checkpoint" | |
CATEGORY = "ExtraModels/HunyuanDiT" | |
TITLE = "Hunyuan DiT Checkpoint Loader" | |
def load_checkpoint(self, ckpt_name, model): | |
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
model_conf = hydit_conf[model] | |
model = load_hydit( | |
model_path = ckpt_path, | |
model_conf = model_conf, | |
) | |
return (model,) | |
#### temp stuff for the text encoder #### | |
import torch | |
from .tenc import load_clip, load_t5 | |
from ..utils.dtype import string_to_dtype | |
dtypes = [ | |
"default", | |
"auto (comfy)", | |
"FP32", | |
"FP16", | |
"BF16" | |
] | |
class HYDiTTextEncoderLoader: | |
def INPUT_TYPES(s): | |
devices = ["auto", "cpu", "gpu"] | |
# hack for using second GPU as offload | |
for k in range(1, torch.cuda.device_count()): | |
devices.append(f"cuda:{k}") | |
return { | |
"required": { | |
"clip_name": (folder_paths.get_filename_list("clip"),), | |
"mt5_name": (folder_paths.get_filename_list("t5"),), | |
"device": (devices, {"default":"cpu"}), | |
"dtype": (dtypes,), | |
} | |
} | |
RETURN_TYPES = ("CLIP", "T5") | |
FUNCTION = "load_model" | |
CATEGORY = "ExtraModels/HunyuanDiT" | |
TITLE = "Hunyuan DiT Text Encoder Loader" | |
def load_model(self, clip_name, mt5_name, device, dtype): | |
dtype = string_to_dtype(dtype, "text_encoder") | |
if device == "cpu": | |
assert dtype in [None, torch.float32, torch.bfloat16], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default' or 'bf16'." | |
clip = load_clip( | |
model_path = folder_paths.get_full_path("clip", clip_name), | |
device = device, | |
dtype = dtype, | |
) | |
t5 = load_t5( | |
model_path = folder_paths.get_full_path("t5", mt5_name), | |
device = device, | |
dtype = dtype, | |
) | |
return(clip, t5) | |
class HYDiTTextEncode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"text": ("STRING", {"multiline": True}), | |
"text_t5": ("STRING", {"multiline": True}), | |
"CLIP": ("CLIP",), | |
"T5": ("T5",), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
FUNCTION = "encode" | |
CATEGORY = "ExtraModels/HunyuanDiT" | |
TITLE = "Hunyuan DiT Text Encode" | |
def encode(self, text, text_t5, CLIP, T5): | |
# T5 | |
T5.load_model() | |
t5_pre = T5.tokenizer( | |
text_t5, | |
max_length = T5.cond_stage_model.max_length, | |
padding = 'max_length', | |
truncation = True, | |
return_attention_mask = True, | |
add_special_tokens = True, | |
return_tensors = 'pt' | |
) | |
t5_mask = t5_pre["attention_mask"] | |
with torch.no_grad(): | |
t5_outs = T5.cond_stage_model.transformer( | |
input_ids = t5_pre["input_ids"].to(T5.load_device), | |
attention_mask = t5_mask.to(T5.load_device), | |
output_hidden_states = True, | |
) | |
# to-do: replace -1 for clip skip | |
t5_embs = t5_outs["hidden_states"][-1].float().cpu() | |
# "clip" | |
CLIP.load_model() | |
clip_pre = CLIP.tokenizer( | |
text, | |
max_length = CLIP.cond_stage_model.max_length, | |
padding = 'max_length', | |
truncation = True, | |
return_attention_mask = True, | |
add_special_tokens = True, | |
return_tensors = 'pt' | |
) | |
clip_mask = clip_pre["attention_mask"] | |
with torch.no_grad(): | |
clip_outs = CLIP.cond_stage_model.transformer( | |
input_ids = clip_pre["input_ids"].to(CLIP.load_device), | |
attention_mask = clip_mask.to(CLIP.load_device), | |
) | |
# to-do: add hidden states | |
clip_embs = clip_outs[0].float().cpu() | |
# combined cond | |
return ([[ | |
clip_embs, { | |
"context_t5": t5_embs, | |
"context_mask": clip_mask.float(), | |
"context_t5_mask": t5_mask.float() | |
} | |
]],) | |
class HYDiTTextEncodeSimple(HYDiTTextEncode): | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"text": ("STRING", {"multiline": True}), | |
"CLIP": ("CLIP",), | |
"T5": ("T5",), | |
} | |
} | |
FUNCTION = "encode_simple" | |
TITLE = "Hunyuan DiT Text Encode (simple)" | |
def encode_simple(self, text, **args): | |
return self.encode(text=text, text_t5=text, **args) | |
class HYDiTSrcSizeCond: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"cond": ("CONDITIONING", ), | |
"width": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}), | |
"height": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
RETURN_NAMES = ("cond",) | |
FUNCTION = "add_cond" | |
CATEGORY = "ExtraModels/HunyuanDiT" | |
TITLE = "Hunyuan DiT Size Conditioning (advanced)" | |
def add_cond(self, cond, width, height): | |
cond = deepcopy(cond) | |
for c in range(len(cond)): | |
cond[c][1].update({ | |
"src_size_cond": [[height, width]], | |
}) | |
return (cond,) | |
NODE_CLASS_MAPPINGS = { | |
"HYDiTCheckpointLoader": HYDiTCheckpointLoader, | |
"HYDiTTextEncoderLoader": HYDiTTextEncoderLoader, | |
"HYDiTTextEncode": HYDiTTextEncode, | |
"HYDiTTextEncodeSimple": HYDiTTextEncodeSimple, | |
"HYDiTSrcSizeCond": HYDiTSrcSizeCond, | |
} | |