FooocusEnhanced / ComfyUI-Kolors-MZ /mz_kolors_core.py
JasonSmithSO's picture
Upload 578 files
8866644 verified
import gc
import json
import os
import random
import re
import subprocess
import sys
from types import MethodType
import torch
import folder_paths
import comfy.model_management as mm
def chatglm3_text_encode(chatglm3_model, prompt):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.unload_all_models()
mm.soft_empty_cache()
# Function to randomly select an option from the brackets
def choose_random_option(match):
options = match.group(1).split('|')
return random.choice(options)
prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt)
# Define tokenizers and text encoders
tokenizer = chatglm3_model['tokenizer']
text_encoder = chatglm3_model['text_encoder']
text_encoder.to(device)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
).to(device)
output = text_encoder(
input_ids=text_inputs['input_ids'],
attention_mask=text_inputs['attention_mask'],
position_ids=text_inputs['position_ids'],
output_hidden_states=True)
# [batch_size, 77, 4096]
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1,
:, :].clone() # [batch_size, 4096]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(
bs_embed, seq_len, -1)
bs_embed = text_proj.shape[0]
text_proj = text_proj.repeat(1, 1).view(
bs_embed, -1
)
text_encoder.to(offload_device)
mm.soft_empty_cache()
gc.collect()
return prompt_embeds, text_proj
def MZ_ChatGLM3Loader_call(args):
# from .mz_kolors_utils import Utils
# llm_dir = os.path.join(Utils.get_models_path(), "LLM")
chatglm3_checkpoint = args.get("chatglm3_checkpoint")
chatglm3_checkpoint_path = folder_paths.get_full_path(
'LLM', chatglm3_checkpoint)
if not os.path.exists(chatglm3_checkpoint_path):
raise RuntimeError(
f"ERROR: Could not find chatglm3 checkpoint: {chatglm3_checkpoint_path}")
from .chatglm3.configuration_chatglm import ChatGLMConfig
from .chatglm3.modeling_chatglm import ChatGLMModel
from .chatglm3.tokenization_chatglm import ChatGLMTokenizer
offload_device = mm.unet_offload_device()
text_encoder_config = os.path.join(
os.path.dirname(__file__), 'configs', 'text_encoder_config.json')
with open(text_encoder_config, 'r') as file:
config = json.load(file)
text_encoder_config = ChatGLMConfig(**config)
from comfy.utils import load_torch_file
from contextlib import nullcontext
is_accelerate_available = False
try:
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
is_accelerate_available = True
except:
pass
with (init_empty_weights() if is_accelerate_available else nullcontext()):
with torch.no_grad():
# 打印版本号
print("torch version:", torch.__version__)
text_encoder = ChatGLMModel(text_encoder_config).eval()
if '4bit' in chatglm3_checkpoint:
try:
import cpm_kernels
except ImportError:
print("Installing cpm_kernels...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "cpm_kernels"], check=True)
pass
text_encoder.quantize(4)
elif '8bit' in chatglm3_checkpoint:
text_encoder.quantize(8)
text_encoder_sd = load_torch_file(chatglm3_checkpoint_path)
if is_accelerate_available:
for key in text_encoder_sd:
set_module_tensor_to_device(
text_encoder, key, device=offload_device, value=text_encoder_sd[key])
else:
print("WARNING: Accelerate not available, use load_state_dict load model")
text_encoder.load_state_dict(text_encoder_sd)
tokenizer_path = os.path.join(
os.path.dirname(__file__), 'configs', "tokenizer")
tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
return ({"text_encoder": text_encoder, "tokenizer": tokenizer},)
def MZ_ChatGLM3TextEncodeV2_call(args):
text = args.get("text")
chatglm3_model = args.get("chatglm3_model")
prompt_embeds, pooled_output = chatglm3_text_encode(
chatglm3_model,
text,
)
extra_kwargs = {
"pooled_output": pooled_output,
}
extra_cond_keys = [
"width",
"height",
"crop_w",
"crop_h",
"target_width",
"target_height"
]
for key, value in args.items():
if key in extra_cond_keys:
extra_kwargs[key] = value
return ([[
prompt_embeds,
# {"pooled_output": pooled_output},
extra_kwargs
]], )
def MZ_ChatGLM3Embeds2Conditioning_call(args):
kolors_embeds = args.get("kolors_embeds")
# kolors_embeds = {
# 'prompt_embeds': prompt_embeds,
# 'negative_prompt_embeds': negative_prompt_embeds,
# 'pooled_prompt_embeds': text_proj,
# 'negative_pooled_prompt_embeds': negative_text_proj
# }
positive = [[
kolors_embeds['prompt_embeds'],
{
"pooled_output": kolors_embeds['pooled_prompt_embeds'],
"width": args.get("width"),
"height": args.get("height"),
"crop_w": args.get("crop_w"),
"crop_h": args.get("crop_h"),
"target_width": args.get("target_width"),
"target_height": args.get("target_height")
}
]]
negative = [[
kolors_embeds['negative_prompt_embeds'],
{
"pooled_output": kolors_embeds['negative_pooled_prompt_embeds'],
}
]]
return (positive, negative)
def MZ_KolorsUNETLoaderV2_call(kwargs):
from . import hook_comfyui_kolors_v2
import comfy.sd
with hook_comfyui_kolors_v2.apply_kolors():
unet_name = kwargs.get("unet_name")
unet_path = folder_paths.get_full_path("unet", unet_name)
import comfy.utils
sd = comfy.utils.load_torch_file(unet_path)
model = comfy.sd.load_unet_state_dict(sd)
if model is None:
raise RuntimeError(
"ERROR: Could not detect model type of: {}".format(unet_path))
return (model, )
def MZ_KolorsCheckpointLoaderSimple_call(kwargs):
checkpoint_name = kwargs.get("ckpt_name")
ckpt_path = folder_paths.get_full_path("checkpoints", checkpoint_name)
from . import hook_comfyui_kolors_v2
import comfy.sd
with hook_comfyui_kolors_v2.apply_kolors():
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path, output_vae=True, output_clip=False, embedding_directory=folder_paths.get_folder_paths("embeddings"))
unet, _, vae = out[:3]
return (unet, vae)
from comfy.cldm.cldm import ControlNet
from comfy.controlnet import ControlLora
def MZ_KolorsControlNetLoader_call(kwargs):
control_net_name = kwargs.get("control_net_name")
controlnet_path = folder_paths.get_full_path(
"controlnet", control_net_name)
from torch import nn
from . import hook_comfyui_kolors_v2
import comfy.controlnet
with hook_comfyui_kolors_v2.apply_kolors():
control_net = comfy.controlnet.load_controlnet(controlnet_path)
return (control_net, )
def MZ_KolorsControlNetPatch_call(kwargs):
import copy
from . import hook_comfyui_kolors_v2
import comfy.model_management
import comfy.model_patcher
model = kwargs.get("model")
control_net = kwargs.get("control_net")
if hasattr(control_net, "control_model") and hasattr(control_net.control_model, "encoder_hid_proj"):
return (control_net,)
control_net = copy.deepcopy(control_net)
import comfy.controlnet
if isinstance(control_net, ControlLora):
del_keys = []
for k in control_net.control_weights:
if k.startswith("label_emb.0.0."):
del_keys.append(k)
for k in del_keys:
control_net.control_weights.pop(k)
super_pre_run = ControlLora.pre_run
super_forward = ControlNet.forward
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
context = self.encoder_hid_proj(context)
return super_forward(self, x, hint, timesteps, context, **kwargs)
def KolorsControlLora_pre_run(self, *args, **kwargs):
result = super_pre_run(self, *args, **kwargs)
if hasattr(self, "control_model"):
if hasattr(self.control_model, "encoder_hid_proj"):
return result
setattr(self.control_model, "encoder_hid_proj",
model.model.diffusion_model.encoder_hid_proj)
self.control_model.forward = MethodType(
KolorsControlNet_forward, self.control_model)
return result
control_net.pre_run = MethodType(
KolorsControlLora_pre_run, control_net)
super_copy = ControlLora.copy
def KolorsControlLora_copy(self, *args, **kwargs):
c = super_copy(self, *args, **kwargs)
c.pre_run = MethodType(
KolorsControlLora_pre_run, c)
return c
control_net.copy = MethodType(
KolorsControlLora_copy, control_net)
control_net = copy.deepcopy(control_net)
elif isinstance(control_net, comfy.controlnet.ControlNet):
model_label_emb = model.model.diffusion_model.label_emb
control_net.control_model.label_emb = model_label_emb
setattr(control_net.control_model, "encoder_hid_proj",
model.model.diffusion_model.encoder_hid_proj)
control_net.control_model_wrapped = comfy.model_patcher.ModelPatcher(
control_net.control_model, load_device=control_net.load_device, offload_device=comfy.model_management.unet_offload_device())
super_forward = ControlNet.forward
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
context = self.encoder_hid_proj(context)
return super_forward(self, x, hint, timesteps, context, **kwargs)
control_net.control_model.forward = MethodType(
KolorsControlNet_forward, control_net.control_model)
else:
raise NotImplementedError(
f"Type {control_net} not supported for KolorsControlNetPatch")
return (control_net,)
def MZ_KolorsCLIPVisionLoader_call(kwargs):
import comfy.clip_vision
from . import hook_comfyui_kolors_v2
clip_name = kwargs.get("clip_name")
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
with hook_comfyui_kolors_v2.apply_kolors():
clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,)
def MZ_ApplySDXLSamplingSettings_call(kwargs):
model = kwargs.get("model").clone()
import comfy.model_sampling
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
sampling_type = comfy.model_sampling.EPS
class SDXLSampling(sampling_base, sampling_type):
pass
model.model.model_config.sampling_settings["beta_schedule"] = "linear"
model.model.model_config.sampling_settings["linear_start"] = 0.00085
model.model.model_config.sampling_settings["linear_end"] = 0.012
model.model.model_config.sampling_settings["timesteps"] = 1000
model_sampling = SDXLSampling(model.model.model_config)
model.add_object_patch("model_sampling", model_sampling)
return (model,)
def MZ_ApplyCUDAGenerator_call(kwargs):
model = kwargs.get("model")
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.Generator(device="cuda").manual_seed(seed)
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cuda")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1] + 1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype,
layout=latent_image.layout, generator=generator, device="cuda")
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
import comfy.sample
comfy.sample.prepare_noise = prepare_noise
return (model,)