Skip to content

Commit

Permalink
Override FSDP.apply to enable custom weight init
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Mar 20, 2021
1 parent bfcb5a3 commit 46f422f
Showing 1 changed file with 45 additions and 21 deletions.
66 changes: 45 additions & 21 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,44 @@ def __init__(
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance

def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``.
Args:
fn (nn.Module): function to be applied to each submodule
Returns:
Module: self
"""
is_uninitialized = self._is_root is None
with self.summon_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return return_value

@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
self.apply(cast_fn)
for module in self.modules():
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, name, buf)

@property
def params_with_grad(self) -> List[Parameter]:
Expand Down Expand Up @@ -641,6 +674,9 @@ def _reset_lazy_init(self) -> None:
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params:
if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes

def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
Expand All @@ -658,12 +694,14 @@ def _lazy_init(self) -> None:
self._set_is_root()
self._setup_streams()

if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=self.compute_device, dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.buffer_dtype)

if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _all_buffers_to
# applies recursively, we only call this from the root instance.
if self.cpu_offload:
self._all_buffers_to(device=self.compute_device, dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.buffer_dtype)

# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
Expand Down Expand Up @@ -1302,20 +1340,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)


def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device and floating point buffers to dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, key, buf)


def free_storage_(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
Expand Down

0 comments on commit 46f422f

Please sign in to comment.