File size: 6,157 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
import comfy.conds
import torch
import math 
from comfy import model_management
from .diffusers_convert import convert_state_dict

class EXM_PixArt(comfy.supported_models_base.BASE):
	unet_config = {}
	unet_extra_config = {}
	latent_format = comfy.latent_formats.SD15

	def __init__(self, model_conf):
		self.model_target = model_conf.get("target")
		self.unet_config = model_conf.get("unet_config", {})
		self.sampling_settings = model_conf.get("sampling_settings", {})
		self.latent_format = self.latent_format()
		# UNET is handled by extension
		self.unet_config["disable_unet_model_creation"] = True

	def model_type(self, state_dict, prefix=""):
		return comfy.model_base.ModelType.EPS

class EXM_PixArt_Model(comfy.model_base.BaseModel):
	def __init__(self, *args, **kwargs):
		super().__init__(*args, **kwargs)
	
	def extra_conds(self, **kwargs):
		out = super().extra_conds(**kwargs)

		img_hw = kwargs.get("img_hw", None)
		if img_hw is not None:
			out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))
		
		aspect_ratio = kwargs.get("aspect_ratio", None)
		if aspect_ratio is not None:
			out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))

		cn_hint = kwargs.get("cn_hint", None)
		if cn_hint is not None:
			out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)

		return out

def load_pixart(model_path, model_conf=None):
	state_dict = comfy.utils.load_torch_file(model_path)
	state_dict = state_dict.get("model", state_dict)

	# prefix
	for prefix in ["model.diffusion_model.",]:
		if any(True for x in state_dict if x.startswith(prefix)):
			state_dict = {k[len(prefix):]:v for k,v in state_dict.items()}

	# diffusers
	if "adaln_single.linear.weight" in state_dict:
		state_dict = convert_state_dict(state_dict) # Diffusers

	# guess auto config
	if model_conf is None:
		model_conf = guess_pixart_config(state_dict)

	parameters = comfy.utils.calculate_parameters(state_dict)
	unet_dtype = model_management.unet_dtype(model_params=parameters)
	load_device = comfy.model_management.get_torch_device()
	offload_device = comfy.model_management.unet_offload_device()

	# ignore fp8/etc and use directly for now
	manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
	if manual_cast_dtype:
		print(f"PixArt: falling back to {manual_cast_dtype}")
		unet_dtype = manual_cast_dtype

	model_conf = EXM_PixArt(model_conf) # convert to object
	model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel
		model_conf,
		model_type=comfy.model_base.ModelType.EPS,
		device=model_management.get_torch_device()
	)

	if model_conf.model_target == "PixArtMS":
		from .models.PixArtMS import PixArtMS
		model.diffusion_model = PixArtMS(**model_conf.unet_config)
	elif model_conf.model_target == "PixArt":
		from .models.PixArt import PixArt
		model.diffusion_model = PixArt(**model_conf.unet_config)
	elif model_conf.model_target == "PixArtMSSigma":
		from .models.PixArtMS import PixArtMS
		model.diffusion_model = PixArtMS(**model_conf.unet_config)
		model.latent_format = comfy.latent_formats.SDXL()
	elif model_conf.model_target == "ControlPixArtMSHalf":
		from .models.PixArtMS import PixArtMS
		from .models.pixart_controlnet import ControlPixArtMSHalf
		model.diffusion_model = PixArtMS(**model_conf.unet_config)
		model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
	elif model_conf.model_target == "ControlPixArtHalf":
		from .models.PixArt import PixArt
		from .models.pixart_controlnet import ControlPixArtHalf
		model.diffusion_model = PixArt(**model_conf.unet_config)
		model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
	else:
		raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")

	m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
	if len(m) > 0: print("Missing UNET keys", m)
	if len(u) > 0: print("Leftover UNET keys", u)
	model.diffusion_model.dtype = unet_dtype
	model.diffusion_model.eval()
	model.diffusion_model.to(unet_dtype)

	model_patcher = comfy.model_patcher.ModelPatcher(
		model,
		load_device = load_device,
		offload_device = offload_device,
		current_device = "cpu",
	)
	return model_patcher

def guess_pixart_config(sd):
	"""
	Guess config based on converted state dict.
	"""
	# Shared settings based on DiT_XL_2 - could be enumerated
	config = {
		"num_heads"   :   16, # get from attention
		"patch_size"  :    2, # final layer I guess?
		"hidden_size" : 1152, # pos_embed.shape[2]
	}
	config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28

	try:
		# this is not present in the diffusers version for sigma?
		config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0]
	except KeyError:
		# need better logic to guess this
		config["model_max_length"] = 300

	if "pos_embed" in sd:
		config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"]
		config["pe_interpolation"] = config["input_size"] // (512//8) # dumb guess

	target_arch = "PixArtMS"
	if config["model_max_length"] == 300:
		# Sigma
		target_arch = "PixArtMSSigma"
		config["micro_condition"] = False
		if "input_size" not in config:
			# The diffusers weights for 1K/2K are exactly the same...?
			# replace patch embed logic with HyDiT?
			print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!")
			config["input_size"] = 1024//8
	else:
		# Alpha
		if "csize_embedder.mlp.0.weight" in sd:
			# MS (microconds)
			target_arch = "PixArtMS"
			config["micro_condition"] = True
			if "input_size" not in config:
				config["input_size"] = 1024//8
				config["pe_interpolation"] = 2
		else:
			# PixArt
			target_arch = "PixArt"
			if "input_size" not in config:
				config["input_size"] = 512//8
				config["pe_interpolation"] = 1

	print("PixArt guessed config:", target_arch, config)
	return {
		"target": target_arch,
		"unet_config": config,
		"sampling_settings": {
			"beta_schedule" : "sqrt_linear",
			"linear_start"  : 0.0001,
			"linear_end"    : 0.02,
			"timesteps"     : 1000,
		}	
	}