File size: 10,415 Bytes
baa8e90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import torch
import comfy
import comfy.sd1_clip
from torch.nn.functional import silu
from types import MethodType
from comfy.sd import CLIP
from comfy import ldm
import ldm.modules.diffusionmodules
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.modules.attention
from . import devices, shared, sd_hijack_unet, sd_hijack_optimizations, script_callbacks, errors
from .textual_inversion import textual_inversion
from ..smZNodes import FrozenCLIPEmbedderWithCustomWordsCustom, FrozenOpenCLIPEmbedder2WithCustomWordsCustom, get_learned_conditioning
from functools import partial
if not hasattr(ldm.modules.diffusionmodules.model, "nonlinearity_orig"):
    ldm.modules.diffusionmodules.model.nonlinearity_orig = ldm.modules.diffusionmodules.model.nonlinearity
if not hasattr(ldm.modules.diffusionmodules.openaimodel, "th_orig"):
    ldm.modules.diffusionmodules.openaimodel.th_orig = ldm.modules.diffusionmodules.openaimodel.th

ldm.modules.attention.CrossAttention.forward_orig = ldm.modules.attention.CrossAttention.forward
ldm.modules.diffusionmodules.model.AttnBlock.forward_orig = ldm.modules.diffusionmodules.model.AttnBlock.forward

optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
already_optimized = False # temp fix for displaying info since two cliptextencode's will run

def list_optimizers():
    script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
    new_optimizers = script_callbacks.list_optimizers_callback()

    new_optimizers = [x for x in new_optimizers if x.is_available()]

    new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)

    optimizers.clear()
    optimizers.extend(new_optimizers)


def apply_optimizations(option=None):
    global already_optimized
    if already_optimized:
        display = False
    list_optimizers()
    global current_optimizer

    undo_optimizations()

    if len(optimizers) == 0:
        # a script can access the model very early, and optimizations would not be filled by then
        current_optimizer = None
        return ''

    ldm.modules.diffusionmodules.model.nonlinearity = silu
    ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

    # sgm.modules.diffusionmodules.model.nonlinearity = silu
    # sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

    if current_optimizer is not None:
        current_optimizer.undo()
        current_optimizer = None

    selection = option or shared.opts.cross_attention_optimization
    if selection == "Automatic" and len(optimizers) > 0:
        matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
    else:
        matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt == selection]), None)
    if selection == "None":
        matching_optimizer = None
    elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
        matching_optimizer = None
    elif matching_optimizer is None:
        matching_optimizer = optimizers[0]

    if matching_optimizer is not None:
        if shared.opts.debug:
            print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
        matching_optimizer.apply()
        already_optimized = True
        if shared.opts.debug:
            print("done.")
        current_optimizer = matching_optimizer
        return current_optimizer
    else:
        # if shared.opts.debug:
            # print("Disabling attention optimization")
        return ''

def undo_optimizations():
    sd_hijack_optimizations.undo()
    ldm.modules.diffusionmodules.model.nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity_orig
    ldm.modules.diffusionmodules.openaimodel.th = ldm.modules.diffusionmodules.openaimodel.th_orig

