Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load tensors directly on device #1028

Merged
merged 2 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
require_huggingface_suite,
require_mps,
require_multi_gpu,
require_safetensors,
require_single_gpu,
require_torch_min_version,
require_tpu,
Expand Down
9 changes: 9 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_comet_ml_available,
is_datasets_available,
is_deepspeed_available,
is_safetensors_available,
is_tensorboard_available,
is_torch_version,
is_tpu_available,
Expand Down Expand Up @@ -128,6 +129,14 @@ def require_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)


def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors installed. These tests are skipped when safetensors isn't
installed
"""
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)


def require_deepspeed(test_case):
"""
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
infer_auto_device_map,
load_checkpoint_in_model,
load_offloaded_weights,
load_state_dict,
named_module_tensors,
retie_parameters,
set_module_tensor_to_device,
Expand Down
56 changes: 55 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@
import torch
import torch.nn as nn

from .imports import is_safetensors_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index


if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file

WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"


Expand Down Expand Up @@ -643,6 +648,55 @@ def check_device_map(model: nn.Module, device_map: Dict[str, Union[int, str, tor
)


def load_state_dict(checkpoint_file, device_map=None):
"""
Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the
weights can be fast-loaded directly on the GPU.

Args:
checkpoint_file (`str`): The path to the checkpoint to load.
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
"""
if checkpoint_file.endswith(".safetensors"):
if not is_safetensors_available():
raise ImportError(
f"To load {checkpoint_file}, the `safetensors` library is necessary `pip install safetensors`."
)
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
weight_names = f.keys()
if metadata.get("format") not in ["pt", "tf", "flax"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
elif metadata["format"] != "pt":
raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.")
if device_map is None:
return safe_load_file(checkpoint_file)
else:
devices = [device for device in device_map.values() if device not in ["disk"]]

# For each device, get the weights that go there
device_weights = {device: [] for device in devices}
for module_name, device in device_map.items():
if device in devices:
device_weights[device].extend([k for k in weight_names if k.startswith(module_name)])
device_weights["cpu"].extend([k for k in weight_names if k not in device_weights])

tensors = {}
for device in devices:
with safe_open(checkpoint_file, framework="pt", device=device) as f:
for key in device_weights[device]:
tensors[key] = f.get_tensor(key)

return tensors
else:
return torch.load(checkpoint_file)


def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
Expand Down Expand Up @@ -737,7 +791,7 @@ def load_checkpoint_in_model(
buffer_names = [name for name, _ in model.named_buffers()]

for checkpoint_file in checkpoint_files:
checkpoint = torch.load(checkpoint_file)
checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
if device_map is None:
model.load_state_dict(checkpoint, strict=False)
else:
Expand Down
20 changes: 19 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn as nn

from accelerate.test_utils import require_cuda, require_multi_gpu
from accelerate.test_utils import require_cuda, require_multi_gpu, require_safetensors
from accelerate.test_utils.testing import require_torch_min_version
from accelerate.utils.modeling import (
check_device_map,
Expand All @@ -30,6 +30,7 @@
get_balanced_memory,
infer_auto_device_map,
load_checkpoint_in_model,
load_state_dict,
named_module_tensors,
set_module_tensor_to_device,
)
Expand Down Expand Up @@ -413,3 +414,20 @@ def test_get_balanced_memory(self):
# If we set a device to 0, it's not counted.
max_memory = get_balanced_memory(model, max_memory={0: 0, 1: 300, 2: 300})
self.assertDictEqual({0: 0, 1: 215, 2: 300}, max_memory)

@require_cuda
@require_safetensors
def test_load_state_dict(self):
from safetensors.torch import save_file

state_dict = {k: torch.randn(4, 5) for k in ["a", "b", "c"]}
device_map = {"a": "cpu", "b": 0, "c": "disk"}
with tempfile.TemporaryDirectory() as tmp_dir:
checkpoint_file = os.path.join(tmp_dir, "model.safetensors")
save_file(state_dict, checkpoint_file, metadata={"format": "pt"})

loaded_state_dict = load_state_dict(checkpoint_file, device_map=device_map)

self.assertEqual(loaded_state_dict["a"].device, torch.device("cpu"))
self.assertEqual(loaded_state_dict["b"].device, torch.device(0))
self.assertEqual(loaded_state_dict["c"].device, torch.device("cpu"))