Skip to content

Commit

Permalink
Load tensors directly on device (#1028)
Browse files Browse the repository at this point in the history
* Load tensors directly on device

* Update src/accelerate/utils/modeling.py

Co-authored-by: Zachary Mueller <muellerzr@gmail.com>

---------

Co-authored-by: Zachary Mueller <muellerzr@gmail.com>
  • Loading branch information
sgugger and muellerzr authored Feb 7, 2023
1 parent 5002e56 commit 978dfc3
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
require_huggingface_suite,
require_mps,
require_multi_gpu,
require_safetensors,
require_single_gpu,
require_torch_min_version,
require_tpu,
Expand Down
9 changes: 9 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_comet_ml_available,
is_datasets_available,
is_deepspeed_available,
is_safetensors_available,
is_tensorboard_available,
is_torch_version,
is_tpu_available,
Expand Down Expand Up @@ -128,6 +129,14 @@ def require_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)


def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors installed. These tests are skipped when safetensors isn't
installed
"""
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)


def require_deepspeed(test_case):
"""
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
infer_auto_device_map,
load_checkpoint_in_model,
load_offloaded_weights,
load_state_dict,
named_module_tensors,
retie_parameters,
set_module_tensor_to_device,
Expand Down
56 changes: 55 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@
import torch
import torch.nn as nn

from .imports import is_safetensors_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index


if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file

WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"


Expand Down Expand Up @@ -643,6 +648,55 @@ def check_device_map(model: nn.Module, device_map: Dict[str, Union[int, str, tor
)


def load_state_dict(checkpoint_file, device_map=None):
"""
Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the
weights can be fast-loaded directly on the GPU.
Args:
checkpoint_file (`str`): The path to the checkpoint to load.
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
"""
if checkpoint_file.endswith(".safetensors"):
if not is_safetensors_available():
raise ImportError(
f"To load {checkpoint_file}, the `safetensors` library is necessary `pip install safetensors`."
)
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
weight_names = f.keys()
if metadata.get("format") not in ["pt", "tf", "flax"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
elif metadata["format"] != "pt":
raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.")
if device_map is None:
return safe_load_file(checkpoint_file)
else:
devices = [device for device in device_map.values() if device not in ["disk"]]

# For each device, get the weights that go there
device_weights = {device: [] for device in devices}
for module_name, device in device_map.items():
if device in devices:
device_weights[device].extend([k for k in weight_names if k.startswith(module_name)])
device_weights["cpu"].extend([k for k in weight_names if k not in device_weights])

tensors = {}
for device in devices:
with safe_open(checkpoint_file, framework="pt", device=device) as f:
for key in device_weights[device]:
tensors[key] = f.get_tensor(key)

return tensors
else:
return torch.load(checkpoint_file)


def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
Expand Down Expand Up @@ -737,7 +791,7 @@ def load_checkpoint_in_model(
buffer_names = [name for name, _ in model.named_buffers()]

for checkpoint_file in checkpoint_files:
checkpoint = torch.load(checkpoint_file)
checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
if device_map is None:
model.load_state_dict(checkpoint, strict=False)
else:
Expand Down
20 changes: 19 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn as nn

from accelerate.test_utils import require_cuda, require_multi_gpu
from accelerate.test_utils import require_cuda, require_multi_gpu, require_safetensors
from accelerate.test_utils.testing import require_torch_min_version
from accelerate.utils.modeling import (
check_device_map,
Expand All @@ -30,6 +30,7 @@
get_balanced_memory,
infer_auto_device_map,
load_checkpoint_in_model,
load_state_dict,
named_module_tensors,
set_module_tensor_to_device,
)
Expand Down Expand Up @@ -413,3 +414,20 @@ def test_get_balanced_memory(self):
# If we set a device to 0, it's not counted.
max_memory = get_balanced_memory(model, max_memory={0: 0, 1: 300, 2: 300})
self.assertDictEqual({0: 0, 1: 215, 2: 300}, max_memory)

@require_cuda
@require_safetensors
def test_load_state_dict(self):
from safetensors.torch import save_file

state_dict = {k: torch.randn(4, 5) for k in ["a", "b", "c"]}
device_map = {"a": "cpu", "b": 0, "c": "disk"}
with tempfile.TemporaryDirectory() as tmp_dir:
checkpoint_file = os.path.join(tmp_dir, "model.safetensors")
save_file(state_dict, checkpoint_file, metadata={"format": "pt"})

loaded_state_dict = load_state_dict(checkpoint_file, device_map=device_map)

self.assertEqual(loaded_state_dict["a"].device, torch.device("cpu"))
self.assertEqual(loaded_state_dict["b"].device, torch.device(0))
self.assertEqual(loaded_state_dict["c"].device, torch.device("cpu"))

0 comments on commit 978dfc3

Please sign in to comment.