class StableDiffusionModelHijack:
    fixes = None
    comments = []
    layers = None
    circular_enabled = False
    clip = None
    tokenizer = None
    optimization_method = None
    embedding_db = textual_inversion.EmbeddingDatabase()

    def apply_optimizations(self, option=None):
        try:
            self.optimization_method = apply_optimizations(option)
        except Exception as e:
            errors.display(e, "applying optimizations")
            undo_optimizations()

    def hijack(self, m: comfy.sd1_clip.SD1ClipModel):
        tokenizer_parent = m.tokenizer # SD1Tokenizer
        # SDTokenizer
        tokenizer_parent2 = getattr(tokenizer_parent, tokenizer_parent.clip) if hasattr(tokenizer_parent, 'clip') else tokenizer_parent
        tokenizer = getattr(tokenizer_parent, tokenizer_parent.clip).tokenizer if hasattr(tokenizer_parent, 'clip') else tokenizer_parent.tokenizer
        if hasattr(m, 'clip'):
            m = getattr(m, m.clip)
        model_embeddings = m.transformer.text_model.embeddings
        model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
        model_embeddings.token_embedding.weight = model_embeddings.token_embedding.wrapped._parameters.get('weight').to(device=devices.device)
        m.tokenizer_parent0 = tokenizer_parent
        m.tokenizer_parent = tokenizer_parent2
        m.tokenizer = tokenizer
        m = FrozenOpenCLIPEmbedder2WithCustomWordsCustom(m, self) if "SDXLClipG" in type(m).__name__ else FrozenCLIPEmbedderWithCustomWordsCustom(m, self)
        m.clip_layer = getattr(m.wrapped, "clip_layer", None)
        m.reset_clip_layer = getattr(m.wrapped, "reset_clip_layer", None)
        m.transformer = getattr(m.wrapped, "transformer", None)
        self.cond_stage_model = m
        self.clip = m

        apply_weighted_forward(self.clip)
        self.apply_optimizations()

    def undo_hijack(self, m):
        try:
            m = m.wrapped
            model_embeddings = m.transformer.text_model.embeddings
            if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
                model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
            undo_optimizations()
            undo_weighted_forward(m)
            self.apply_circular(False)
            # self.layers = None
            self.clip = None
            self.cond_stage_model = None
        except Exception as err:
            print(err)

    def apply_circular(self, enable):
        if self.circular_enabled == enable:
            return

        self.circular_enabled = enable

        for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
            layer.padding_mode = 'circular' if enable else 'zeros'

    def clear_comments(self):
        self.comments = []

    def get_prompt_lengths(self, text):
        if self.clip is None:
            return 0, 0
        _, token_count = self.clip.process_texts([text])
        return token_count, self.clip.get_target_prompt_token_count(token_count)

model_hijack = StableDiffusionModelHijack()

def weighted_loss(sd_model, pred, target, mean=True):
    #Calculate the weight normally, but ignore the mean
    loss = sd_model._old_get_loss(pred, target, mean=False) # pylint: disable=protected-access

    #Check if we have weights available
    weight = getattr(sd_model, '_custom_loss_weight', None)
    if weight is not None:
        loss *= weight

    #Return the loss, as mean if specified
    return loss.mean() if mean else loss

def weighted_forward(sd_model, x, c, w, *args, **kwargs):
    try:
        #Temporarily append weights to a place accessible during loss calc
        sd_model._custom_loss_weight = w # pylint: disable=protected-access

        #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
        #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
        if not hasattr(sd_model, '_old_get_loss'):
            sd_model._old_get_loss = sd_model.get_loss # pylint: disable=protected-access
        sd_model.get_loss = MethodType(weighted_loss, sd_model)

        #Run the standard forward function, but with the patched 'get_loss'
        return sd_model.forward(x, c, *args, **kwargs)
    finally:
        try:
            #Delete temporary weights if appended
            del sd_model._custom_loss_weight
        except AttributeError:
            pass

        #If we have an old loss function, reset the loss function to the original one
        if hasattr(sd_model, '_old_get_loss'):
            sd_model.get_loss = sd_model._old_get_loss # pylint: disable=protected-access
            del sd_model._old_get_loss

def apply_weighted_forward(sd_model):
    #Add new function 'weighted_forward' that can be called to calc weighted loss
    sd_model.weighted_forward = MethodType(weighted_forward, sd_model)

def undo_weighted_forward(sd_model):
    try:
        del sd_model.weighted_forward
    except AttributeError:
        pass


class EmbeddingsWithFixes(torch.nn.Module):
    def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
        super().__init__()
        self.wrapped = wrapped
        self.embeddings = embeddings

    def forward(self, input_ids):
        batch_fixes = self.embeddings.fixes
        self.embeddings.fixes = None

        try:
            inputs_embeds = self.wrapped(input_ids)
        except:
            inputs_embeds = self.wrapped(input_ids.cpu())

        if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
            return inputs_embeds

        vecs = []
        for fixes, tensor in zip(batch_fixes, inputs_embeds):
            for offset, embedding in fixes:
                vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
                emb = devices.cond_cast_unet(vec)
                if emb.device != tensor.device:
                    emb = emb.to(device=tensor.device)
                emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
                try:
                    tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
                except Exception as err:
                    print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", tensor.shape[0], emb.shape[1])
                    # raise err
            vecs.append(tensor)

        return torch.stack(vecs)