Skip to content

Commit

Permalink
Add a disable_mmap option to the from_single_file loader to impro…
Browse files Browse the repository at this point in the history
…ve load performance on network mounts (#10305)

* Add no_mmap arg.

* Fix arg parsing.

* Update another method to force no mmap.

* logging

* logging2

* propagate no_mmap

* logging3

* propagate no_mmap

* logging4

* fix open call

* clean up logging

* cleanup

* fix missing arg

* update logging and comments

* Rename to disable_mmap and update other references.

* [Docs] Update ltx_video.md to remove generator from `from_pretrained()` (#10316)

Update ltx_video.md to remove generator from `from_pretrained()`

* docs: fix a mistake in docstring (#10319)

Update pipeline_hunyuan_video.py

docs: fix a mistake

* [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306)

[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float

Co-authored-by: hlky <hlky@hlky.ac>

* [docs] Fix quantization links (#10323)

Update overview.md

* [Sana]add 2K related model for Sana (#10322)

add 2K related model for Sana

* Update src/diffusers/loaders/single_file_model.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/loaders/single_file.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* make style

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Leojc <liao_junchao@outlook.com>
Co-authored-by: Aditya Raj <syntaxticsugr@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Junsong Chen <cjs1020440147@icloud.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
  • Loading branch information
8 people authored Jan 10, 2025
1 parent a6f043a commit 52c05bd
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def load_single_file_sub_model(
local_files_only=False,
torch_dtype=None,
is_legacy_loading=False,
disable_mmap=False,
**kwargs,
):
if is_pipeline_module:
Expand Down Expand Up @@ -106,6 +107,7 @@ def load_single_file_sub_model(
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
disable_mmap=disable_mmap,
**kwargs,
)

Expand Down Expand Up @@ -308,6 +310,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
Expand Down Expand Up @@ -355,6 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)

is_legacy_loading = False

Expand Down Expand Up @@ -383,6 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)

if config is None:
Expand Down Expand Up @@ -504,6 +511,7 @@ def load_module(name, value):
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
disable_mmap=disable_mmap,
**kwargs,
)
except SingleFileComponentError as e:
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
Expand Down Expand Up @@ -234,6 +237,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
Expand All @@ -246,6 +250,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def load_single_file_checkpoint(
cache_dir=None,
local_files_only=None,
revision=None,
disable_mmap=False,
):
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
Expand All @@ -404,7 +405,7 @@ def load_single_file_checkpoint(
revision=revision,
)

checkpoint = load_state_dict(pretrained_model_link_or_path)
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)

# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class


def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
Expand All @@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
<Tip>
Expand Down Expand Up @@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
disable_mmap = kwargs.pop("disable_mmap", False)

allow_pickle = False
if use_safetensors is None:
Expand Down Expand Up @@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict)

# move the params from meta device to cpu
Expand Down Expand Up @@ -979,7 +983,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
Expand Down

0 comments on commit 52c05bd

Please sign in to comment.