import gc import json import os import random import re import subprocess import sys from types import MethodType import torch import folder_paths import comfy.model_management as mm def chatglm3_text_encode(chatglm3_model, prompt): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.unload_all_models() mm.soft_empty_cache() # Function to randomly select an option from the brackets def choose_random_option(match): options = match.group(1).split('|') return random.choice(options) prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt) # Define tokenizers and text encoders tokenizer = chatglm3_model['tokenizer'] text_encoder = chatglm3_model['text_encoder'] text_encoder.to(device) text_inputs = tokenizer( prompt, padding="max_length", max_length=256, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=text_inputs['input_ids'], attention_mask=text_inputs['attention_mask'], position_ids=text_inputs['position_ids'], output_hidden_states=True) # [batch_size, 77, 4096] prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, 1, 1) prompt_embeds = prompt_embeds.view( bs_embed, seq_len, -1) bs_embed = text_proj.shape[0] text_proj = text_proj.repeat(1, 1).view( bs_embed, -1 ) text_encoder.to(offload_device) mm.soft_empty_cache() gc.collect() return prompt_embeds, text_proj def MZ_ChatGLM3Loader_call(args): # from .mz_kolors_utils import Utils # llm_dir = os.path.join(Utils.get_models_path(), "LLM") chatglm3_checkpoint = args.get("chatglm3_checkpoint") chatglm3_checkpoint_path = folder_paths.get_full_path( 'LLM', chatglm3_checkpoint) if not os.path.exists(chatglm3_checkpoint_path): raise RuntimeError( f"ERROR: Could not find chatglm3 checkpoint: {chatglm3_checkpoint_path}") from .chatglm3.configuration_chatglm import ChatGLMConfig from .chatglm3.modeling_chatglm import ChatGLMModel from .chatglm3.tokenization_chatglm import ChatGLMTokenizer offload_device = mm.unet_offload_device() text_encoder_config = os.path.join( os.path.dirname(__file__), 'configs', 'text_encoder_config.json') with open(text_encoder_config, 'r') as file: config = json.load(file) text_encoder_config = ChatGLMConfig(**config) from comfy.utils import load_torch_file from contextlib import nullcontext is_accelerate_available = False try: from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device is_accelerate_available = True except: pass with (init_empty_weights() if is_accelerate_available else nullcontext()): with torch.no_grad(): # 打印版本号 print("torch version:", torch.__version__) text_encoder = ChatGLMModel(text_encoder_config).eval() if '4bit' in chatglm3_checkpoint: try: import cpm_kernels except ImportError: print("Installing cpm_kernels...") subprocess.run( [sys.executable, "-m", "pip", "install", "cpm_kernels"], check=True) pass text_encoder.quantize(4) elif '8bit' in chatglm3_checkpoint: text_encoder.quantize(8) text_encoder_sd = load_torch_file(chatglm3_checkpoint_path) if is_accelerate_available: for key in text_encoder_sd: set_module_tensor_to_device( text_encoder, key, device=offload_device, value=text_encoder_sd[key]) else: print("WARNING: Accelerate not available, use load_state_dict load model") text_encoder.load_state_dict(text_encoder_sd) tokenizer_path = os.path.join( os.path.dirname(__file__), 'configs', "tokenizer") tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path) return ({"text_encoder": text_encoder, "tokenizer": tokenizer},) def MZ_ChatGLM3TextEncodeV2_call(args): text = args.get("text") chatglm3_model = args.get("chatglm3_model") prompt_embeds, pooled_output = chatglm3_text_encode( chatglm3_model, text, ) extra_kwargs = { "pooled_output": pooled_output, } extra_cond_keys = [ "width", "height", "crop_w", "crop_h", "target_width", "target_height" ] for key, value in args.items(): if key in extra_cond_keys: extra_kwargs[key] = value return ([[ prompt_embeds, # {"pooled_output": pooled_output}, extra_kwargs ]], ) def MZ_ChatGLM3Embeds2Conditioning_call(args): kolors_embeds = args.get("kolors_embeds") # kolors_embeds = { # 'prompt_embeds': prompt_embeds, # 'negative_prompt_embeds': negative_prompt_embeds, # 'pooled_prompt_embeds': text_proj, # 'negative_pooled_prompt_embeds': negative_text_proj # } positive = [[ kolors_embeds['prompt_embeds'], { "pooled_output": kolors_embeds['pooled_prompt_embeds'], "width": args.get("width"), "height": args.get("height"), "crop_w": args.get("crop_w"), "crop_h": args.get("crop_h"), "target_width": args.get("target_width"), "target_height": args.get("target_height") } ]] negative = [[ kolors_embeds['negative_prompt_embeds'], { "pooled_output": kolors_embeds['negative_pooled_prompt_embeds'], } ]] return (positive, negative) def MZ_KolorsUNETLoaderV2_call(kwargs): from . import hook_comfyui_kolors_v2 import comfy.sd with hook_comfyui_kolors_v2.apply_kolors(): unet_name = kwargs.get("unet_name") unet_path = folder_paths.get_full_path("unet", unet_name) import comfy.utils sd = comfy.utils.load_torch_file(unet_path) model = comfy.sd.load_unet_state_dict(sd) if model is None: raise RuntimeError( "ERROR: Could not detect model type of: {}".format(unet_path)) return (model, ) def MZ_KolorsCheckpointLoaderSimple_call(kwargs): checkpoint_name = kwargs.get("ckpt_name") ckpt_path = folder_paths.get_full_path("checkpoints", checkpoint_name) from . import hook_comfyui_kolors_v2 import comfy.sd with hook_comfyui_kolors_v2.apply_kolors(): out = comfy.sd.load_checkpoint_guess_config( ckpt_path, output_vae=True, output_clip=False, embedding_directory=folder_paths.get_folder_paths("embeddings")) unet, _, vae = out[:3] return (unet, vae) from comfy.cldm.cldm import ControlNet from comfy.controlnet import ControlLora def MZ_KolorsControlNetLoader_call(kwargs): control_net_name = kwargs.get("control_net_name") controlnet_path = folder_paths.get_full_path( "controlnet", control_net_name) from torch import nn from . import hook_comfyui_kolors_v2 import comfy.controlnet with hook_comfyui_kolors_v2.apply_kolors(): control_net = comfy.controlnet.load_controlnet(controlnet_path) return (control_net, ) def MZ_KolorsControlNetPatch_call(kwargs): import copy from . import hook_comfyui_kolors_v2 import comfy.model_management import comfy.model_patcher model = kwargs.get("model") control_net = kwargs.get("control_net") if hasattr(control_net, "control_model") and hasattr(control_net.control_model, "encoder_hid_proj"): return (control_net,) control_net = copy.deepcopy(control_net) import comfy.controlnet if isinstance(control_net, ControlLora): del_keys = [] for k in control_net.control_weights: if k.startswith("label_emb.0.0."): del_keys.append(k) for k in del_keys: control_net.control_weights.pop(k) super_pre_run = ControlLora.pre_run super_forward = ControlNet.forward def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): with torch.cuda.amp.autocast(enabled=True): context = self.encoder_hid_proj(context) return super_forward(self, x, hint, timesteps, context, **kwargs) def KolorsControlLora_pre_run(self, *args, **kwargs): result = super_pre_run(self, *args, **kwargs) if hasattr(self, "control_model"): if hasattr(self.control_model, "encoder_hid_proj"): return result setattr(self.control_model, "encoder_hid_proj", model.model.diffusion_model.encoder_hid_proj) self.control_model.forward = MethodType( KolorsControlNet_forward, self.control_model) return result control_net.pre_run = MethodType( KolorsControlLora_pre_run, control_net) super_copy = ControlLora.copy def KolorsControlLora_copy(self, *args, **kwargs): c = super_copy(self, *args, **kwargs) c.pre_run = MethodType( KolorsControlLora_pre_run, c) return c control_net.copy = MethodType( KolorsControlLora_copy, control_net) control_net = copy.deepcopy(control_net) elif isinstance(control_net, comfy.controlnet.ControlNet): model_label_emb = model.model.diffusion_model.label_emb control_net.control_model.label_emb = model_label_emb setattr(control_net.control_model, "encoder_hid_proj", model.model.diffusion_model.encoder_hid_proj) control_net.control_model_wrapped = comfy.model_patcher.ModelPatcher( control_net.control_model, load_device=control_net.load_device, offload_device=comfy.model_management.unet_offload_device()) super_forward = ControlNet.forward def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): with torch.cuda.amp.autocast(enabled=True): context = self.encoder_hid_proj(context) return super_forward(self, x, hint, timesteps, context, **kwargs) control_net.control_model.forward = MethodType( KolorsControlNet_forward, control_net.control_model) else: raise NotImplementedError( f"Type {control_net} not supported for KolorsControlNetPatch") return (control_net,) def MZ_KolorsCLIPVisionLoader_call(kwargs): import comfy.clip_vision from . import hook_comfyui_kolors_v2 clip_name = kwargs.get("clip_name") clip_path = folder_paths.get_full_path("clip_vision", clip_name) with hook_comfyui_kolors_v2.apply_kolors(): clip_vision = comfy.clip_vision.load(clip_path) return (clip_vision,) def MZ_ApplySDXLSamplingSettings_call(kwargs): model = kwargs.get("model").clone() import comfy.model_sampling sampling_base = comfy.model_sampling.ModelSamplingDiscrete sampling_type = comfy.model_sampling.EPS class SDXLSampling(sampling_base, sampling_type): pass model.model.model_config.sampling_settings["beta_schedule"] = "linear" model.model.model_config.sampling_settings["linear_start"] = 0.00085 model.model.model_config.sampling_settings["linear_end"] = 0.012 model.model.model_config.sampling_settings["timesteps"] = 1000 model_sampling = SDXLSampling(model.model.model_config) model.add_object_patch("model_sampling", model_sampling) return (model,) def MZ_ApplyCUDAGenerator_call(kwargs): model = kwargs.get("model") def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.Generator(device="cuda").manual_seed(seed) if noise_inds is None: return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cuda") unique_inds, inverse = np.unique(noise_inds, return_inverse=True) noises = [] for i in range(unique_inds[-1] + 1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cuda") if i in unique_inds: noises.append(noise) noises = [noises[i] for i in inverse] noises = torch.cat(noises, axis=0) return noises import comfy.sample comfy.sample.prepare_noise = prepare_noise return (model,)