Spaces:
Runtime error
Runtime error
File size: 6,681 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import time
from typing import Union
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import
from modules.shared import log
from modules import errors, sd_models
from modules.control.units import detect
what = 'T2I-Adapter'
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
predefined_sd15 = {
'Segment': 'TencentARC/t2iadapter_seg_sd14v1',
'Zoe Depth': 'TencentARC/t2iadapter_zoedepth_sd15v1',
'OpenPose': 'TencentARC/t2iadapter_openpose_sd14v1',
'KeyPose': 'TencentARC/t2iadapter_keypose_sd14v1',
'Color': 'TencentARC/t2iadapter_color_sd14v1',
'Depth v1': 'TencentARC/t2iadapter_depth_sd14v1',
'Depth v2': 'TencentARC/t2iadapter_depth_sd15v2',
'Canny v1': 'TencentARC/t2iadapter_canny_sd14v1',
'Canny v2': 'TencentARC/t2iadapter_canny_sd15v2',
'Sketch v1': 'TencentARC/t2iadapter_sketch_sd14v1',
'Sketch v2': 'TencentARC/t2iadapter_sketch_sd15v2',
}
predefined_sdxl = {
'Canny XL': 'TencentARC/t2i-adapter-canny-sdxl-1.0',
'LineArt XL': 'TencentARC/t2i-adapter-lineart-sdxl-1.0',
'Sketch XL': 'TencentARC/t2i-adapter-sketch-sdxl-1.0',
'Zoe Depth XL': 'TencentARC/t2i-adapter-depth-zoe-sdxl-1.0',
'OpenPose XL': 'TencentARC/t2i-adapter-openpose-sdxl-1.0',
'Midas Depth XL': 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0',
}
models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/adapter'
def list_models(refresh=False):
import modules.shared
global models # pylint: disable=global-statement
if not refresh and len(models) > 0:
return models
models = {}
if modules.shared.sd_model_type == 'none':
models = ['None']
elif modules.shared.sd_model_type == 'sdxl':
models = ['None'] + sorted(predefined_sdxl)
elif modules.shared.sd_model_type == 'sd':
models = ['None'] + sorted(predefined_sd15)
else:
log.warning(f'Control {what} model list failed: unknown model type')
models = ['None'] + sorted(list(predefined_sd15) + list(predefined_sdxl))
debug(f'Control list {what}: path={cache_dir} models={models}')
return models
class AdapterModel(T2IAdapter):
pass
class Adapter():
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
self.model: AdapterModel = None
self.model_id: str = model_id
self.device = device
self.dtype = dtype
self.load_config = { 'cache_dir': cache_dir }
if load_config is not None:
self.load_config.update(load_config)
if model_id is not None:
self.load()
def reset(self):
if self.model is not None:
debug(f'Control {what} model unloaded')
self.model = None
self.model_id = None
def load(self, model_id: str = None) -> str:
try:
t0 = time.time()
model_id = model_id or self.model_id
if model_id is None or model_id == 'None':
self.reset()
return
model_path = all_models[model_id]
if model_path is None:
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
return
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"')
self.model = T2IAdapter.from_pretrained(model_path, **self.load_config)
if self.device is not None:
self.model.to(self.device)
if self.dtype is not None:
self.model.to(self.dtype)
t1 = time.time()
self.model_id = model_id
log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
return f'{what} loaded model: {model_id}'
except Exception as e:
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
errors.display(e, f'Control {what} load')
return f'{what} failed to load model: {model_id}'
class AdapterPipeline():
def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
t0 = time.time()
self.orig_pipeline = pipeline
self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None
if pipeline is None:
log.error(f'Control {what} pipeline: model not loaded')
return
if isinstance(adapter, list) and len(adapter) > 1:
adapter = MultiAdapter(adapter)
adapter.to(device=pipeline.device, dtype=pipeline.dtype)
if detect.is_sdxl(pipeline):
self.pipeline = StableDiffusionXLAdapterPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
text_encoder_2=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer,
tokenizer_2=pipeline.tokenizer_2,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
adapter=adapter,
)
sd_models.move_model(self.pipeline, pipeline.device)
elif detect.is_sd15(pipeline):
self.pipeline = StableDiffusionAdapterPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
requires_safety_checker=False,
safety_checker=None,
adapter=adapter,
)
sd_models.move_model(self.pipeline, pipeline.device)
else:
log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
return
if dtype is not None and self.pipeline is not None:
self.pipeline.dtype = dtype
t1 = time.time()
if self.pipeline is not None:
log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
else:
log.error(f'Control {what} pipeline: not initialized')
def restore(self):
self.pipeline = None
return self.orig_pipeline
|