Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][FSDP] fix weight init when using apply() (fixes #490 and #444) #543

Merged
merged 6 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 67 additions & 31 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ class FullyShardedDataParallel(nn.Module):
based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
Default: 25.
compute_device (torch.device, Optional):
device for computation. If not given and module params are on a CUDA
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
"""

def __init__(
Expand All @@ -165,6 +170,7 @@ def __init__(
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
):
super().__init__()
self.process_group = process_group or dist.new_group()
Expand All @@ -179,14 +185,21 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device

if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")

compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device
validate_process_group(compute_device, self.process_group)
if self.compute_device is None:
# Try to infer CUDA device from module parameters.
self.compute_device = next(module.parameters()).device
if self.compute_device.type != "cuda":
# Fall back to current CUDA device.
self.compute_device = torch.device("cuda")

validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module)

# Only handle params which are not already sharded. This enables
Expand Down Expand Up @@ -239,11 +252,44 @@ def __init__(
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance

def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.

Compared to ``torch.nn.Module.apply``, this version additionally gathers
myleott marked this conversation as resolved.
Show resolved Hide resolved
the full parameters before applying ``fn``.

Args:
fn (nn.Module): function to be applied to each submodule

Returns:
Module: self
"""
is_uninitialized = self._is_root is None
with self.summon_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
myleott marked this conversation as resolved.
Show resolved Hide resolved
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return return_value

@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
self.apply(cast_fn)
for module in self.modules():
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, name, buf)
myleott marked this conversation as resolved.
Show resolved Hide resolved

@property
def params_with_grad(self) -> List[Parameter]:
Expand Down Expand Up @@ -386,7 +432,10 @@ def extra_repr(self) -> str:
f"flatten_parameters={self.flatten_parameters}, "
f"cpu_offload={self.cpu_offload}, "
f"compute_dtype={self.compute_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}"
f"buffer_dtype={self.buffer_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"compute_device={self.compute_device}"
)

def __getattr__(self, name: str) -> Any:
Expand Down Expand Up @@ -572,7 +621,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed persist after the context manager exists;
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
Expand Down Expand Up @@ -625,6 +674,9 @@ def _reset_lazy_init(self) -> None:
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params:
if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes

def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
Expand All @@ -642,12 +694,14 @@ def _lazy_init(self) -> None:
self._set_is_root()
self._setup_streams()

if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.buffer_dtype)

if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _all_buffers_to
# applies recursively, we only call this from the root instance.
if self.cpu_offload:
self._all_buffers_to(device=self.compute_device, dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.buffer_dtype)

# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
Expand Down Expand Up @@ -684,10 +738,6 @@ def _init_param_attributes(self, p: Parameter) -> None:
if hasattr(p, "_fp32_shard"):
return

# Compute device defaults to CUDA when *cpu_offload* is enabled, or the
# param's current device otherwise (could be CPU).
compute_device = torch.device("cuda") if self.cpu_offload else p.device

# A single shard of the parameters in full precision.
p._fp32_shard = p.data

Expand All @@ -707,7 +757,7 @@ def _init_param_attributes(self, p: Parameter) -> None:
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype)
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard
Expand All @@ -720,7 +770,7 @@ def _init_param_attributes(self, p: Parameter) -> None:
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)

Expand Down Expand Up @@ -1290,20 +1340,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)


def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device and floating point buffers to dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, key, buf)


def free_storage_(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/data_parallel/test_fsdp_apply.py
myleott marked this conversation as resolved.
Show resolved Hide resolved
86 changes: 43 additions & 43 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,49 @@ def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> Ful
model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda()
return model

@classmethod
myleott marked this conversation as resolved.
Show resolved Hide resolved
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False

# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()

# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")


class TestMixedPrecision(DistributedTest):
def test_all_fp32(self):
Expand Down Expand Up @@ -313,49 +356,6 @@ def test_mixture_of_experts_grad_clip_breaks(self):
def _dummy_ddp_fn(self, model, group):
return DummyDDP(model)

@classmethod
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False

# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()

# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")

@parameterized.expand([[1], [inf]], name_func=rename_test)
def test_clip_norm_transformer(self, norm_type):
config = {"mixed_precision": True}
Expand Down
65 changes: 65 additions & 0 deletions tests/nn/data_parallel/test_fsdp_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
myleott marked this conversation as resolved.
Show resolved Hide resolved
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import unittest

from parameterized import parameterized
import torch.nn as nn

from .test_fsdp import (
CONFIG_OPTIONS,
DistributedTest,
NestedWrappedModule,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
)


class TestApply(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_weight_init(self, config):
model_init_fn = functools.partial(apply_custom_weight_init, TransformerWithSharedParams)
myleott marked this conversation as resolved.
Show resolved Hide resolved
test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
spawn_and_init(test_fn)

@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_weight_init(self, config):
model_init_fn = functools.partial(apply_custom_weight_init, NestedWrappedModule)
test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
spawn_and_init(test_fn)


def apply_custom_weight_init(model_init_fn, *args, **kwargs):
myleott marked this conversation as resolved.
Show resolved Hide resolved
model = model_init_fn(*args, **kwargs)
model.apply(init_bert_params_)
return model


def init_bert_params_(module):
"""
Initialize the weights specific to the BERT Model.
"""

def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02))

if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, nn.MultiheadAttention):
normal_(module.in_proj_weight.data)


if __name__ == "__main__":
unittest.main()