# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os import re import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union import torch from huggingface_hub import ( model_info, ) from packaging import version from ..utils import ( SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, get_class_from_dynamic_module, is_peft_available, is_transformers_available, logging, ) from ..utils.torch_utils import is_compiled_module if is_transformers_available(): import transformers from transformers import PreTrainedModel from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME from huggingface_hub.utils import validate_hf_hub_args from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" CONNECTED_PIPES_KEYS = ["prior"] logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, "onnxruntime.training": { "ORTModule": ["save_pretrained", "from_pretrained"], }, } ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: """ Checking for safetensors compatibility: - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch files to know which safetensors files are needed. - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. Converting default pytorch serialized filenames to safetensors serialized filenames: - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" extension is replaced with ".safetensors" """ pt_filenames = [] sf_filenames = set() passed_components = passed_components or [] for filename in filenames: _, extension = os.path.splitext(filename) if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: continue if extension == ".bin": pt_filenames.append(os.path.normpath(filename)) elif extension == ".safetensors": sf_filenames.add(os.path.normpath(filename)) for filename in pt_filenames: # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' path, filename = os.path.split(filename) filename, extension = os.path.splitext(filename) if filename.startswith("pytorch_model"): filename = filename.replace("pytorch_model", "model") else: filename = filename expected_sf_filename = os.path.normpath(os.path.join(path, filename)) expected_sf_filename = f"{expected_sf_filename}.safetensors" if expected_sf_filename not in sf_filenames: logger.warning(f"{expected_sf_filename} not found") return False return True def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ] if is_transformers_available(): weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] # model_pytorch, diffusion_model_pytorch, ... weight_prefixes = [w.split(".")[0] for w in weight_names] # .bin, .safetensors, ... weight_suffixs = [w.split(".")[-1] for w in weight_names] # -00001-of-00002 transformers_index_format = r"\d{5}-of-\d{5}" if variant is not None: # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.fp16.json` variant_index_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" ) # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` non_variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") if variant is not None: variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} variant_filenames = variant_weights | variant_indexes else: variant_filenames = set() non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} non_variant_filenames = non_variant_weights | non_variant_indexes # all variant filenames will be used by default usable_filenames = set(variant_filenames) def convert_to_variant(filename): if "index" in filename: variant_filename = filename.replace("index", f"index.{variant}") elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" else: variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename for f in non_variant_filenames: variant_filename = convert_to_variant(f) if variant_filename not in usable_filenames: usable_filenames.add(f) return usable_filenames, variant_filenames @validate_hf_hub_args def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): info = model_info( pretrained_model_name_or_path, token=token, revision=None, ) filenames = {sibling.rfilename for sibling in info.siblings} comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] if set(model_filenames).issubset(set(comp_model_filenames)): warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, ) else: warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", FutureWarning, ) def _unwrap_model(model): """Unwraps a model.""" if is_compiled_module(model): model = model._orig_mod if is_peft_available(): from peft import PeftModel if isinstance(model, PeftModel): model = model.base_model.model return model def maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ): """Simple helper method to raise or warn in case incorrect module has been passed""" if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] unwrapped_sub_model = _unwrap_model(sub_model) model_cls = unwrapped_sub_model.__class__ if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" ) else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): """Simple helper method to retrieve class object of module as well as potential parent class objects""" component_folder = os.path.join(cache_dir, component_name) if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = {c: class_obj for c in importable_classes.keys()} elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): # load custom component class_obj = get_class_from_dynamic_module( component_folder, module_file=library_name + ".py", class_name=class_name ) class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} return class_obj, class_candidates def _get_pipeline_class( class_obj, config=None, load_connected_pipeline=False, custom_pipeline=None, repo_id=None, hub_revision=None, class_name=None, cache_dir=None, revision=None, ): if custom_pipeline is not None: if custom_pipeline.endswith(".py"): path = Path(custom_pipeline) # decompose into folder & file file_name = path.name custom_pipeline = path.parent.absolute() elif repo_id is not None: file_name = f"{custom_pipeline}.py" custom_pipeline = repo_id else: file_name = CUSTOM_PIPELINE_FILE_NAME if repo_id is not None and hub_revision is not None: # if we load the pipeline code from the Hub # make sure to overwrite the `revision` revision = hub_revision return get_class_from_dynamic_module( custom_pipeline, module_file=file_name, class_name=class_name, cache_dir=cache_dir, revision=revision, ) if class_obj.__name__ != "DiffusionPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) class_name = class_name or config["_class_name"] if not class_name: raise ValueError( "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." ) class_name = class_name[4:] if class_name.startswith("Flax") else class_name pipeline_cls = getattr(diffusers_module, class_name) if load_connected_pipeline: from .auto_pipeline import _get_connected_pipeline connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) if connected_pipeline_cls is not None: logger.info( f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" ) else: logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") pipeline_cls = connected_pipeline_cls or pipeline_cls return pipeline_cls def load_sub_model( library_name: str, class_name: str, importable_classes: List[Any], pipelines: Any, is_pipeline_module: bool, pipeline_class: Any, torch_dtype: torch.dtype, provider: Any, sess_options: Any, device_map: Optional[Union[Dict[str, torch.device], str]], max_memory: Optional[Dict[Union[int, str], Union[int, str]]], offload_folder: Optional[Union[str, os.PathLike]], offload_state_dict: bool, model_variants: Dict[str, str], name: str, from_flax: bool, variant: str, low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], ): """Helper method to load the module `name` from `library_name` and `class_name`""" # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=name, cache_dir=cached_folder, ) load_method_name = None # retrieve load method name for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] # if load method name is None, then we have a dummy module -> raise Error if load_method_name is None: none_module = class_obj.__module__ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( TRANSFORMERS_DUMMY_MODULES_FOLDER ) if is_dummy_path and "dummy" in none_module: # call class_obj for nice error message of missing requirements class_obj() raise ValueError( f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." ) load_method = getattr(class_obj, load_method_name) # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) loading_kwargs = {} if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) else: transformers_version = "N/A" is_transformers_model = ( is_transformers_available() and issubclass(class_obj, PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map loading_kwargs["max_memory"] = max_memory loading_kwargs["offload_folder"] = offload_folder loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) if from_flax: loading_kwargs["from_flax"] = True # the following can be deleted once the minimum required `transformers` version # is higher than 4.27 if ( is_transformers_model and loading_kwargs["variant"] is not None and transformers_version < version.parse("4.27.0") ): raise ImportError( f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" ) elif is_transformers_model and loading_kwargs["variant"] is None: loading_kwargs.pop("variant") # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` if not (from_flax and is_transformers_model): loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage else: loading_kwargs["low_cpu_mem_usage"] = False # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) else: # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) return loaded_sub_model def _fetch_class_library_tuple(module): # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines") # register the config from the original module, not the dynamo compiled one not_compiled_module = _unwrap_model(module) library = not_compiled_module.__module__.split(".")[0] # check if the module is a pipeline module module_path_items = not_compiled_module.__module__.split(".") pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None path = not_compiled_module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. if is_pipeline_module: library = pipeline_dir elif library not in LOADABLE_CLASSES: library = not_compiled_module.__module__ # retrieve class_name class_name = not_compiled_module.__class__.__name__ return (library, class_name)