FooocusEnhanced / ComfyUI-Kolors-MZ /hook_comfyui_kolors_v2.py
JasonSmithSO's picture
Upload 578 files
8866644 verified
import os
from types import MethodType
import warnings
from comfy.model_detection import *
import comfy.model_detection as model_detection
import comfy.supported_models
import comfy.utils
import torch
from comfy import model_base
from comfy.model_base import sdxl_pooled, CLIPEmbeddingNoiseAugmentation, Timestep, ModelType
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.cldm.cldm import ControlNet
# try:
# import comfy.samplers as samplers
# original_CFGGuider_inner_set_conds = samplers.CFGGuider.set_conds
# def patched_set_conds(self, positive, negative):
# if isinstance(self.model_patcher.model, KolorsSDXL):
# import copy
# if "control" in positive[0][1]:
# if hasattr(positive[0][1]["control"], "control_model"):
# if positive[0][1]["control"].control_model.label_emb.shape[1] == 5632:
# return
# warnings.warn("该方法不再维护")
# positive = copy.deepcopy(positive)
# negative = copy.deepcopy(negative)
# hid_proj = self.model_patcher.model.encoder_hid_proj
# if hid_proj is not None:
# positive[0][0] = hid_proj(positive[0][0])
# negative[0][0] = hid_proj(negative[0][0])
# if "control" in positive[0][1]:
# if hasattr(positive[0][1]["control"], "control_model"):
# positive[0][1]["control"].control_model.label_emb = self.model_patcher.model.diffusion_model.label_emb
# if "control" in negative[0][1]:
# if hasattr(negative[0][1]["control"], "control_model"):
# negative[0][1]["control"].control_model.label_emb = self.model_patcher.model.diffusion_model.label_emb
# return original_CFGGuider_inner_set_conds(self, positive, negative)
# samplers.CFGGuider.set_conds = patched_set_conds
# except ImportError:
# print("CFGGuider not found, skipping patching")
class KolorsUNetModel(UNetModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encoder_hid_proj = nn.Linear(
4096, 2048, bias=True)
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
if "context" in kwargs:
kwargs["context"] = self.encoder_hid_proj(
kwargs["context"])
# if "y" in kwargs:
# if kwargs["y"].shape[1] == 2816:
# # 扩展至5632
# kwargs["y"] = torch.cat(
# torch.zeros(kwargs["y"].shape[0], 2816).to(kwargs["y"].device), kwargs["y"], dim=1)
result = super().forward(*args, **kwargs)
return result
class KolorsSDXL(model_base.SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
model_config.sampling_settings["beta_schedule"] = "linear"
model_config.sampling_settings["linear_start"] = 0.00085
model_config.sampling_settings["linear_end"] = 0.014
model_config.sampling_settings["timesteps"] = 1100
model_type = ModelType.EPS
model_base.BaseModel.__init__(
self, model_config, model_type, device=device, unet_model=KolorsUNetModel)
self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(
**{"noise_schedule_config": {"timesteps": 1100, "beta_schedule": "linear", "linear_start": 0.00085, "linear_end": 0.014}, "timestep_dim": 1280})
def encode_adm(self, **kwargs):
clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(
dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class KolorsSupported(comfy.supported_models.SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 5632,
"use_temporal_attention": False,
}
def get_model(self, state_dict, prefix="", device=None):
out = KolorsSDXL(self, model_type=self.model_type(
state_dict, prefix), device=device,)
out.__class__ = model_base.SDXL
if self.inpaint_model():
out.set_inpaint()
return out
def kolors_unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
attn_res = 1
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(
state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(
state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(
state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(
i, ab)].shape[1]
attn_res *= 2
if attn_blocks == 0:
for i in range(res_blocks):
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
Kolors = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
Kolors_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
Kolors_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [Kolors, Kolors_inpaint,
Kolors_ip2p, SDXL, SDXL_mid_cnet, SDXL_small_cnet]
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
print("key {} does not match".format(
k), match[k], "||", unet_config[k])
matches = False
break
if matches:
return convert_config(unet_config)
return None
import comfy.ldm.modules.diffusionmodules.openaimodel
from torch import nn
def load_clipvision_336_from_sd(sd, prefix="", convert_keys=False):
from comfy.clip_vision import ClipVisionModel, convert_to_transformers
json_config = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "clip_vit_336", "config.json")
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
logging.warning("missing clip vision: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
# def vis_forward(self, pixel_values, attention_mask=None, intermediate_output=None):
# pixel_values = nn.functional.interpolate(
# pixel_values, size=(336, 336), mode='bilinear', align_corners=False)
# x = self.embeddings(pixel_values)
# x = self.pre_layrnorm(x)
# # TODO: attention_mask?
# x, i = self.encoder(
# x, mask=None, intermediate_output=intermediate_output)
# pooled_output = self.post_layernorm(x[:, 0, :])
# return x, i, pooled_output
# clip.model.vision_model.forward = MethodType(
# vis_forward, clip.model.vision_model
# )
return clip
class KolorsControlNet(ControlNet):
def __init__(self, *args, **kwargs):
adm_in_channels = kwargs["adm_in_channels"]
if adm_in_channels == 2816:
# 异常: 该加载器不支持SDXL类型, 请使用ControlNet加载器+KolorsControlNetPatch节点
raise Exception(
"This loader does not support SDXL type, please use ControlNet loader + KolorsControlNetPatch node")
super().__init__(*args, **kwargs)
self.encoder_hid_proj = nn.Linear(
4096, 2048, bias=True)
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
if "context" in kwargs:
kwargs["context"] = self.encoder_hid_proj(
kwargs["context"])
result = super().forward(*args, **kwargs)
return result
class apply_kolors:
def __enter__(self):
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.cldm.cldm
import comfy.utils
import comfy.clip_vision
self.original_load_clipvision_from_sd = comfy.clip_vision.load_clipvision_from_sd
comfy.clip_vision.load_clipvision_from_sd = load_clipvision_336_from_sd
self.original_UNET_MAP_BASIC = comfy.utils.UNET_MAP_BASIC.copy()
comfy.utils.UNET_MAP_BASIC.add(
("encoder_hid_proj.weight", "encoder_hid_proj.weight"),
)
comfy.utils.UNET_MAP_BASIC.add(
("encoder_hid_proj.bias", "encoder_hid_proj.bias"),
)
self.original_unet_config_from_diffusers_unet = model_detection.unet_config_from_diffusers_unet
model_detection.unet_config_from_diffusers_unet = kolors_unet_config_from_diffusers_unet
import comfy.supported_models
self.original_supported_models = comfy.supported_models.models
comfy.supported_models.models = [KolorsSupported]
self.original_controlnet = comfy.cldm.cldm.ControlNet
comfy.cldm.cldm.ControlNet = KolorsControlNet
def __exit__(self, type, value, traceback):
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.cldm.cldm
import comfy.utils
comfy.utils.UNET_MAP_BASIC = self.original_UNET_MAP_BASIC
model_detection.unet_config_from_diffusers_unet = self.original_unet_config_from_diffusers_unet
import comfy.supported_models
comfy.supported_models.models = self.original_supported_models
import comfy.clip_vision
comfy.clip_vision.load_clipvision_from_sd = self.original_load_clipvision_from_sd
comfy.cldm.cldm.ControlNet = self.original_controlnet