diff --git a/modeling_utils.py b/modeling_utils.py index f4697636719e..9e05672bf163 100644 --- a/modeling_utils.py +++ b/modeling_utils.py @@ -35,6 +35,12 @@ logger = logging.get_logger(__name__) +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + def get_parameter_device(parameter: torch.nn.Module): try: return next(parameter.parameters()).device @@ -278,11 +284,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - fast_load (`bool`, *optional*, defaults to `True`): + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - this argument will be ignored and the model will be loaded normally. + setting this argument to `True` will raise an error. @@ -311,16 +317,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) - fast_load = kwargs.pop("fast_load", True) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0") + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) - # Fast init is only possible if torch version is >= 1.9.0 - _INIT_EMPTY_WEIGHTS = fast_load or device_map is not None - if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"): - logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.") + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) user_agent = { "diffusers": __version__, @@ -403,7 +419,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # restore default dtype - if _INIT_EMPTY_WEIGHTS: + if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): model, unused_kwargs = cls.from_config( diff --git a/pipeline_utils.py b/pipeline_utils.py index 5c248ec1a956..36c2d5b888ef 100644 --- a/pipeline_utils.py +++ b/pipeline_utils.py @@ -25,6 +25,7 @@ import diffusers import PIL +from accelerate.utils.versions import is_torch_version from huggingface_hub import snapshot_download from packaging import version from PIL import Image @@ -33,6 +34,7 @@ from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import http_user_agent +from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .utils import ( CONFIG_NAME, @@ -328,6 +330,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. specify the folder name here. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the @@ -380,7 +395,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) - fast_load = kwargs.pop("fast_load", True) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -573,17 +606,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") ) - if is_diffusers_model: - loading_kwargs["fast_load"] = fast_load - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. - # To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. - if is_transformers_model and device_map is None: - loading_kwargs["low_cpu_mem_usage"] = fast_load - if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)):