Skip to content

Commit

Permalink
[fix][FSDP] fix weight init when using apply() (fixes #490 and #444) (#…
Browse files Browse the repository at this point in the history
…543)

* Add new test for weight init (fails)
* Set FSDP.compute_device so summon_full_params works before module moves to CUDA
* Override FSDP.apply to enable custom weight init
  • Loading branch information
myleott authored Mar 20, 2021
1 parent e386554 commit fa1b85f
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 81 deletions.
133 changes: 95 additions & 38 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import functools
from math import inf
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union

import torch
from torch.autograd import Variable
Expand Down Expand Up @@ -150,6 +150,11 @@ class FullyShardedDataParallel(nn.Module):
based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
Default: 25.
compute_device (torch.device, Optional):
device for computation. If not given and module params are on a CUDA
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
"""

def __init__(
Expand All @@ -165,6 +170,7 @@ def __init__(
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
):
super().__init__()
self.process_group = process_group or dist.new_group()
Expand All @@ -179,14 +185,21 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device

if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")

compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device
validate_process_group(compute_device, self.process_group)
if self.compute_device is None:
# Try to infer CUDA device from module parameters.
self.compute_device = next(module.parameters()).device
if self.compute_device.type != "cuda":
# Fall back to current CUDA device.
self.compute_device = torch.device("cuda")

validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module)

# Only handle params which are not already sharded. This enables
Expand Down Expand Up @@ -239,11 +252,68 @@ def __init__(
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance

@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
self.apply(cast_fn)
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (nn.Module): function to be applied to each submodule
Returns:
Module: self
"""
is_uninitialized = self._is_root is None
self.assert_state(TrainingState.IDLE)
with self.summon_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return return_value

def _cast_buffers(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
) -> None:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
case of nested FSDP instances, we will respect the child instance's
``compute_device`` and ``buffer_dtype`` configuration.
Args:
device (torch.device, Optional):
device to cast buffers to (defaults to compute_device)
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
"""
if memo is None:
memo = set()
for module in self.modules():
if module is not self and isinstance(module, FullyShardedDataParallel):
# Allow any child FSDP instances to handle their own buffers.
module._cast_buffers(device=device, dtype=dtype, memo=memo)
elif module not in memo:
memo.add(module)
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device or self.compute_device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype or self.buffer_dtype)
setattr(module, name, buf)

@property
def params_with_grad(self) -> List[Parameter]:
Expand Down Expand Up @@ -386,7 +456,10 @@ def extra_repr(self) -> str:
f"flatten_parameters={self.flatten_parameters}, "
f"cpu_offload={self.cpu_offload}, "
f"compute_dtype={self.compute_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}"
f"buffer_dtype={self.buffer_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"compute_device={self.compute_device}"
)

def __getattr__(self, name: str) -> Any:
Expand Down Expand Up @@ -443,7 +516,7 @@ def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tenso
self._lazy_init()
if self.mixed_precision:
# Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32)
self._cast_buffers(dtype=torch.float32)

if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
Expand All @@ -463,8 +536,8 @@ def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tenso
state_dict[k] = state_dict[k].cpu()

if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.buffer_dtype)
# In case we are in mixed precision, restore buffers back to buffer_dtype.
self._cast_buffers()
return state_dict

# TODO (Min): figuring out how to do typing for this overloaded function.
Expand Down Expand Up @@ -572,7 +645,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed persist after the context manager exists;
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
Expand Down Expand Up @@ -625,6 +698,9 @@ def _reset_lazy_init(self) -> None:
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params:
if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes

def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
Expand All @@ -642,12 +718,11 @@ def _lazy_init(self) -> None:
self._set_is_root()
self._setup_streams()

if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.buffer_dtype)

if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers
# applies recursively, we only call this from the root instance.
self._cast_buffers()

# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
Expand Down Expand Up @@ -684,10 +759,6 @@ def _init_param_attributes(self, p: Parameter) -> None:
if hasattr(p, "_fp32_shard"):
return

# Compute device defaults to CUDA when *cpu_offload* is enabled, or the
# param's current device otherwise (could be CPU).
compute_device = torch.device("cuda") if self.cpu_offload else p.device

# A single shard of the parameters in full precision.
p._fp32_shard = p.data

Expand All @@ -707,7 +778,7 @@ def _init_param_attributes(self, p: Parameter) -> None:
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype)
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard
Expand All @@ -720,7 +791,7 @@ def _init_param_attributes(self, p: Parameter) -> None:
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)

Expand Down Expand Up @@ -1290,20 +1361,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)


def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device and floating point buffers to dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, key, buf)


def free_storage_(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/data_parallel/test_fsdp_apply.py
86 changes: 43 additions & 43 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,49 @@ def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> Ful
model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda()
return model

@classmethod
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False

# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()

# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")


class TestMixedPrecision(DistributedTest):
def test_all_fp32(self):
Expand Down Expand Up @@ -313,49 +356,6 @@ def test_mixture_of_experts_grad_clip_breaks(self):
def _dummy_ddp_fn(self, model, group):
return DummyDDP(model)

@classmethod
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False

# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()

# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")

@parameterized.expand([[1], [inf]], name_func=rename_test)
def test_clip_norm_transformer(self, norm_type):
config = {"mixed_precision": True}
Expand Down
Loading

0 comments on commit fa1b85f

Please sign in to comment.