diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py new file mode 100644 index 000000000000..30c17eec119d --- /dev/null +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -0,0 +1,508 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import os +import re +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +from huggingface_hub import ( + model_info, +) +from packaging import version + +from ..utils import ( + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + get_class_from_dynamic_module, + is_peft_available, + is_transformers_available, + logging, +) +from ..utils.torch_utils import is_compiled_module + + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel + from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME + from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME + from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME +from huggingface_hub.utils import validate_hf_hub_args + +from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME + + +INDEX_FILE = "diffusion_pytorch_model.bin" +CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" +DUMMY_MODULES_FOLDER = "diffusers.utils" +TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" +CONNECTED_PIPES_KEYS = ["prior"] + +logger = logging.get_logger(__name__) + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: + """ + Checking for safetensors compatibility: + - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch + files to know which safetensors files are needed. + - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. + + Converting default pytorch serialized filenames to safetensors serialized filenames: + - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" + - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" + extension is replaced with ".safetensors" + """ + pt_filenames = [] + + sf_filenames = set() + + passed_components = passed_components or [] + + for filename in filenames: + _, extension = os.path.splitext(filename) + + if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: + continue + + if extension == ".bin": + pt_filenames.append(os.path.normpath(filename)) + elif extension == ".safetensors": + sf_filenames.add(os.path.normpath(filename)) + + for filename in pt_filenames: + # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' + path, filename = os.path.split(filename) + filename, extension = os.path.splitext(filename) + + if filename.startswith("pytorch_model"): + filename = filename.replace("pytorch_model", "model") + else: + filename = filename + + expected_sf_filename = os.path.normpath(os.path.join(path, filename)) + expected_sf_filename = f"{expected_sf_filename}.safetensors" + if expected_sf_filename not in sf_filenames: + logger.warning(f"{expected_sf_filename} not found") + return False + + return True + + +def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] + + if is_transformers_available(): + weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] + + # model_pytorch, diffusion_model_pytorch, ... + weight_prefixes = [w.split(".")[0] for w in weight_names] + # .bin, .safetensors, ... + weight_suffixs = [w.split(".")[-1] for w in weight_names] + # -00001-of-00002 + transformers_index_format = r"\d{5}-of-\d{5}" + + if variant is not None: + # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` + variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.fp16.json` + variant_index_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) + + # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` + non_variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.json` + non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") + + if variant is not None: + variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} + variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} + variant_filenames = variant_weights | variant_indexes + else: + variant_filenames = set() + + non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} + non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} + non_variant_filenames = non_variant_weights | non_variant_indexes + + # all variant filenames will be used by default + usable_filenames = set(variant_filenames) + + def convert_to_variant(filename): + if "index" in filename: + variant_filename = filename.replace("index", f"index.{variant}") + elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: + variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" + else: + variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" + return variant_filename + + for f in non_variant_filenames: + variant_filename = convert_to_variant(f) + if variant_filename not in usable_filenames: + usable_filenames.add(f) + + return usable_filenames, variant_filenames + + +@validate_hf_hub_args +def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): + info = model_info( + pretrained_model_name_or_path, + token=token, + revision=None, + ) + filenames = {sibling.rfilename for sibling in info.siblings} + comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) + comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] + + if set(model_filenames).issubset(set(comp_model_filenames)): + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + else: + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", + FutureWarning, + ) + + +def _unwrap_model(model): + """Unwraps a model.""" + if is_compiled_module(model): + model = model._orig_mod + + if is_peft_available(): + from peft import PeftModel + + if isinstance(model, PeftModel): + model = model.base_model.model + + return model + + +def maybe_raise_or_warn( + library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module +): + """Simple helper method to raise or warn in case incorrect module has been passed""" + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + sub_model = passed_class_obj[name] + unwrapped_sub_model = _unwrap_model(sub_model) + model_cls = unwrapped_sub_model.__class__ + + if not issubclass(model_cls, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" + ) + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + +def get_class_obj_and_candidates( + library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None +): + """Simple helper method to retrieve class object of module as well as potential parent class objects""" + component_folder = os.path.join(cache_dir, component_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + + class_obj = getattr(pipeline_module, class_name) + class_candidates = {c: class_obj for c in importable_classes.keys()} + elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): + # load custom component + class_obj = get_class_from_dynamic_module( + component_folder, module_file=library_name + ".py", class_name=class_name + ) + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + return class_obj, class_candidates + + +def _get_pipeline_class( + class_obj, + config=None, + load_connected_pipeline=False, + custom_pipeline=None, + repo_id=None, + hub_revision=None, + class_name=None, + cache_dir=None, + revision=None, +): + if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + elif repo_id is not None: + file_name = f"{custom_pipeline}.py" + custom_pipeline = repo_id + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + + if repo_id is not None and hub_revision is not None: + # if we load the pipeline code from the Hub + # make sure to overwrite the `revision` + revision = hub_revision + + return get_class_from_dynamic_module( + custom_pipeline, + module_file=file_name, + class_name=class_name, + cache_dir=cache_dir, + revision=revision, + ) + + if class_obj.__name__ != "DiffusionPipeline": + return class_obj + + diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) + class_name = class_name or config["_class_name"] + if not class_name: + raise ValueError( + "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." + ) + + class_name = class_name[4:] if class_name.startswith("Flax") else class_name + + pipeline_cls = getattr(diffusers_module, class_name) + + if load_connected_pipeline: + from .auto_pipeline import _get_connected_pipeline + + connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) + if connected_pipeline_cls is not None: + logger.info( + f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" + ) + else: + logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") + + pipeline_cls = connected_pipeline_cls or pipeline_cls + + return pipeline_cls + + +def load_sub_model( + library_name: str, + class_name: str, + importable_classes: List[Any], + pipelines: Any, + is_pipeline_module: bool, + pipeline_class: Any, + torch_dtype: torch.dtype, + provider: Any, + sess_options: Any, + device_map: Optional[Union[Dict[str, torch.device], str]], + max_memory: Optional[Dict[Union[int, str], Union[int, str]]], + offload_folder: Optional[Union[str, os.PathLike]], + offload_state_dict: bool, + model_variants: Dict[str, str], + name: str, + from_flax: bool, + variant: str, + low_cpu_mem_usage: bool, + cached_folder: Union[str, os.PathLike], +): + """Helper method to load the module `name` from `library_name` and `class_name`""" + # retrieve class candidates + class_obj, class_candidates = get_class_obj_and_candidates( + library_name, + class_name, + importable_classes, + pipelines, + is_pipeline_module, + component_name=name, + cache_dir=cached_folder, + ) + + load_method_name = None + # retrieve load method name + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + # if load method name is None, then we have a dummy module -> raise Error + if load_method_name is None: + none_module = class_obj.__module__ + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) + + # add kwargs to loading method + diffusers_module = importlib.import_module(__name__.split(".")[0]) + loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. + # This makes sure that the weights won't be initialized which significantly speeds up loading. + if is_diffusers_model or is_transformers_model: + loading_kwargs["device_map"] = device_map + loading_kwargs["max_memory"] = max_memory + loading_kwargs["offload_folder"] = offload_folder + loading_kwargs["offload_state_dict"] = offload_state_dict + loading_kwargs["variant"] = model_variants.pop(name, None) + + if from_flax: + loading_kwargs["from_flax"] = True + + # the following can be deleted once the minimum required `transformers` version + # is higher than 4.27 + if ( + is_transformers_model + and loading_kwargs["variant"] is not None + and transformers_version < version.parse("4.27.0") + ): + raise ImportError( + f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" + ) + elif is_transformers_model and loading_kwargs["variant"] is None: + loading_kwargs.pop("variant") + + # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` + if not (from_flax and is_transformers_model): + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + else: + loading_kwargs["low_cpu_mem_usage"] = False + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + return loaded_sub_model + + +def _fetch_class_library_tuple(module): + # import it here to avoid circular import + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # register the config from the original module, not the dynamo compiled one + not_compiled_module = _unwrap_model(module) + library = not_compiled_module.__module__.split(".")[0] + + # check if the module is a pipeline module + module_path_items = not_compiled_module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = not_compiled_module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if is_pipeline_module: + library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = not_compiled_module.__module__ + + # retrieve class_name + class_name = not_compiled_module.__class__.__name__ + + return (library, class_name) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8a441a12768d..e4ac56f9d011 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -19,7 +19,6 @@ import os import re import sys -import warnings from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union @@ -49,72 +48,44 @@ from ..utils import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, - SAFETENSORS_WEIGHTS_NAME, - WEIGHTS_NAME, BaseOutput, + PushToHubMixin, deprecate, - get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, - is_peft_available, is_torch_version, - is_transformers_available, logging, numpy_to_pil, ) from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module - - -if is_transformers_available(): - import transformers - from transformers import PreTrainedModel - from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME - from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME - from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME - -from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin +from .pipeline_loading_utils import ( + ALL_IMPORTABLE_CLASSES, + CONNECTED_PIPES_KEYS, + CUSTOM_PIPELINE_FILE_NAME, + LOADABLE_CLASSES, + _fetch_class_library_tuple, + _get_pipeline_class, + _unwrap_model, + is_safetensors_compatible, + load_sub_model, + maybe_raise_or_warn, + variant_compatible_siblings, + warn_deprecated_model_variant, +) if is_accelerate_available(): import accelerate -INDEX_FILE = "diffusion_pytorch_model.bin" -CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" -DUMMY_MODULES_FOLDER = "diffusers.utils" -TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" -CONNECTED_PIPES_KEYS = ["prior"] - +LIBRARIES = [] +for library in LOADABLE_CLASSES: + LIBRARIES.append(library) logger = logging.get_logger(__name__) -LOADABLE_CLASSES = { - "diffusers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_pretrained", "from_pretrained"], - "DiffusionPipeline": ["save_pretrained", "from_pretrained"], - "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], - }, - "transformers": { - "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], - "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], - "PreTrainedModel": ["save_pretrained", "from_pretrained"], - "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], - "ProcessorMixin": ["save_pretrained", "from_pretrained"], - "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], - }, - "onnxruntime.training": { - "ORTModule": ["save_pretrained", "from_pretrained"], - }, -} - -ALL_IMPORTABLE_CLASSES = {} -for library in LOADABLE_CLASSES: - ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) - - @dataclass class ImagePipelineOutput(BaseOutput): """ @@ -142,432 +113,6 @@ class AudioPipelineOutput(BaseOutput): audios: np.ndarray -def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: - """ - Checking for safetensors compatibility: - - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch - files to know which safetensors files are needed. - - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. - - Converting default pytorch serialized filenames to safetensors serialized filenames: - - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" - - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" - extension is replaced with ".safetensors" - """ - pt_filenames = [] - - sf_filenames = set() - - passed_components = passed_components or [] - - for filename in filenames: - _, extension = os.path.splitext(filename) - - if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: - continue - - if extension == ".bin": - pt_filenames.append(os.path.normpath(filename)) - elif extension == ".safetensors": - sf_filenames.add(os.path.normpath(filename)) - - for filename in pt_filenames: - # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' - path, filename = os.path.split(filename) - filename, extension = os.path.splitext(filename) - - if filename.startswith("pytorch_model"): - filename = filename.replace("pytorch_model", "model") - else: - filename = filename - - expected_sf_filename = os.path.normpath(os.path.join(path, filename)) - expected_sf_filename = f"{expected_sf_filename}.safetensors" - if expected_sf_filename not in sf_filenames: - logger.warning(f"{expected_sf_filename} not found") - return False - - return True - - -def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: - weight_names = [ - WEIGHTS_NAME, - SAFETENSORS_WEIGHTS_NAME, - FLAX_WEIGHTS_NAME, - ONNX_WEIGHTS_NAME, - ONNX_EXTERNAL_WEIGHTS_NAME, - ] - - if is_transformers_available(): - weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] - - # model_pytorch, diffusion_model_pytorch, ... - weight_prefixes = [w.split(".")[0] for w in weight_names] - # .bin, .safetensors, ... - weight_suffixs = [w.split(".")[-1] for w in weight_names] - # -00001-of-00002 - transformers_index_format = r"\d{5}-of-\d{5}" - - if variant is not None: - # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` - variant_file_re = re.compile( - rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" - ) - # `text_encoder/pytorch_model.bin.index.fp16.json` - variant_index_re = re.compile( - rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" - ) - - # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` - non_variant_file_re = re.compile( - rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" - ) - # `text_encoder/pytorch_model.bin.index.json` - non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") - - if variant is not None: - variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} - variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} - variant_filenames = variant_weights | variant_indexes - else: - variant_filenames = set() - - non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} - non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} - non_variant_filenames = non_variant_weights | non_variant_indexes - - # all variant filenames will be used by default - usable_filenames = set(variant_filenames) - - def convert_to_variant(filename): - if "index" in filename: - variant_filename = filename.replace("index", f"index.{variant}") - elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: - variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" - else: - variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" - return variant_filename - - for f in non_variant_filenames: - variant_filename = convert_to_variant(f) - if variant_filename not in usable_filenames: - usable_filenames.add(f) - - return usable_filenames, variant_filenames - - -@validate_hf_hub_args -def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): - info = model_info( - pretrained_model_name_or_path, - token=token, - revision=None, - ) - filenames = {sibling.rfilename for sibling in info.siblings} - comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) - comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] - - if set(model_filenames).issubset(set(comp_model_filenames)): - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", - FutureWarning, - ) - else: - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", - FutureWarning, - ) - - -def _unwrap_model(model): - """Unwraps a model.""" - if is_compiled_module(model): - model = model._orig_mod - - if is_peft_available(): - from peft import PeftModel - - if isinstance(model, PeftModel): - model = model.base_model.model - - return model - - -def maybe_raise_or_warn( - library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module -): - """Simple helper method to raise or warn in case incorrect module has been passed""" - if not is_pipeline_module: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - expected_class_obj = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - expected_class_obj = class_candidate - - # Dynamo wraps the original model in a private class. - # I didn't find a public API to get the original class. - sub_model = passed_class_obj[name] - unwrapped_sub_model = _unwrap_model(sub_model) - model_cls = unwrapped_sub_model.__class__ - - if not issubclass(model_cls, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" - ) - else: - logger.warning( - f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" - " has the correct type" - ) - - -def get_class_obj_and_candidates( - library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None -): - """Simple helper method to retrieve class object of module as well as potential parent class objects""" - component_folder = os.path.join(cache_dir, component_name) - - if is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - - class_obj = getattr(pipeline_module, class_name) - class_candidates = {c: class_obj for c in importable_classes.keys()} - elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): - # load custom component - class_obj = get_class_from_dynamic_module( - component_folder, module_file=library_name + ".py", class_name=class_name - ) - class_candidates = {c: class_obj for c in importable_classes.keys()} - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - - class_obj = getattr(library, class_name) - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - return class_obj, class_candidates - - -def _get_pipeline_class( - class_obj, - config=None, - load_connected_pipeline=False, - custom_pipeline=None, - repo_id=None, - hub_revision=None, - class_name=None, - cache_dir=None, - revision=None, -): - if custom_pipeline is not None: - if custom_pipeline.endswith(".py"): - path = Path(custom_pipeline) - # decompose into folder & file - file_name = path.name - custom_pipeline = path.parent.absolute() - elif repo_id is not None: - file_name = f"{custom_pipeline}.py" - custom_pipeline = repo_id - else: - file_name = CUSTOM_PIPELINE_FILE_NAME - - if repo_id is not None and hub_revision is not None: - # if we load the pipeline code from the Hub - # make sure to overwrite the `revision` - revision = hub_revision - - return get_class_from_dynamic_module( - custom_pipeline, - module_file=file_name, - class_name=class_name, - cache_dir=cache_dir, - revision=revision, - ) - - if class_obj != DiffusionPipeline: - return class_obj - - diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) - class_name = class_name or config["_class_name"] - if not class_name: - raise ValueError( - "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." - ) - - class_name = class_name[4:] if class_name.startswith("Flax") else class_name - - pipeline_cls = getattr(diffusers_module, class_name) - - if load_connected_pipeline: - from .auto_pipeline import _get_connected_pipeline - - connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) - if connected_pipeline_cls is not None: - logger.info( - f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" - ) - else: - logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") - - pipeline_cls = connected_pipeline_cls or pipeline_cls - - return pipeline_cls - - -def load_sub_model( - library_name: str, - class_name: str, - importable_classes: List[Any], - pipelines: Any, - is_pipeline_module: bool, - pipeline_class: Any, - torch_dtype: torch.dtype, - provider: Any, - sess_options: Any, - device_map: Optional[Union[Dict[str, torch.device], str]], - max_memory: Optional[Dict[Union[int, str], Union[int, str]]], - offload_folder: Optional[Union[str, os.PathLike]], - offload_state_dict: bool, - model_variants: Dict[str, str], - name: str, - from_flax: bool, - variant: str, - low_cpu_mem_usage: bool, - cached_folder: Union[str, os.PathLike], -): - """Helper method to load the module `name` from `library_name` and `class_name`""" - # retrieve class candidates - class_obj, class_candidates = get_class_obj_and_candidates( - library_name, - class_name, - importable_classes, - pipelines, - is_pipeline_module, - component_name=name, - cache_dir=cached_folder, - ) - - load_method_name = None - # retrieve load method name - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - # if load method name is None, then we have a dummy module -> raise Error - if load_method_name is None: - none_module = class_obj.__module__ - is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( - TRANSFORMERS_DUMMY_MODULES_FOLDER - ) - if is_dummy_path and "dummy" in none_module: - # call class_obj for nice error message of missing requirements - class_obj() - - raise ValueError( - f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" - f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." - ) - - load_method = getattr(class_obj, load_method_name) - - # add kwargs to loading method - diffusers_module = importlib.import_module(__name__.split(".")[0]) - loading_kwargs = {} - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - - is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) - - if is_transformers_available(): - transformers_version = version.parse(version.parse(transformers.__version__).base_version) - else: - transformers_version = "N/A" - - is_transformers_model = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedModel) - and transformers_version >= version.parse("4.20.0") - ) - - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. - # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. - # This makes sure that the weights won't be initialized which significantly speeds up loading. - if is_diffusers_model or is_transformers_model: - loading_kwargs["device_map"] = device_map - loading_kwargs["max_memory"] = max_memory - loading_kwargs["offload_folder"] = offload_folder - loading_kwargs["offload_state_dict"] = offload_state_dict - loading_kwargs["variant"] = model_variants.pop(name, None) - - if from_flax: - loading_kwargs["from_flax"] = True - - # the following can be deleted once the minimum required `transformers` version - # is higher than 4.27 - if ( - is_transformers_model - and loading_kwargs["variant"] is not None - and transformers_version < version.parse("4.27.0") - ): - raise ImportError( - f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" - ) - elif is_transformers_model and loading_kwargs["variant"] is None: - loading_kwargs.pop("variant") - - # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` - if not (from_flax and is_transformers_model): - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - else: - loading_kwargs["low_cpu_mem_usage"] = False - - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder, **loading_kwargs) - - return loaded_sub_model - - -def _fetch_class_library_tuple(module): - # import it here to avoid circular import - diffusers_module = importlib.import_module(__name__.split(".")[0]) - pipelines = getattr(diffusers_module, "pipelines") - - # register the config from the original module, not the dynamo compiled one - not_compiled_module = _unwrap_model(module) - library = not_compiled_module.__module__.split(".")[0] - - # check if the module is a pipeline module - module_path_items = not_compiled_module.__module__.split(".") - pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None - - path = not_compiled_module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if is_pipeline_module: - library = pipeline_dir - elif library not in LOADABLE_CLASSES: - library = not_compiled_module.__module__ - - # retrieve class_name - class_name = not_compiled_module.__class__.__name__ - - return (library, class_name) - - class DiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for all pipelines.