|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
import os |
|
import re |
|
import warnings |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import torch |
|
from huggingface_hub import ( |
|
model_info, |
|
) |
|
from packaging import version |
|
|
|
from ..utils import ( |
|
SAFETENSORS_WEIGHTS_NAME, |
|
WEIGHTS_NAME, |
|
get_class_from_dynamic_module, |
|
is_peft_available, |
|
is_transformers_available, |
|
logging, |
|
) |
|
from ..utils.torch_utils import is_compiled_module |
|
|
|
|
|
if is_transformers_available(): |
|
import transformers |
|
from transformers import PreTrainedModel |
|
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME |
|
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME |
|
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME |
|
from huggingface_hub.utils import validate_hf_hub_args |
|
|
|
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME |
|
|
|
|
|
INDEX_FILE = "diffusion_pytorch_model.bin" |
|
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" |
|
DUMMY_MODULES_FOLDER = "diffusers.utils" |
|
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" |
|
CONNECTED_PIPES_KEYS = ["prior"] |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
LOADABLE_CLASSES = { |
|
"diffusers": { |
|
"ModelMixin": ["save_pretrained", "from_pretrained"], |
|
"SchedulerMixin": ["save_pretrained", "from_pretrained"], |
|
"DiffusionPipeline": ["save_pretrained", "from_pretrained"], |
|
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], |
|
}, |
|
"transformers": { |
|
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], |
|
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], |
|
"PreTrainedModel": ["save_pretrained", "from_pretrained"], |
|
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], |
|
"ProcessorMixin": ["save_pretrained", "from_pretrained"], |
|
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"], |
|
}, |
|
"onnxruntime.training": { |
|
"ORTModule": ["save_pretrained", "from_pretrained"], |
|
}, |
|
} |
|
|
|
ALL_IMPORTABLE_CLASSES = {} |
|
for library in LOADABLE_CLASSES: |
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) |
|
|
|
|
|
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: |
|
""" |
|
Checking for safetensors compatibility: |
|
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch |
|
files to know which safetensors files are needed. |
|
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. |
|
|
|
Converting default pytorch serialized filenames to safetensors serialized filenames: |
|
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors" |
|
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" |
|
extension is replaced with ".safetensors" |
|
""" |
|
pt_filenames = [] |
|
|
|
sf_filenames = set() |
|
|
|
passed_components = passed_components or [] |
|
|
|
for filename in filenames: |
|
_, extension = os.path.splitext(filename) |
|
|
|
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: |
|
continue |
|
|
|
if extension == ".bin": |
|
pt_filenames.append(os.path.normpath(filename)) |
|
elif extension == ".safetensors": |
|
sf_filenames.add(os.path.normpath(filename)) |
|
|
|
for filename in pt_filenames: |
|
|
|
path, filename = os.path.split(filename) |
|
filename, extension = os.path.splitext(filename) |
|
|
|
if filename.startswith("pytorch_model"): |
|
filename = filename.replace("pytorch_model", "model") |
|
else: |
|
filename = filename |
|
|
|
expected_sf_filename = os.path.normpath(os.path.join(path, filename)) |
|
expected_sf_filename = f"{expected_sf_filename}.safetensors" |
|
if expected_sf_filename not in sf_filenames: |
|
logger.warning(f"{expected_sf_filename} not found") |
|
return False |
|
|
|
return True |
|
|
|
|
|
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: |
|
weight_names = [ |
|
WEIGHTS_NAME, |
|
SAFETENSORS_WEIGHTS_NAME, |
|
FLAX_WEIGHTS_NAME, |
|
ONNX_WEIGHTS_NAME, |
|
ONNX_EXTERNAL_WEIGHTS_NAME, |
|
] |
|
|
|
if is_transformers_available(): |
|
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] |
|
|
|
|
|
weight_prefixes = [w.split(".")[0] for w in weight_names] |
|
|
|
weight_suffixs = [w.split(".")[-1] for w in weight_names] |
|
|
|
transformers_index_format = r"\d{5}-of-\d{5}" |
|
|
|
if variant is not None: |
|
|
|
variant_file_re = re.compile( |
|
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" |
|
) |
|
|
|
variant_index_re = re.compile( |
|
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" |
|
) |
|
|
|
|
|
non_variant_file_re = re.compile( |
|
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" |
|
) |
|
|
|
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") |
|
|
|
if variant is not None: |
|
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} |
|
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} |
|
variant_filenames = variant_weights | variant_indexes |
|
else: |
|
variant_filenames = set() |
|
|
|
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} |
|
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} |
|
non_variant_filenames = non_variant_weights | non_variant_indexes |
|
|
|
|
|
usable_filenames = set(variant_filenames) |
|
|
|
def convert_to_variant(filename): |
|
if "index" in filename: |
|
variant_filename = filename.replace("index", f"index.{variant}") |
|
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: |
|
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" |
|
else: |
|
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" |
|
return variant_filename |
|
|
|
for f in non_variant_filenames: |
|
variant_filename = convert_to_variant(f) |
|
if variant_filename not in usable_filenames: |
|
usable_filenames.add(f) |
|
|
|
return usable_filenames, variant_filenames |
|
|
|
|
|
@validate_hf_hub_args |
|
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): |
|
info = model_info( |
|
pretrained_model_name_or_path, |
|
token=token, |
|
revision=None, |
|
) |
|
filenames = {sibling.rfilename for sibling in info.siblings} |
|
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) |
|
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] |
|
|
|
if set(model_filenames).issubset(set(comp_model_filenames)): |
|
warnings.warn( |
|
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", |
|
FutureWarning, |
|
) |
|
else: |
|
warnings.warn( |
|
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", |
|
FutureWarning, |
|
) |
|
|
|
|
|
def _unwrap_model(model): |
|
"""Unwraps a model.""" |
|
if is_compiled_module(model): |
|
model = model._orig_mod |
|
|
|
if is_peft_available(): |
|
from peft import PeftModel |
|
|
|
if isinstance(model, PeftModel): |
|
model = model.base_model.model |
|
|
|
return model |
|
|
|
|
|
def maybe_raise_or_warn( |
|
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module |
|
): |
|
"""Simple helper method to raise or warn in case incorrect module has been passed""" |
|
if not is_pipeline_module: |
|
library = importlib.import_module(library_name) |
|
class_obj = getattr(library, class_name) |
|
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} |
|
|
|
expected_class_obj = None |
|
for class_name, class_candidate in class_candidates.items(): |
|
if class_candidate is not None and issubclass(class_obj, class_candidate): |
|
expected_class_obj = class_candidate |
|
|
|
|
|
|
|
sub_model = passed_class_obj[name] |
|
unwrapped_sub_model = _unwrap_model(sub_model) |
|
model_cls = unwrapped_sub_model.__class__ |
|
|
|
if not issubclass(model_cls, expected_class_obj): |
|
raise ValueError( |
|
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" |
|
) |
|
else: |
|
logger.warning( |
|
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" |
|
" has the correct type" |
|
) |
|
|
|
|
|
def get_class_obj_and_candidates( |
|
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None |
|
): |
|
"""Simple helper method to retrieve class object of module as well as potential parent class objects""" |
|
component_folder = os.path.join(cache_dir, component_name) |
|
|
|
if is_pipeline_module: |
|
pipeline_module = getattr(pipelines, library_name) |
|
|
|
class_obj = getattr(pipeline_module, class_name) |
|
class_candidates = {c: class_obj for c in importable_classes.keys()} |
|
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): |
|
|
|
class_obj = get_class_from_dynamic_module( |
|
component_folder, module_file=library_name + ".py", class_name=class_name |
|
) |
|
class_candidates = {c: class_obj for c in importable_classes.keys()} |
|
else: |
|
|
|
library = importlib.import_module(library_name) |
|
|
|
class_obj = getattr(library, class_name) |
|
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} |
|
|
|
return class_obj, class_candidates |
|
|
|
|
|
def _get_pipeline_class( |
|
class_obj, |
|
config=None, |
|
load_connected_pipeline=False, |
|
custom_pipeline=None, |
|
repo_id=None, |
|
hub_revision=None, |
|
class_name=None, |
|
cache_dir=None, |
|
revision=None, |
|
): |
|
if custom_pipeline is not None: |
|
if custom_pipeline.endswith(".py"): |
|
path = Path(custom_pipeline) |
|
|
|
file_name = path.name |
|
custom_pipeline = path.parent.absolute() |
|
elif repo_id is not None: |
|
file_name = f"{custom_pipeline}.py" |
|
custom_pipeline = repo_id |
|
else: |
|
file_name = CUSTOM_PIPELINE_FILE_NAME |
|
|
|
if repo_id is not None and hub_revision is not None: |
|
|
|
|
|
revision = hub_revision |
|
|
|
return get_class_from_dynamic_module( |
|
custom_pipeline, |
|
module_file=file_name, |
|
class_name=class_name, |
|
cache_dir=cache_dir, |
|
revision=revision, |
|
) |
|
|
|
if class_obj.__name__ != "DiffusionPipeline": |
|
return class_obj |
|
|
|
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) |
|
class_name = class_name or config["_class_name"] |
|
if not class_name: |
|
raise ValueError( |
|
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." |
|
) |
|
|
|
class_name = class_name[4:] if class_name.startswith("Flax") else class_name |
|
|
|
pipeline_cls = getattr(diffusers_module, class_name) |
|
|
|
if load_connected_pipeline: |
|
from .auto_pipeline import _get_connected_pipeline |
|
|
|
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) |
|
if connected_pipeline_cls is not None: |
|
logger.info( |
|
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" |
|
) |
|
else: |
|
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") |
|
|
|
pipeline_cls = connected_pipeline_cls or pipeline_cls |
|
|
|
return pipeline_cls |
|
|
|
|
|
def load_sub_model( |
|
library_name: str, |
|
class_name: str, |
|
importable_classes: List[Any], |
|
pipelines: Any, |
|
is_pipeline_module: bool, |
|
pipeline_class: Any, |
|
torch_dtype: torch.dtype, |
|
provider: Any, |
|
sess_options: Any, |
|
device_map: Optional[Union[Dict[str, torch.device], str]], |
|
max_memory: Optional[Dict[Union[int, str], Union[int, str]]], |
|
offload_folder: Optional[Union[str, os.PathLike]], |
|
offload_state_dict: bool, |
|
model_variants: Dict[str, str], |
|
name: str, |
|
from_flax: bool, |
|
variant: str, |
|
low_cpu_mem_usage: bool, |
|
cached_folder: Union[str, os.PathLike], |
|
): |
|
"""Helper method to load the module `name` from `library_name` and `class_name`""" |
|
|
|
class_obj, class_candidates = get_class_obj_and_candidates( |
|
library_name, |
|
class_name, |
|
importable_classes, |
|
pipelines, |
|
is_pipeline_module, |
|
component_name=name, |
|
cache_dir=cached_folder, |
|
) |
|
|
|
load_method_name = None |
|
|
|
for class_name, class_candidate in class_candidates.items(): |
|
if class_candidate is not None and issubclass(class_obj, class_candidate): |
|
load_method_name = importable_classes[class_name][1] |
|
|
|
|
|
if load_method_name is None: |
|
none_module = class_obj.__module__ |
|
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( |
|
TRANSFORMERS_DUMMY_MODULES_FOLDER |
|
) |
|
if is_dummy_path and "dummy" in none_module: |
|
|
|
class_obj() |
|
|
|
raise ValueError( |
|
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" |
|
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." |
|
) |
|
|
|
load_method = getattr(class_obj, load_method_name) |
|
|
|
|
|
diffusers_module = importlib.import_module(__name__.split(".")[0]) |
|
loading_kwargs = {} |
|
if issubclass(class_obj, torch.nn.Module): |
|
loading_kwargs["torch_dtype"] = torch_dtype |
|
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): |
|
loading_kwargs["provider"] = provider |
|
loading_kwargs["sess_options"] = sess_options |
|
|
|
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) |
|
|
|
if is_transformers_available(): |
|
transformers_version = version.parse(version.parse(transformers.__version__).base_version) |
|
else: |
|
transformers_version = "N/A" |
|
|
|
is_transformers_model = ( |
|
is_transformers_available() |
|
and issubclass(class_obj, PreTrainedModel) |
|
and transformers_version >= version.parse("4.20.0") |
|
) |
|
|
|
|
|
|
|
|
|
if is_diffusers_model or is_transformers_model: |
|
loading_kwargs["device_map"] = device_map |
|
loading_kwargs["max_memory"] = max_memory |
|
loading_kwargs["offload_folder"] = offload_folder |
|
loading_kwargs["offload_state_dict"] = offload_state_dict |
|
loading_kwargs["variant"] = model_variants.pop(name, None) |
|
|
|
if from_flax: |
|
loading_kwargs["from_flax"] = True |
|
|
|
|
|
|
|
if ( |
|
is_transformers_model |
|
and loading_kwargs["variant"] is not None |
|
and transformers_version < version.parse("4.27.0") |
|
): |
|
raise ImportError( |
|
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" |
|
) |
|
elif is_transformers_model and loading_kwargs["variant"] is None: |
|
loading_kwargs.pop("variant") |
|
|
|
|
|
if not (from_flax and is_transformers_model): |
|
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage |
|
else: |
|
loading_kwargs["low_cpu_mem_usage"] = False |
|
|
|
|
|
if os.path.isdir(os.path.join(cached_folder, name)): |
|
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) |
|
else: |
|
|
|
loaded_sub_model = load_method(cached_folder, **loading_kwargs) |
|
|
|
return loaded_sub_model |
|
|
|
|
|
def _fetch_class_library_tuple(module): |
|
|
|
diffusers_module = importlib.import_module(__name__.split(".")[0]) |
|
pipelines = getattr(diffusers_module, "pipelines") |
|
|
|
|
|
not_compiled_module = _unwrap_model(module) |
|
library = not_compiled_module.__module__.split(".")[0] |
|
|
|
|
|
module_path_items = not_compiled_module.__module__.split(".") |
|
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None |
|
|
|
path = not_compiled_module.__module__.split(".") |
|
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) |
|
|
|
|
|
|
|
|
|
if is_pipeline_module: |
|
library = pipeline_dir |
|
elif library not in LOADABLE_CLASSES: |
|
library = not_compiled_module.__module__ |
|
|
|
|
|
class_name = not_compiled_module.__class__.__name__ |
|
|
|
return (library, class_name) |
|
|