import gc import json import os import random import re import torch import folder_paths import comfy.model_management as mm from . import mz_kolors_core def MZ_ChatGLM3TextEncode_call(args): text = args.get("text") chatglm3_model = args.get("chatglm3_model") prompt_embeds, pooled_output = mz_kolors_core.chatglm3_text_encode( chatglm3_model, text, ) from torch import nn hid_proj: nn.Linear = args.get("hid_proj") if hid_proj.weight.dtype != prompt_embeds.dtype: with torch.cuda.amp.autocast(dtype=hid_proj.weight.dtype): prompt_embeds = hid_proj(prompt_embeds) else: prompt_embeds = hid_proj(prompt_embeds) return ([[ prompt_embeds, {"pooled_output": pooled_output}, ]], ) def load_unet_state_dict(sd): # load unet in diffusers or regular format from comfy import model_management, model_detection import comfy.utils # Allow loading unets from checkpoint files checkpoint = False diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) temp_sd = comfy.utils.state_dict_prefix_replace( sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd checkpoint = True parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() from torch import nn hid_proj: nn.Linear = None if True: model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None: return None diffusers_keys = comfy.utils.unet_to_diffusers( model_config.unet_config) new_sd = {} for k in diffusers_keys: if k in sd: new_sd[diffusers_keys[k]] = sd.pop(k) else: print("{} {}".format(diffusers_keys[k], k)) encoder_hid_proj_weight = sd.pop("encoder_hid_proj.weight") encoder_hid_proj_bias = sd.pop("encoder_hid_proj.bias") hid_proj = nn.Linear( encoder_hid_proj_weight.shape[1], encoder_hid_proj_weight.shape[0]) hid_proj.weight.data = encoder_hid_proj_weight hid_proj.bias.data = encoder_hid_proj_bias hid_proj = hid_proj.to(load_device) offload_device = model_management.unet_offload_device() unet_dtype = model_management.unet_dtype( model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast( unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: print("left over keys in unet: {}".format(left_over)) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device), hid_proj def MZ_KolorsUNETLoader_call(kwargs): from . import hook_comfyui_kolors_v1 with hook_comfyui_kolors_v1.apply_kolors(): unet_name = kwargs.get("unet_name") unet_path = folder_paths.get_full_path("unet", unet_name) import comfy.utils sd = comfy.utils.load_torch_file(unet_path) model, hid_proj = load_unet_state_dict(sd) if model is None: raise RuntimeError( "ERROR: Could not detect model type of: {}".format(unet_path)) return (model, hid_proj) def MZ_FakeCond_call(kwargs): import torch cond = torch.zeros(2, 256, 4096) pool = torch.zeros(2, 4096) dtype = kwargs.get("dtype") if dtype == "fp16": print("fp16") cond = cond.half() pool = pool.half() elif dtype == "bf16": print("bf16") cond = cond.bfloat16() pool = pool.bfloat16() else: print("fp32") cond = cond.float() pool = pool.float() return ([[ cond, {"pooled_output": pool}, ]],) NODE_CLASS_MAPPINGS = { } NODE_DISPLAY_NAME_MAPPINGS = { } AUTHOR_NAME = "MinusZone" CATEGORY_NAME = f"{AUTHOR_NAME} - Kolors" class MZ_ChatGLM3TextEncode: @classmethod def INPUT_TYPES(s): return { "required": { "chatglm3_model": ("CHATGLM3MODEL", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "hid_proj": ("TorchLinear", ), } } RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" CATEGORY = CATEGORY_NAME + "/Legacy" def encode(self, **kwargs): return MZ_ChatGLM3TextEncode_call(kwargs) NODE_CLASS_MAPPINGS["MZ_ChatGLM3"] = MZ_ChatGLM3TextEncode NODE_DISPLAY_NAME_MAPPINGS[ "MZ_ChatGLM3"] = f"{AUTHOR_NAME} - ChatGLM3TextEncode" class MZ_KolorsUNETLoader(): @classmethod def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), }} RETURN_TYPES = ("MODEL", "TorchLinear") RETURN_NAMES = ("model", "hid_proj") FUNCTION = "load_unet" CATEGORY = CATEGORY_NAME + "/Legacy" def load_unet(self, **kwargs): return MZ_KolorsUNETLoader_call(kwargs) NODE_CLASS_MAPPINGS["MZ_KolorsUNETLoader"] = MZ_KolorsUNETLoader NODE_DISPLAY_NAME_MAPPINGS[ "MZ_KolorsUNETLoader"] = f"{AUTHOR_NAME} - Kolors UNET Loader" class MZ_FakeCond: @classmethod def INPUT_TYPES(s): return { "required": { "seed": ("INT", {"default": 0}), "dtype": ([ "fp32", "fp16", "bf16", ],), } } RETURN_TYPES = ("CONDITIONING", ) RETURN_NAMES = ("prompt", ) FUNCTION = "encode" CATEGORY = CATEGORY_NAME def encode(self, **kwargs): return MZ_FakeCond_call(kwargs) try: if os.environ.get("MZ_DEV", None) is not None: NODE_CLASS_MAPPINGS["MZ_FakeCond"] = MZ_FakeCond NODE_DISPLAY_NAME_MAPPINGS[ "MZ_FakeCond"] = f"{AUTHOR_NAME} - FakeCond" except ImportError: pass