From fa1b85fbbe75f058b39f1bcf027de42e6ddbd487 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 20 Mar 2021 17:15:30 -0400 Subject: [PATCH] [fix][FSDP] fix weight init when using apply() (fixes #490 and #444) (#543) * Add new test for weight init (fails) * Set FSDP.compute_device so summon_full_params works before module moves to CUDA * Override FSDP.apply to enable custom weight init --- .../fully_sharded_data_parallel.py | 133 +++++++++++++----- tests/ci_test_list_2.txt | 1 + tests/nn/data_parallel/test_fsdp.py | 86 +++++------ tests/nn/data_parallel/test_fsdp_apply.py | 65 +++++++++ 4 files changed, 204 insertions(+), 81 deletions(-) create mode 100644 tests/nn/data_parallel/test_fsdp_apply.py diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index ac3647b4a..5e1e5c0de 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -9,7 +9,7 @@ import functools from math import inf import traceback -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union import torch from torch.autograd import Variable @@ -150,6 +150,11 @@ class FullyShardedDataParallel(nn.Module): based on world_size, so the max shard size is roughly ``bucket_cap_mb / world_size``. Values <= 0 disable bucketing. Default: 25. + compute_device (torch.device, Optional): + device for computation. If not given and module params are on a CUDA + device, the param's device will be used. If not given and module + params are on CPU, then the current CUDA device (as indicated by + ``torch.cuda.current_device()`` will be used. """ def __init__( @@ -165,6 +170,7 @@ def __init__( buffer_dtype: Optional[torch.dtype] = None, move_grads_to_cpu: Optional[bool] = None, bucket_cap_mb: int = 25, + compute_device: Optional[torch.device] = None, ): super().__init__() self.process_group = process_group or dist.new_group() @@ -179,14 +185,21 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb + self.compute_device = compute_device if self.fp32_reduce_scatter and not self.mixed_precision: raise ValueError("fp32_reduce_scatter requires mixed_precision=True") if self.cpu_offload and not self.mixed_precision: raise ValueError("cpu_offload requires mixed_precision=True") - compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device - validate_process_group(compute_device, self.process_group) + if self.compute_device is None: + # Try to infer CUDA device from module parameters. + self.compute_device = next(module.parameters()).device + if self.compute_device.type != "cuda": + # Fall back to current CUDA device. + self.compute_device = torch.device("cuda") + + validate_process_group(self.compute_device, self.process_group) enable_pytorch_sync_bn(module) # Only handle params which are not already sharded. This enables @@ -239,11 +252,68 @@ def __init__( def module(self) -> nn.Module: return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance - @torch.no_grad() - def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: - """Move all buffers to the specified device and dtype, recursively.""" - cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype) - self.apply(cast_fn) + def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": + """ + Applies ``fn`` recursively to every submodule (as returned by + ``.children()``) as well as self. Typical use includes initializing the + parameters of a model. + + Compared to ``torch.nn.Module.apply``, this version additionally gathers + the full parameters before applying ``fn``. It should not be called from + within another ``summon_full_params`` context. + + Args: + fn (nn.Module): function to be applied to each submodule + + Returns: + Module: self + """ + is_uninitialized = self._is_root is None + self.assert_state(TrainingState.IDLE) + with self.summon_full_params(recurse=False): + return_value = super().apply(fn) + # summon_full_params will call _lazy_init, which sets _is_root. However, + # apply() may be called directly on children instances to do weight + # init, so we should reset the _is_root flag in this case. + if is_uninitialized and self._is_root: + for module in self.modules(): + if isinstance(module, FullyShardedDataParallel): + module._reset_lazy_init() + return return_value + + def _cast_buffers( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None + ) -> None: + """Move all buffers to the given *device* and *dtype*. + + If *device* or *dtype* are not given, then they will default to + ``self.compute_device`` and ``self.buffer_dtype``, respectively. In the + case of nested FSDP instances, we will respect the child instance's + ``compute_device`` and ``buffer_dtype`` configuration. + + Args: + device (torch.device, Optional): + device to cast buffers to (defaults to compute_device) + dtype (torch.dtype, Optional): + dtype to cast buffers to (defaults to buffer_dtype) + memo (Set, Optional): + set of modules that have already been processed + """ + if memo is None: + memo = set() + for module in self.modules(): + if module is not self and isinstance(module, FullyShardedDataParallel): + # Allow any child FSDP instances to handle their own buffers. + module._cast_buffers(device=device, dtype=dtype, memo=memo) + elif module not in memo: + memo.add(module) + for name, buf in module.named_buffers(recurse=False): + if buf is None: + continue + buf = buf.to(device=device or self.compute_device) + if torch.is_floating_point(buf): + buf = buf.to(dtype=dtype or self.buffer_dtype) + setattr(module, name, buf) @property def params_with_grad(self) -> List[Parameter]: @@ -386,7 +456,10 @@ def extra_repr(self) -> str: f"flatten_parameters={self.flatten_parameters}, " f"cpu_offload={self.cpu_offload}, " f"compute_dtype={self.compute_dtype}, " - f"move_grads_to_cpu={self.move_grads_to_cpu}" + f"buffer_dtype={self.buffer_dtype}, " + f"move_grads_to_cpu={self.move_grads_to_cpu}, " + f"bucket_cap_mb={self.bucket_cap_mb}, " + f"compute_device={self.compute_device}" ) def __getattr__(self, name: str) -> Any: @@ -443,7 +516,7 @@ def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tenso self._lazy_init() if self.mixed_precision: # Buffers dtype stays consistent with parameters. - self._all_buffers_to(dtype=torch.float32) + self._cast_buffers(dtype=torch.float32) if self._return_full_state_dict: if self.training_state != TrainingState.SUMMON_FULL_PARAMS: @@ -463,8 +536,8 @@ def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tenso state_dict[k] = state_dict[k].cpu() if self.mixed_precision: - # In case we are in mixed precision, restore buffers back to fp16. - self._all_buffers_to(dtype=self.buffer_dtype) + # In case we are in mixed precision, restore buffers back to buffer_dtype. + self._cast_buffers() return state_dict # TODO (Min): figuring out how to do typing for this overloaded function. @@ -572,7 +645,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge recurse (bool, Optional): recursively summon all params for nested FSDP instances (default: True) volatile (bool, Optional): if ``True``, modifications to params are - not guaranteed persist after the context manager exists; + not guaranteed to persist after the context manager exists; enabling this can be slightly more efficient (default: False) """ if recurse: @@ -625,6 +698,9 @@ def _reset_lazy_init(self) -> None: self._queue_wait_for_post_backward_closure: Optional[Callable] = None self._streams: Dict[str, torch.cuda.Stream] = {} self._reducer: Optional[ReduceScatterBucketer] = None + for p in self.params: + if hasattr(p, "_fp32_shard"): + del p._fp32_shard # reset _init_param_attributes def _lazy_init(self) -> None: """Initialization steps that should happen lazily, typically right @@ -642,12 +718,11 @@ def _lazy_init(self) -> None: self._set_is_root() self._setup_streams() - if self.cpu_offload: # Buffers stay on GPU, and don't get sharded - self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype) - else: - self._all_buffers_to(dtype=self.buffer_dtype) - if self._is_root: + # Buffers stay on GPU, and don't get sharded. Since _cast_buffers + # applies recursively, we only call this from the root instance. + self._cast_buffers() + # Don't free the full params for the outer-most (root) instance, # since those params will be needed immediately after for the # backward pass. @@ -684,10 +759,6 @@ def _init_param_attributes(self, p: Parameter) -> None: if hasattr(p, "_fp32_shard"): return - # Compute device defaults to CUDA when *cpu_offload* is enabled, or the - # param's current device otherwise (could be CPU). - compute_device = torch.device("cuda") if self.cpu_offload else p.device - # A single shard of the parameters in full precision. p._fp32_shard = p.data @@ -707,7 +778,7 @@ def _init_param_attributes(self, p: Parameter) -> None: # the computation in the forward/backward pass. We resize the # storage to size 0 at init (here) and re-materialize (by copying # from _fp32_shard) as needed. - p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype) + p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) free_storage_(p._fp16_shard) else: p._fp16_shard = None # use _fp32_shard @@ -720,7 +791,7 @@ def _init_param_attributes(self, p: Parameter) -> None: # relevant computation. if p._is_sharded: p._full_param_padded = torch.zeros( - p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype + p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype ) free_storage_(p._full_param_padded) @@ -1290,20 +1361,6 @@ def fn(x: torch.Tensor) -> torch.Tensor: return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs) -def cast_buffers_( - module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None -) -> None: - """Cast all of module.named_buffers to device and floating point buffers to dtype.""" - # if buffers are already on the right device and/or dtype this is just python loop cost - assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16 - for key, buf in module.named_buffers(recurse=False): - if buf is not None: - buf = buf.to(device=device) - if torch.is_floating_point(buf): - buf = buf.to(dtype=dtype) - setattr(module, key, buf) - - def free_storage_(data: torch.Tensor) -> None: """Free underlying storage of a Tensor.""" if data.storage().size() > 0: diff --git a/tests/ci_test_list_2.txt b/tests/ci_test_list_2.txt index 25d0ee0b3..7aaf33400 100644 --- a/tests/ci_test_list_2.txt +++ b/tests/ci_test_list_2.txt @@ -33,3 +33,4 @@ tests/nn/pipe/test_deferred_batch_norm.py tests/nn/pipe/test_dependency.py tests/nn/pipe/test_stream.py tests/experimental/nn/test_multiprocess_pipe.py +tests/nn/data_parallel/test_fsdp_apply.py diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 9813d1142..8e5082f8d 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -77,6 +77,49 @@ def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> Ful model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda() return model + @classmethod + def _test_identical_outputs( + cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, + ): + if config.get("mixed_precision", False): + autocast = True + # Force the compute dtype to be torch.float32 so that we get + # identical results as PyTorch DDP when using autocast. Note that + # this will cause the all-gather to happen in FP32, which is slower + # than necessary in most cases. + config["compute_dtype"] = torch.float32 + else: + autocast = False + + # Establish reference behavior with PyTorch DDP (+ optionally autocast). + model = model_init_fn(group=group, wrapper_config=None).cuda() + if ref_ddp_fn is None: + model = nn.parallel.DistributedDataParallel( + model, device_ids=[rank], output_device=rank, process_group=group + ) + else: + model = ref_ddp_fn(model, group) + ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) + ref_state_dict = model.module.state_dict() + if config.get("cpu_offload", False): + for k in ref_state_dict.keys(): + ref_state_dict[k] = ref_state_dict[k].cpu() + + # Confirm we get the same behavior using FullyShardedDataParallel. + model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) + if use_cuda: + model = model.cuda() + else: + assert next(model.parameters()).device == torch.device("cpu") + shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) + shard_state_dict = model.state_dict() + + try: + torch.testing.assert_allclose(ref_loss, shard_loss) + assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) + except (AssertionError, RuntimeError) as e: + raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") + class TestMixedPrecision(DistributedTest): def test_all_fp32(self): @@ -313,49 +356,6 @@ def test_mixture_of_experts_grad_clip_breaks(self): def _dummy_ddp_fn(self, model, group): return DummyDDP(model) - @classmethod - def _test_identical_outputs( - cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, - ): - if config.get("mixed_precision", False): - autocast = True - # Force the compute dtype to be torch.float32 so that we get - # identical results as PyTorch DDP when using autocast. Note that - # this will cause the all-gather to happen in FP32, which is slower - # than necessary in most cases. - config["compute_dtype"] = torch.float32 - else: - autocast = False - - # Establish reference behavior with PyTorch DDP (+ optionally autocast). - model = model_init_fn(group=group, wrapper_config=None).cuda() - if ref_ddp_fn is None: - model = nn.parallel.DistributedDataParallel( - model, device_ids=[rank], output_device=rank, process_group=group - ) - else: - model = ref_ddp_fn(model, group) - ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) - ref_state_dict = model.module.state_dict() - if config.get("cpu_offload", False): - for k in ref_state_dict.keys(): - ref_state_dict[k] = ref_state_dict[k].cpu() - - # Confirm we get the same behavior using FullyShardedDataParallel. - model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) - if use_cuda: - model = model.cuda() - else: - assert next(model.parameters()).device == torch.device("cpu") - shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) - shard_state_dict = model.state_dict() - - try: - torch.testing.assert_allclose(ref_loss, shard_loss) - assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) - except (AssertionError, RuntimeError) as e: - raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") - @parameterized.expand([[1], [inf]], name_func=rename_test) def test_clip_norm_transformer(self, norm_type): config = {"mixed_precision": True} diff --git a/tests/nn/data_parallel/test_fsdp_apply.py b/tests/nn/data_parallel/test_fsdp_apply.py new file mode 100644 index 000000000..7d5c8ff98 --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp_apply.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import unittest + +from parameterized import parameterized +import torch.nn as nn + +from .test_fsdp import ( + CONFIG_OPTIONS, + DistributedTest, + NestedWrappedModule, + TransformerWithSharedParams, + rename_test, + spawn_and_init, +) + + +class TestApply(DistributedTest): + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_transformer_weight_init(self, config): + model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, TransformerWithSharedParams) + test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01) + spawn_and_init(test_fn) + + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_nested_wrapped_weight_init(self, config): + model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, NestedWrappedModule) + test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01) + spawn_and_init(test_fn) + + +def model_init_and_apply_custom_weight_init(model_init_fn, *args, **kwargs): + model = model_init_fn(*args, **kwargs) + model.apply(init_bert_params_) + return model + + +def init_bert_params_(module): + """ + Initialize the weights specific to the BERT Model. + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, nn.MultiheadAttention): + normal_(module.in_proj_weight.data) + + +if __name__ == "__main__": + unittest.main()