File size: 6,681 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import time
from typing import Union
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import
from modules.shared import log
from modules import errors, sd_models
from modules.control.units import detect


what = 'T2I-Adapter'
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
predefined_sd15 = {
    'Segment': 'TencentARC/t2iadapter_seg_sd14v1',
    'Zoe Depth': 'TencentARC/t2iadapter_zoedepth_sd15v1',
    'OpenPose': 'TencentARC/t2iadapter_openpose_sd14v1',
    'KeyPose': 'TencentARC/t2iadapter_keypose_sd14v1',
    'Color': 'TencentARC/t2iadapter_color_sd14v1',
    'Depth v1': 'TencentARC/t2iadapter_depth_sd14v1',
    'Depth v2': 'TencentARC/t2iadapter_depth_sd15v2',
    'Canny v1': 'TencentARC/t2iadapter_canny_sd14v1',
    'Canny v2': 'TencentARC/t2iadapter_canny_sd15v2',
    'Sketch v1': 'TencentARC/t2iadapter_sketch_sd14v1',
    'Sketch v2': 'TencentARC/t2iadapter_sketch_sd15v2',
}
predefined_sdxl = {
    'Canny XL': 'TencentARC/t2i-adapter-canny-sdxl-1.0',
    'LineArt XL': 'TencentARC/t2i-adapter-lineart-sdxl-1.0',
    'Sketch XL': 'TencentARC/t2i-adapter-sketch-sdxl-1.0',
    'Zoe Depth XL': 'TencentARC/t2i-adapter-depth-zoe-sdxl-1.0',
    'OpenPose XL': 'TencentARC/t2i-adapter-openpose-sdxl-1.0',
    'Midas Depth XL': 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0',
}
models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/adapter'


def list_models(refresh=False):
    import modules.shared
    global models # pylint: disable=global-statement
    if not refresh and len(models) > 0:
        return models
    models = {}
    if modules.shared.sd_model_type == 'none':
        models = ['None']
    elif modules.shared.sd_model_type == 'sdxl':
        models = ['None'] + sorted(predefined_sdxl)
    elif modules.shared.sd_model_type == 'sd':
        models = ['None'] + sorted(predefined_sd15)
    else:
        log.warning(f'Control {what} model list failed: unknown model type')
        models = ['None'] + sorted(list(predefined_sd15) + list(predefined_sdxl))
    debug(f'Control list {what}: path={cache_dir} models={models}')
    return models


class AdapterModel(T2IAdapter):
    pass


class Adapter():
    def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
        self.model: AdapterModel = None
        self.model_id: str = model_id
        self.device = device
        self.dtype = dtype
        self.load_config = { 'cache_dir': cache_dir }
        if load_config is not None:
            self.load_config.update(load_config)
        if model_id is not None:
            self.load()

    def reset(self):
        if self.model is not None:
            debug(f'Control {what} model unloaded')
        self.model = None
        self.model_id = None

    def load(self, model_id: str = None) -> str:
        try:
            t0 = time.time()
            model_id = model_id or self.model_id
            if model_id is None or model_id == 'None':
                self.reset()
                return
            model_path = all_models[model_id]
            if model_path is None:
                log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
                return
            log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"')
            self.model = T2IAdapter.from_pretrained(model_path, **self.load_config)
            if self.device is not None:
                self.model.to(self.device)
            if self.dtype is not None:
                self.model.to(self.dtype)
            t1 = time.time()
            self.model_id = model_id
            log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
            return f'{what} loaded model: {model_id}'
        except Exception as e:
            log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
            errors.display(e, f'Control {what} load')
            return f'{what} failed to load model: {model_id}'


class AdapterPipeline():
    def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
        t0 = time.time()
        self.orig_pipeline = pipeline
        self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None
        if pipeline is None:
            log.error(f'Control {what} pipeline: model not loaded')
            return
        if isinstance(adapter, list) and len(adapter) > 1:
            adapter = MultiAdapter(adapter)
        adapter.to(device=pipeline.device, dtype=pipeline.dtype)
        if detect.is_sdxl(pipeline):
            self.pipeline = StableDiffusionXLAdapterPipeline(
                vae=pipeline.vae,
                text_encoder=pipeline.text_encoder,
                text_encoder_2=pipeline.text_encoder_2,
                tokenizer=pipeline.tokenizer,
                tokenizer_2=pipeline.tokenizer_2,
                unet=pipeline.unet,
                scheduler=pipeline.scheduler,
                feature_extractor=getattr(pipeline, 'feature_extractor', None),
                adapter=adapter,
            )
            sd_models.move_model(self.pipeline, pipeline.device)
        elif detect.is_sd15(pipeline):
            self.pipeline = StableDiffusionAdapterPipeline(
                vae=pipeline.vae,
                text_encoder=pipeline.text_encoder,
                tokenizer=pipeline.tokenizer,
                unet=pipeline.unet,
                scheduler=pipeline.scheduler,
                feature_extractor=getattr(pipeline, 'feature_extractor', None),
                requires_safety_checker=False,
                safety_checker=None,
                adapter=adapter,
            )
            sd_models.move_model(self.pipeline, pipeline.device)
        else:
            log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
            return
        if dtype is not None and self.pipeline is not None:
            self.pipeline.dtype = dtype
        t1 = time.time()
        if self.pipeline is not None:
            log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
        else:
            log.error(f'Control {what} pipeline: not initialized')


    def restore(self):
        self.pipeline = None
        return self.orig_pipeline