File size: 10,415 Bytes
baa8e90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import torch
import comfy
import comfy.sd1_clip
from torch.nn.functional import silu
from types import MethodType
from comfy.sd import CLIP
from comfy import ldm
import ldm.modules.diffusionmodules
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.modules.attention
from . import devices, shared, sd_hijack_unet, sd_hijack_optimizations, script_callbacks, errors
from .textual_inversion import textual_inversion
from ..smZNodes import FrozenCLIPEmbedderWithCustomWordsCustom, FrozenOpenCLIPEmbedder2WithCustomWordsCustom, get_learned_conditioning
from functools import partial
if not hasattr(ldm.modules.diffusionmodules.model, "nonlinearity_orig"):
ldm.modules.diffusionmodules.model.nonlinearity_orig = ldm.modules.diffusionmodules.model.nonlinearity
if not hasattr(ldm.modules.diffusionmodules.openaimodel, "th_orig"):
ldm.modules.diffusionmodules.openaimodel.th_orig = ldm.modules.diffusionmodules.openaimodel.th
ldm.modules.attention.CrossAttention.forward_orig = ldm.modules.attention.CrossAttention.forward
ldm.modules.diffusionmodules.model.AttnBlock.forward_orig = ldm.modules.diffusionmodules.model.AttnBlock.forward
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
already_optimized = False # temp fix for displaying info since two cliptextencode's will run
def list_optimizers():
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
new_optimizers = script_callbacks.list_optimizers_callback()
new_optimizers = [x for x in new_optimizers if x.is_available()]
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
optimizers.clear()
optimizers.extend(new_optimizers)
def apply_optimizations(option=None):
global already_optimized
if already_optimized:
display = False
list_optimizers()
global current_optimizer
undo_optimizations()
if len(optimizers) == 0:
# a script can access the model very early, and optimizations would not be filled by then
current_optimizer = None
return ''
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
# sgm.modules.diffusionmodules.model.nonlinearity = silu
# sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt == selection]), None)
if selection == "None":
matching_optimizer = None
elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
matching_optimizer = None
elif matching_optimizer is None:
matching_optimizer = optimizers[0]
if matching_optimizer is not None:
if shared.opts.debug:
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
already_optimized = True
if shared.opts.debug:
print("done.")
current_optimizer = matching_optimizer
return current_optimizer
else:
# if shared.opts.debug:
# print("Disabling attention optimization")
return ''
def undo_optimizations():
sd_hijack_optimizations.undo()
ldm.modules.diffusionmodules.model.nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity_orig
ldm.modules.diffusionmodules.openaimodel.th = ldm.modules.diffusionmodules.openaimodel.th_orig
class StableDiffusionModelHijack:
fixes = None
comments = []
layers = None
circular_enabled = False
clip = None
tokenizer = None
optimization_method = None
embedding_db = textual_inversion.EmbeddingDatabase()
def apply_optimizations(self, option=None):
try:
self.optimization_method = apply_optimizations(option)
except Exception as e:
errors.display(e, "applying optimizations")
undo_optimizations()
def hijack(self, m: comfy.sd1_clip.SD1ClipModel):
tokenizer_parent = m.tokenizer # SD1Tokenizer
# SDTokenizer
tokenizer_parent2 = getattr(tokenizer_parent, tokenizer_parent.clip) if hasattr(tokenizer_parent, 'clip') else tokenizer_parent
tokenizer = getattr(tokenizer_parent, tokenizer_parent.clip).tokenizer if hasattr(tokenizer_parent, 'clip') else tokenizer_parent.tokenizer
if hasattr(m, 'clip'):
m = getattr(m, m.clip)
model_embeddings = m.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
model_embeddings.token_embedding.weight = model_embeddings.token_embedding.wrapped._parameters.get('weight').to(device=devices.device)
m.tokenizer_parent0 = tokenizer_parent
m.tokenizer_parent = tokenizer_parent2
m.tokenizer = tokenizer
m = FrozenOpenCLIPEmbedder2WithCustomWordsCustom(m, self) if "SDXLClipG" in type(m).__name__ else FrozenCLIPEmbedderWithCustomWordsCustom(m, self)
m.clip_layer = getattr(m.wrapped, "clip_layer", None)
m.reset_clip_layer = getattr(m.wrapped, "reset_clip_layer", None)
m.transformer = getattr(m.wrapped, "transformer", None)
self.cond_stage_model = m
self.clip = m
apply_weighted_forward(self.clip)
self.apply_optimizations()
def undo_hijack(self, m):
try:
m = m.wrapped
model_embeddings = m.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
undo_optimizations()
undo_weighted_forward(m)
self.apply_circular(False)
# self.layers = None
self.clip = None
self.cond_stage_model = None
except Exception as err:
print(err)
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
self.circular_enabled = enable
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
def clear_comments(self):
self.comments = []
def get_prompt_lengths(self, text):
if self.clip is None:
return 0, 0
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)
model_hijack = StableDiffusionModelHijack()
def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False) # pylint: disable=protected-access
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
#Return the loss, as mean if specified
return loss.mean() if mean else loss
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w # pylint: disable=protected-access
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
sd_model._old_get_loss = sd_model.get_loss # pylint: disable=protected-access
sd_model.get_loss = MethodType(weighted_loss, sd_model)
#Run the standard forward function, but with the patched 'get_loss'
return sd_model.forward(x, c, *args, **kwargs)
finally:
try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
except AttributeError:
pass
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss # pylint: disable=protected-access
del sd_model._old_get_loss
def apply_weighted_forward(sd_model):
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
except AttributeError:
pass
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
try:
inputs_embeds = self.wrapped(input_ids)
except:
inputs_embeds = self.wrapped(input_ids.cpu())
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
if emb.device != tensor.device:
emb = emb.to(device=tensor.device)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
try:
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
except Exception as err:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", tensor.shape[0], emb.shape[1])
# raise err
vecs.append(tensor)
return torch.stack(vecs)
|