Spaces:
Configuration error
Configuration error
import gc | |
import json | |
import os | |
import random | |
import re | |
import torch | |
import folder_paths | |
import comfy.model_management as mm | |
from . import mz_kolors_core | |
def MZ_ChatGLM3TextEncode_call(args): | |
text = args.get("text") | |
chatglm3_model = args.get("chatglm3_model") | |
prompt_embeds, pooled_output = mz_kolors_core.chatglm3_text_encode( | |
chatglm3_model, | |
text, | |
) | |
from torch import nn | |
hid_proj: nn.Linear = args.get("hid_proj") | |
if hid_proj.weight.dtype != prompt_embeds.dtype: | |
with torch.cuda.amp.autocast(dtype=hid_proj.weight.dtype): | |
prompt_embeds = hid_proj(prompt_embeds) | |
else: | |
prompt_embeds = hid_proj(prompt_embeds) | |
return ([[ | |
prompt_embeds, | |
{"pooled_output": pooled_output}, | |
]], ) | |
def load_unet_state_dict(sd): # load unet in diffusers or regular format | |
from comfy import model_management, model_detection | |
import comfy.utils | |
# Allow loading unets from checkpoint files | |
checkpoint = False | |
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) | |
temp_sd = comfy.utils.state_dict_prefix_replace( | |
sd, {diffusion_model_prefix: ""}, filter_keys=True) | |
if len(temp_sd) > 0: | |
sd = temp_sd | |
checkpoint = True | |
parameters = comfy.utils.calculate_parameters(sd) | |
unet_dtype = model_management.unet_dtype(model_params=parameters) | |
load_device = model_management.get_torch_device() | |
from torch import nn | |
hid_proj: nn.Linear = None | |
if True: | |
model_config = model_detection.model_config_from_diffusers_unet(sd) | |
if model_config is None: | |
return None | |
diffusers_keys = comfy.utils.unet_to_diffusers( | |
model_config.unet_config) | |
new_sd = {} | |
for k in diffusers_keys: | |
if k in sd: | |
new_sd[diffusers_keys[k]] = sd.pop(k) | |
else: | |
print("{} {}".format(diffusers_keys[k], k)) | |
encoder_hid_proj_weight = sd.pop("encoder_hid_proj.weight") | |
encoder_hid_proj_bias = sd.pop("encoder_hid_proj.bias") | |
hid_proj = nn.Linear( | |
encoder_hid_proj_weight.shape[1], encoder_hid_proj_weight.shape[0]) | |
hid_proj.weight.data = encoder_hid_proj_weight | |
hid_proj.bias.data = encoder_hid_proj_bias | |
hid_proj = hid_proj.to(load_device) | |
offload_device = model_management.unet_offload_device() | |
unet_dtype = model_management.unet_dtype( | |
model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) | |
manual_cast_dtype = model_management.unet_manual_cast( | |
unet_dtype, load_device, model_config.supported_inference_dtypes) | |
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) | |
model = model_config.get_model(new_sd, "") | |
model = model.to(offload_device) | |
model.load_model_weights(new_sd, "") | |
left_over = sd.keys() | |
if len(left_over) > 0: | |
print("left over keys in unet: {}".format(left_over)) | |
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device), hid_proj | |
def MZ_KolorsUNETLoader_call(kwargs): | |
from . import hook_comfyui_kolors_v1 | |
with hook_comfyui_kolors_v1.apply_kolors(): | |
unet_name = kwargs.get("unet_name") | |
unet_path = folder_paths.get_full_path("unet", unet_name) | |
import comfy.utils | |
sd = comfy.utils.load_torch_file(unet_path) | |
model, hid_proj = load_unet_state_dict(sd) | |
if model is None: | |
raise RuntimeError( | |
"ERROR: Could not detect model type of: {}".format(unet_path)) | |
return (model, hid_proj) | |
def MZ_FakeCond_call(kwargs): | |
import torch | |
cond = torch.zeros(2, 256, 4096) | |
pool = torch.zeros(2, 4096) | |
dtype = kwargs.get("dtype") | |
if dtype == "fp16": | |
print("fp16") | |
cond = cond.half() | |
pool = pool.half() | |
elif dtype == "bf16": | |
print("bf16") | |
cond = cond.bfloat16() | |
pool = pool.bfloat16() | |
else: | |
print("fp32") | |
cond = cond.float() | |
pool = pool.float() | |
return ([[ | |
cond, | |
{"pooled_output": pool}, | |
]],) | |
NODE_CLASS_MAPPINGS = { | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
} | |
AUTHOR_NAME = "MinusZone" | |
CATEGORY_NAME = f"{AUTHOR_NAME} - Kolors" | |
class MZ_ChatGLM3TextEncode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"chatglm3_model": ("CHATGLM3MODEL", ), | |
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), | |
"hid_proj": ("TorchLinear", ), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
FUNCTION = "encode" | |
CATEGORY = CATEGORY_NAME + "/Legacy" | |
def encode(self, **kwargs): | |
return MZ_ChatGLM3TextEncode_call(kwargs) | |
NODE_CLASS_MAPPINGS["MZ_ChatGLM3"] = MZ_ChatGLM3TextEncode | |
NODE_DISPLAY_NAME_MAPPINGS[ | |
"MZ_ChatGLM3"] = f"{AUTHOR_NAME} - ChatGLM3TextEncode" | |
class MZ_KolorsUNETLoader(): | |
def INPUT_TYPES(s): | |
return {"required": { | |
"unet_name": (folder_paths.get_filename_list("unet"), ), | |
}} | |
RETURN_TYPES = ("MODEL", "TorchLinear") | |
RETURN_NAMES = ("model", "hid_proj") | |
FUNCTION = "load_unet" | |
CATEGORY = CATEGORY_NAME + "/Legacy" | |
def load_unet(self, **kwargs): | |
return MZ_KolorsUNETLoader_call(kwargs) | |
NODE_CLASS_MAPPINGS["MZ_KolorsUNETLoader"] = MZ_KolorsUNETLoader | |
NODE_DISPLAY_NAME_MAPPINGS[ | |
"MZ_KolorsUNETLoader"] = f"{AUTHOR_NAME} - Kolors UNET Loader" | |
class MZ_FakeCond: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"seed": ("INT", {"default": 0}), | |
"dtype": ([ | |
"fp32", | |
"fp16", | |
"bf16", | |
],), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING", ) | |
RETURN_NAMES = ("prompt", ) | |
FUNCTION = "encode" | |
CATEGORY = CATEGORY_NAME | |
def encode(self, **kwargs): | |
return MZ_FakeCond_call(kwargs) | |
try: | |
if os.environ.get("MZ_DEV", None) is not None: | |
NODE_CLASS_MAPPINGS["MZ_FakeCond"] = MZ_FakeCond | |
NODE_DISPLAY_NAME_MAPPINGS[ | |
"MZ_FakeCond"] = f"{AUTHOR_NAME} - FakeCond" | |
except ImportError: | |
pass | |