From 52c05bd4cd583ae4f07b5856dc25ba6c56e74ebf Mon Sep 17 00:00:00 2001 From: Daniel Hipke Date: Fri, 10 Jan 2025 02:11:04 -0800 Subject: [PATCH] Add a `disable_mmap` option to the `from_single_file` loader to improve load performance on network mounts (#10305) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add no_mmap arg. * Fix arg parsing. * Update another method to force no mmap. * logging * logging2 * propagate no_mmap * logging3 * propagate no_mmap * logging4 * fix open call * clean up logging * cleanup * fix missing arg * update logging and comments * Rename to disable_mmap and update other references. * [Docs] Update ltx_video.md to remove generator from `from_pretrained()` (#10316) Update ltx_video.md to remove generator from `from_pretrained()` * docs: fix a mistake in docstring (#10319) Update pipeline_hunyuan_video.py docs: fix a mistake * [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306) [BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float" torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor. in function prepare_latents: audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) ... audio = initial_audio_waveforms.new_zeros(audio_shape) audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float Co-authored-by: hlky * [docs] Fix quantization links (#10323) Update overview.md * [Sana]add 2K related model for Sana (#10322) add 2K related model for Sana * Update src/diffusers/loaders/single_file_model.py Co-authored-by: Dhruv Nair * Update src/diffusers/loaders/single_file.py Co-authored-by: Dhruv Nair * make style --------- Co-authored-by: hlky Co-authored-by: Sayak Paul Co-authored-by: Leojc Co-authored-by: Aditya Raj Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Junsong Chen Co-authored-by: Dhruv Nair --- src/diffusers/loaders/single_file.py | 8 ++++++++ src/diffusers/loaders/single_file_model.py | 5 +++++ src/diffusers/loaders/single_file_utils.py | 3 ++- src/diffusers/models/model_loading_utils.py | 9 +++++++-- src/diffusers/models/modeling_utils.py | 8 ++++++-- 5 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index c5c9bea29b8a..007332f73409 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -60,6 +60,7 @@ def load_single_file_sub_model( local_files_only=False, torch_dtype=None, is_legacy_loading=False, + disable_mmap=False, **kwargs, ): if is_pipeline_module: @@ -106,6 +107,7 @@ def load_single_file_sub_model( subfolder=name, torch_dtype=torch_dtype, local_files_only=local_files_only, + disable_mmap=disable_mmap, **kwargs, ) @@ -308,6 +310,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component configs in Diffusers format. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example @@ -355,6 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + disable_mmap = kwargs.pop("disable_mmap", False) is_legacy_loading = False @@ -383,6 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, + disable_mmap=disable_mmap, ) if config is None: @@ -504,6 +511,7 @@ def load_module(name, value): original_config=original_config, local_files_only=local_files_only, is_legacy_loading=is_legacy_loading, + disable_mmap=disable_mmap, **kwargs, ) except SingleFileComponentError as e: diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b65069e60d50..69ab8b6bad20 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -187,6 +187,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (for example the pipeline components of the specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` @@ -234,6 +237,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) device = kwargs.pop("device", None) + disable_mmap = kwargs.pop("disable_mmap", False) if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict @@ -246,6 +250,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, + disable_mmap=disable_mmap, ) if quantization_config is not None: hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index cefba48275cf..b2b21675054c 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -387,6 +387,7 @@ def load_single_file_checkpoint( cache_dir=None, local_files_only=None, revision=None, + disable_mmap=False, ): if os.path.isfile(pretrained_model_link_or_path): pretrained_model_link_or_path = pretrained_model_link_or_path @@ -404,7 +405,7 @@ def load_single_file_checkpoint( revision=revision, ) - checkpoint = load_state_dict(pretrained_model_link_or_path) + checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap) # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5f5ea2351709..a3d006f18994 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class): return old_class -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False +): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ @@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: - return safetensors.torch.load_file(checkpoint_file, device="cpu") + if disable_mmap: + return safetensors.torch.load(open(checkpoint_file, "rb").read()) + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") elif file_extension == GGUF_FILE_EXTENSION: return load_gguf_checkpoint(checkpoint_file) else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 789aeccf9b7f..17e9d2043150 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) quantization_config = kwargs.pop("quantization_config", None) + disable_mmap = kwargs.pop("disable_mmap", False) allow_pickle = False if use_safetensors is None: @@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # TODO (sayakpaul, SunMarc): remove this after model loading refactor else: param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict(model_file, variant=variant) + state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu @@ -979,7 +983,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant) + state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(