test / modules /control /units /t2iadapter.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import os
import time
from typing import Union
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import
from modules.shared import log
from modules import errors, sd_models
from modules.control.units import detect
what = 'T2I-Adapter'
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
predefined_sd15 = {
'Segment': 'TencentARC/t2iadapter_seg_sd14v1',
'Zoe Depth': 'TencentARC/t2iadapter_zoedepth_sd15v1',
'OpenPose': 'TencentARC/t2iadapter_openpose_sd14v1',
'KeyPose': 'TencentARC/t2iadapter_keypose_sd14v1',
'Color': 'TencentARC/t2iadapter_color_sd14v1',
'Depth v1': 'TencentARC/t2iadapter_depth_sd14v1',
'Depth v2': 'TencentARC/t2iadapter_depth_sd15v2',
'Canny v1': 'TencentARC/t2iadapter_canny_sd14v1',
'Canny v2': 'TencentARC/t2iadapter_canny_sd15v2',
'Sketch v1': 'TencentARC/t2iadapter_sketch_sd14v1',
'Sketch v2': 'TencentARC/t2iadapter_sketch_sd15v2',
}
predefined_sdxl = {
'Canny XL': 'TencentARC/t2i-adapter-canny-sdxl-1.0',
'LineArt XL': 'TencentARC/t2i-adapter-lineart-sdxl-1.0',
'Sketch XL': 'TencentARC/t2i-adapter-sketch-sdxl-1.0',
'Zoe Depth XL': 'TencentARC/t2i-adapter-depth-zoe-sdxl-1.0',
'OpenPose XL': 'TencentARC/t2i-adapter-openpose-sdxl-1.0',
'Midas Depth XL': 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0',
}
models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/adapter'
def list_models(refresh=False):
import modules.shared
global models # pylint: disable=global-statement
if not refresh and len(models) > 0:
return models
models = {}
if modules.shared.sd_model_type == 'none':
models = ['None']
elif modules.shared.sd_model_type == 'sdxl':
models = ['None'] + sorted(predefined_sdxl)
elif modules.shared.sd_model_type == 'sd':
models = ['None'] + sorted(predefined_sd15)
else:
log.warning(f'Control {what} model list failed: unknown model type')
models = ['None'] + sorted(list(predefined_sd15) + list(predefined_sdxl))
debug(f'Control list {what}: path={cache_dir} models={models}')
return models
class AdapterModel(T2IAdapter):
pass
class Adapter():
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
self.model: AdapterModel = None
self.model_id: str = model_id
self.device = device
self.dtype = dtype
self.load_config = { 'cache_dir': cache_dir }
if load_config is not None:
self.load_config.update(load_config)
if model_id is not None:
self.load()
def reset(self):
if self.model is not None:
debug(f'Control {what} model unloaded')
self.model = None
self.model_id = None
def load(self, model_id: str = None) -> str:
try:
t0 = time.time()
model_id = model_id or self.model_id
if model_id is None or model_id == 'None':
self.reset()
return
model_path = all_models[model_id]
if model_path is None:
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
return
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"')
self.model = T2IAdapter.from_pretrained(model_path, **self.load_config)
if self.device is not None:
self.model.to(self.device)
if self.dtype is not None:
self.model.to(self.dtype)
t1 = time.time()
self.model_id = model_id
log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
return f'{what} loaded model: {model_id}'
except Exception as e:
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
errors.display(e, f'Control {what} load')
return f'{what} failed to load model: {model_id}'
class AdapterPipeline():
def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
t0 = time.time()
self.orig_pipeline = pipeline
self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None
if pipeline is None:
log.error(f'Control {what} pipeline: model not loaded')
return
if isinstance(adapter, list) and len(adapter) > 1:
adapter = MultiAdapter(adapter)
adapter.to(device=pipeline.device, dtype=pipeline.dtype)
if detect.is_sdxl(pipeline):
self.pipeline = StableDiffusionXLAdapterPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
text_encoder_2=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer,
tokenizer_2=pipeline.tokenizer_2,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
adapter=adapter,
)
sd_models.move_model(self.pipeline, pipeline.device)
elif detect.is_sd15(pipeline):
self.pipeline = StableDiffusionAdapterPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
requires_safety_checker=False,
safety_checker=None,
adapter=adapter,
)
sd_models.move_model(self.pipeline, pipeline.device)
else:
log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
return
if dtype is not None and self.pipeline is not None:
self.pipeline.dtype = dtype
t1 = time.time()
if self.pipeline is not None:
log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
else:
log.error(f'Control {what} pipeline: not initialized')
def restore(self):
self.pipeline = None
return self.orig_pipeline