Spaces:
Configuration error
Configuration error
File size: 6,157 Bytes
8866644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
import comfy.conds
import torch
import math
from comfy import model_management
from .diffusers_convert import convert_state_dict
class EXM_PixArt(comfy.supported_models_base.BASE):
unet_config = {}
unet_extra_config = {}
latent_format = comfy.latent_formats.SD15
def __init__(self, model_conf):
self.model_target = model_conf.get("target")
self.unet_config = model_conf.get("unet_config", {})
self.sampling_settings = model_conf.get("sampling_settings", {})
self.latent_format = self.latent_format()
# UNET is handled by extension
self.unet_config["disable_unet_model_creation"] = True
def model_type(self, state_dict, prefix=""):
return comfy.model_base.ModelType.EPS
class EXM_PixArt_Model(comfy.model_base.BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
img_hw = kwargs.get("img_hw", None)
if img_hw is not None:
out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))
aspect_ratio = kwargs.get("aspect_ratio", None)
if aspect_ratio is not None:
out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))
cn_hint = kwargs.get("cn_hint", None)
if cn_hint is not None:
out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)
return out
def load_pixart(model_path, model_conf=None):
state_dict = comfy.utils.load_torch_file(model_path)
state_dict = state_dict.get("model", state_dict)
# prefix
for prefix in ["model.diffusion_model.",]:
if any(True for x in state_dict if x.startswith(prefix)):
state_dict = {k[len(prefix):]:v for k,v in state_dict.items()}
# diffusers
if "adaln_single.linear.weight" in state_dict:
state_dict = convert_state_dict(state_dict) # Diffusers
# guess auto config
if model_conf is None:
model_conf = guess_pixart_config(state_dict)
parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
# ignore fp8/etc and use directly for now
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype:
print(f"PixArt: falling back to {manual_cast_dtype}")
unet_dtype = manual_cast_dtype
model_conf = EXM_PixArt(model_conf) # convert to object
model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel
model_conf,
model_type=comfy.model_base.ModelType.EPS,
device=model_management.get_torch_device()
)
if model_conf.model_target == "PixArtMS":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
elif model_conf.model_target == "PixArt":
from .models.PixArt import PixArt
model.diffusion_model = PixArt(**model_conf.unet_config)
elif model_conf.model_target == "PixArtMSSigma":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.latent_format = comfy.latent_formats.SDXL()
elif model_conf.model_target == "ControlPixArtMSHalf":
from .models.PixArtMS import PixArtMS
from .models.pixart_controlnet import ControlPixArtMSHalf
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
elif model_conf.model_target == "ControlPixArtHalf":
from .models.PixArt import PixArt
from .models.pixart_controlnet import ControlPixArtHalf
model.diffusion_model = PixArt(**model_conf.unet_config)
model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
else:
raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")
m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
if len(m) > 0: print("Missing UNET keys", m)
if len(u) > 0: print("Leftover UNET keys", u)
model.diffusion_model.dtype = unet_dtype
model.diffusion_model.eval()
model.diffusion_model.to(unet_dtype)
model_patcher = comfy.model_patcher.ModelPatcher(
model,
load_device = load_device,
offload_device = offload_device,
current_device = "cpu",
)
return model_patcher
def guess_pixart_config(sd):
"""
Guess config based on converted state dict.
"""
# Shared settings based on DiT_XL_2 - could be enumerated
config = {
"num_heads" : 16, # get from attention
"patch_size" : 2, # final layer I guess?
"hidden_size" : 1152, # pos_embed.shape[2]
}
config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28
try:
# this is not present in the diffusers version for sigma?
config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0]
except KeyError:
# need better logic to guess this
config["model_max_length"] = 300
if "pos_embed" in sd:
config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"]
config["pe_interpolation"] = config["input_size"] // (512//8) # dumb guess
target_arch = "PixArtMS"
if config["model_max_length"] == 300:
# Sigma
target_arch = "PixArtMSSigma"
config["micro_condition"] = False
if "input_size" not in config:
# The diffusers weights for 1K/2K are exactly the same...?
# replace patch embed logic with HyDiT?
print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!")
config["input_size"] = 1024//8
else:
# Alpha
if "csize_embedder.mlp.0.weight" in sd:
# MS (microconds)
target_arch = "PixArtMS"
config["micro_condition"] = True
if "input_size" not in config:
config["input_size"] = 1024//8
config["pe_interpolation"] = 2
else:
# PixArt
target_arch = "PixArt"
if "input_size" not in config:
config["input_size"] = 512//8
config["pe_interpolation"] = 1
print("PixArt guessed config:", target_arch, config)
return {
"target": target_arch,
"unet_config": config,
"sampling_settings": {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
}
|