Spaces:
Configuration error
Configuration error
import gc | |
import json | |
import os | |
import random | |
import re | |
import subprocess | |
import sys | |
from types import MethodType | |
import torch | |
import folder_paths | |
import comfy.model_management as mm | |
def chatglm3_text_encode(chatglm3_model, prompt): | |
device = mm.get_torch_device() | |
offload_device = mm.unet_offload_device() | |
mm.unload_all_models() | |
mm.soft_empty_cache() | |
# Function to randomly select an option from the brackets | |
def choose_random_option(match): | |
options = match.group(1).split('|') | |
return random.choice(options) | |
prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt) | |
# Define tokenizers and text encoders | |
tokenizer = chatglm3_model['tokenizer'] | |
text_encoder = chatglm3_model['text_encoder'] | |
text_encoder.to(device) | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=256, | |
truncation=True, | |
return_tensors="pt", | |
).to(device) | |
output = text_encoder( | |
input_ids=text_inputs['input_ids'], | |
attention_mask=text_inputs['attention_mask'], | |
position_ids=text_inputs['position_ids'], | |
output_hidden_states=True) | |
# [batch_size, 77, 4096] | |
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() | |
text_proj = output.hidden_states[-1][-1, | |
:, :].clone() # [batch_size, 4096] | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
prompt_embeds = prompt_embeds.repeat(1, 1, 1) | |
prompt_embeds = prompt_embeds.view( | |
bs_embed, seq_len, -1) | |
bs_embed = text_proj.shape[0] | |
text_proj = text_proj.repeat(1, 1).view( | |
bs_embed, -1 | |
) | |
text_encoder.to(offload_device) | |
mm.soft_empty_cache() | |
gc.collect() | |
return prompt_embeds, text_proj | |
def MZ_ChatGLM3Loader_call(args): | |
# from .mz_kolors_utils import Utils | |
# llm_dir = os.path.join(Utils.get_models_path(), "LLM") | |
chatglm3_checkpoint = args.get("chatglm3_checkpoint") | |
chatglm3_checkpoint_path = folder_paths.get_full_path( | |
'LLM', chatglm3_checkpoint) | |
if not os.path.exists(chatglm3_checkpoint_path): | |
raise RuntimeError( | |
f"ERROR: Could not find chatglm3 checkpoint: {chatglm3_checkpoint_path}") | |
from .chatglm3.configuration_chatglm import ChatGLMConfig | |
from .chatglm3.modeling_chatglm import ChatGLMModel | |
from .chatglm3.tokenization_chatglm import ChatGLMTokenizer | |
offload_device = mm.unet_offload_device() | |
text_encoder_config = os.path.join( | |
os.path.dirname(__file__), 'configs', 'text_encoder_config.json') | |
with open(text_encoder_config, 'r') as file: | |
config = json.load(file) | |
text_encoder_config = ChatGLMConfig(**config) | |
from comfy.utils import load_torch_file | |
from contextlib import nullcontext | |
is_accelerate_available = False | |
try: | |
from accelerate import init_empty_weights | |
from accelerate.utils import set_module_tensor_to_device | |
is_accelerate_available = True | |
except: | |
pass | |
with (init_empty_weights() if is_accelerate_available else nullcontext()): | |
with torch.no_grad(): | |
# 打印版本号 | |
print("torch version:", torch.__version__) | |
text_encoder = ChatGLMModel(text_encoder_config).eval() | |
if '4bit' in chatglm3_checkpoint: | |
try: | |
import cpm_kernels | |
except ImportError: | |
print("Installing cpm_kernels...") | |
subprocess.run( | |
[sys.executable, "-m", "pip", "install", "cpm_kernels"], check=True) | |
pass | |
text_encoder.quantize(4) | |
elif '8bit' in chatglm3_checkpoint: | |
text_encoder.quantize(8) | |
text_encoder_sd = load_torch_file(chatglm3_checkpoint_path) | |
if is_accelerate_available: | |
for key in text_encoder_sd: | |
set_module_tensor_to_device( | |
text_encoder, key, device=offload_device, value=text_encoder_sd[key]) | |
else: | |
print("WARNING: Accelerate not available, use load_state_dict load model") | |
text_encoder.load_state_dict(text_encoder_sd) | |
tokenizer_path = os.path.join( | |
os.path.dirname(__file__), 'configs', "tokenizer") | |
tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path) | |
return ({"text_encoder": text_encoder, "tokenizer": tokenizer},) | |
def MZ_ChatGLM3TextEncodeV2_call(args): | |
text = args.get("text") | |
chatglm3_model = args.get("chatglm3_model") | |
prompt_embeds, pooled_output = chatglm3_text_encode( | |
chatglm3_model, | |
text, | |
) | |
extra_kwargs = { | |
"pooled_output": pooled_output, | |
} | |
extra_cond_keys = [ | |
"width", | |
"height", | |
"crop_w", | |
"crop_h", | |
"target_width", | |
"target_height" | |
] | |
for key, value in args.items(): | |
if key in extra_cond_keys: | |
extra_kwargs[key] = value | |
return ([[ | |
prompt_embeds, | |
# {"pooled_output": pooled_output}, | |
extra_kwargs | |
]], ) | |
def MZ_ChatGLM3Embeds2Conditioning_call(args): | |
kolors_embeds = args.get("kolors_embeds") | |
# kolors_embeds = { | |
# 'prompt_embeds': prompt_embeds, | |
# 'negative_prompt_embeds': negative_prompt_embeds, | |
# 'pooled_prompt_embeds': text_proj, | |
# 'negative_pooled_prompt_embeds': negative_text_proj | |
# } | |
positive = [[ | |
kolors_embeds['prompt_embeds'], | |
{ | |
"pooled_output": kolors_embeds['pooled_prompt_embeds'], | |
"width": args.get("width"), | |
"height": args.get("height"), | |
"crop_w": args.get("crop_w"), | |
"crop_h": args.get("crop_h"), | |
"target_width": args.get("target_width"), | |
"target_height": args.get("target_height") | |
} | |
]] | |
negative = [[ | |
kolors_embeds['negative_prompt_embeds'], | |
{ | |
"pooled_output": kolors_embeds['negative_pooled_prompt_embeds'], | |
} | |
]] | |
return (positive, negative) | |
def MZ_KolorsUNETLoaderV2_call(kwargs): | |
from . import hook_comfyui_kolors_v2 | |
import comfy.sd | |
with hook_comfyui_kolors_v2.apply_kolors(): | |
unet_name = kwargs.get("unet_name") | |
unet_path = folder_paths.get_full_path("unet", unet_name) | |
import comfy.utils | |
sd = comfy.utils.load_torch_file(unet_path) | |
model = comfy.sd.load_unet_state_dict(sd) | |
if model is None: | |
raise RuntimeError( | |
"ERROR: Could not detect model type of: {}".format(unet_path)) | |
return (model, ) | |
def MZ_KolorsCheckpointLoaderSimple_call(kwargs): | |
checkpoint_name = kwargs.get("ckpt_name") | |
ckpt_path = folder_paths.get_full_path("checkpoints", checkpoint_name) | |
from . import hook_comfyui_kolors_v2 | |
import comfy.sd | |
with hook_comfyui_kolors_v2.apply_kolors(): | |
out = comfy.sd.load_checkpoint_guess_config( | |
ckpt_path, output_vae=True, output_clip=False, embedding_directory=folder_paths.get_folder_paths("embeddings")) | |
unet, _, vae = out[:3] | |
return (unet, vae) | |
from comfy.cldm.cldm import ControlNet | |
from comfy.controlnet import ControlLora | |
def MZ_KolorsControlNetLoader_call(kwargs): | |
control_net_name = kwargs.get("control_net_name") | |
controlnet_path = folder_paths.get_full_path( | |
"controlnet", control_net_name) | |
from torch import nn | |
from . import hook_comfyui_kolors_v2 | |
import comfy.controlnet | |
with hook_comfyui_kolors_v2.apply_kolors(): | |
control_net = comfy.controlnet.load_controlnet(controlnet_path) | |
return (control_net, ) | |
def MZ_KolorsControlNetPatch_call(kwargs): | |
import copy | |
from . import hook_comfyui_kolors_v2 | |
import comfy.model_management | |
import comfy.model_patcher | |
model = kwargs.get("model") | |
control_net = kwargs.get("control_net") | |
if hasattr(control_net, "control_model") and hasattr(control_net.control_model, "encoder_hid_proj"): | |
return (control_net,) | |
control_net = copy.deepcopy(control_net) | |
import comfy.controlnet | |
if isinstance(control_net, ControlLora): | |
del_keys = [] | |
for k in control_net.control_weights: | |
if k.startswith("label_emb.0.0."): | |
del_keys.append(k) | |
for k in del_keys: | |
control_net.control_weights.pop(k) | |
super_pre_run = ControlLora.pre_run | |
super_forward = ControlNet.forward | |
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): | |
with torch.cuda.amp.autocast(enabled=True): | |
context = self.encoder_hid_proj(context) | |
return super_forward(self, x, hint, timesteps, context, **kwargs) | |
def KolorsControlLora_pre_run(self, *args, **kwargs): | |
result = super_pre_run(self, *args, **kwargs) | |
if hasattr(self, "control_model"): | |
if hasattr(self.control_model, "encoder_hid_proj"): | |
return result | |
setattr(self.control_model, "encoder_hid_proj", | |
model.model.diffusion_model.encoder_hid_proj) | |
self.control_model.forward = MethodType( | |
KolorsControlNet_forward, self.control_model) | |
return result | |
control_net.pre_run = MethodType( | |
KolorsControlLora_pre_run, control_net) | |
super_copy = ControlLora.copy | |
def KolorsControlLora_copy(self, *args, **kwargs): | |
c = super_copy(self, *args, **kwargs) | |
c.pre_run = MethodType( | |
KolorsControlLora_pre_run, c) | |
return c | |
control_net.copy = MethodType( | |
KolorsControlLora_copy, control_net) | |
control_net = copy.deepcopy(control_net) | |
elif isinstance(control_net, comfy.controlnet.ControlNet): | |
model_label_emb = model.model.diffusion_model.label_emb | |
control_net.control_model.label_emb = model_label_emb | |
setattr(control_net.control_model, "encoder_hid_proj", | |
model.model.diffusion_model.encoder_hid_proj) | |
control_net.control_model_wrapped = comfy.model_patcher.ModelPatcher( | |
control_net.control_model, load_device=control_net.load_device, offload_device=comfy.model_management.unet_offload_device()) | |
super_forward = ControlNet.forward | |
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): | |
with torch.cuda.amp.autocast(enabled=True): | |
context = self.encoder_hid_proj(context) | |
return super_forward(self, x, hint, timesteps, context, **kwargs) | |
control_net.control_model.forward = MethodType( | |
KolorsControlNet_forward, control_net.control_model) | |
else: | |
raise NotImplementedError( | |
f"Type {control_net} not supported for KolorsControlNetPatch") | |
return (control_net,) | |
def MZ_KolorsCLIPVisionLoader_call(kwargs): | |
import comfy.clip_vision | |
from . import hook_comfyui_kolors_v2 | |
clip_name = kwargs.get("clip_name") | |
clip_path = folder_paths.get_full_path("clip_vision", clip_name) | |
with hook_comfyui_kolors_v2.apply_kolors(): | |
clip_vision = comfy.clip_vision.load(clip_path) | |
return (clip_vision,) | |
def MZ_ApplySDXLSamplingSettings_call(kwargs): | |
model = kwargs.get("model").clone() | |
import comfy.model_sampling | |
sampling_base = comfy.model_sampling.ModelSamplingDiscrete | |
sampling_type = comfy.model_sampling.EPS | |
class SDXLSampling(sampling_base, sampling_type): | |
pass | |
model.model.model_config.sampling_settings["beta_schedule"] = "linear" | |
model.model.model_config.sampling_settings["linear_start"] = 0.00085 | |
model.model.model_config.sampling_settings["linear_end"] = 0.012 | |
model.model.model_config.sampling_settings["timesteps"] = 1000 | |
model_sampling = SDXLSampling(model.model.model_config) | |
model.add_object_patch("model_sampling", model_sampling) | |
return (model,) | |
def MZ_ApplyCUDAGenerator_call(kwargs): | |
model = kwargs.get("model") | |
def prepare_noise(latent_image, seed, noise_inds=None): | |
""" | |
creates random noise given a latent image and a seed. | |
optional arg skip can be used to skip and discard x number of noise generations for a given seed | |
""" | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
if noise_inds is None: | |
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cuda") | |
unique_inds, inverse = np.unique(noise_inds, return_inverse=True) | |
noises = [] | |
for i in range(unique_inds[-1] + 1): | |
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, | |
layout=latent_image.layout, generator=generator, device="cuda") | |
if i in unique_inds: | |
noises.append(noise) | |
noises = [noises[i] for i in inverse] | |
noises = torch.cat(noises, axis=0) | |
return noises | |
import comfy.sample | |
comfy.sample.prepare_noise = prepare_noise | |
return (model,) | |