From a31d5dd47a30b47fc5b6d19c709614e4d25e17f4 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 20 Mar 2021 05:12:20 -0700 Subject: [PATCH] Override FSDP.apply to enable custom weight init --- .../fully_sharded_data_parallel.py | 66 +++++++++++++------ 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index d81925a95..a3d9c86d6 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -252,11 +252,44 @@ def __init__( def module(self) -> nn.Module: return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance + def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": + """ + Applies ``fn`` recursively to every submodule (as returned by + ``.children()``) as well as self. Typical use includes initializing the + parameters of a model. + + Compared to ``torch.nn.Module.apply``, this version additionally gathers + the full parameters before applying ``fn``. + + Args: + fn (nn.Module): function to be applied to each submodule + + Returns: + Module: self + """ + is_uninitialized = self._is_root is None + with self.summon_full_params(recurse=False): + return_value = super().apply(fn) + # summon_full_params will call _lazy_init, which sets _is_root. However, + # apply() may be called directly on children instances to do weight + # init, so we should reset the _is_root flag in this case. + if is_uninitialized and self._is_root: + for module in self.modules(): + if isinstance(module, FullyShardedDataParallel): + module._reset_lazy_init() + return return_value + @torch.no_grad() def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: """Move all buffers to the specified device and dtype, recursively.""" - cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype) - self.apply(cast_fn) + for module in self.modules(): + for name, buf in module.named_buffers(recurse=False): + if buf is None: + continue + buf = buf.to(device=device) + if torch.is_floating_point(buf): + buf = buf.to(dtype=dtype) + setattr(module, name, buf) @property def params_with_grad(self) -> List[Parameter]: @@ -641,6 +674,9 @@ def _reset_lazy_init(self) -> None: self._queue_wait_for_post_backward_closure: Optional[Callable] = None self._streams: Dict[str, torch.cuda.Stream] = {} self._reducer: Optional[ReduceScatterBucketer] = None + for p in self.params: + if hasattr(p, "_fp32_shard"): + del p._fp32_shard # reset _init_param_attributes def _lazy_init(self) -> None: """Initialization steps that should happen lazily, typically right @@ -658,12 +694,14 @@ def _lazy_init(self) -> None: self._set_is_root() self._setup_streams() - if self.cpu_offload: # Buffers stay on GPU, and don't get sharded - self._all_buffers_to(device=self.compute_device, dtype=self.buffer_dtype) - else: - self._all_buffers_to(dtype=self.buffer_dtype) - if self._is_root: + # Buffers stay on GPU, and don't get sharded. Since _all_buffers_to + # applies recursively, we only call this from the root instance. + if self.cpu_offload: + self._all_buffers_to(device=self.compute_device, dtype=self.buffer_dtype) + else: + self._all_buffers_to(dtype=self.buffer_dtype) + # Don't free the full params for the outer-most (root) instance, # since those params will be needed immediately after for the # backward pass. @@ -1302,20 +1340,6 @@ def fn(x: torch.Tensor) -> torch.Tensor: return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs) -def cast_buffers_( - module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None -) -> None: - """Cast all of module.named_buffers to device and floating point buffers to dtype.""" - # if buffers are already on the right device and/or dtype this is just python loop cost - assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16 - for key, buf in module.named_buffers(recurse=False): - if buf is not None: - buf = buf.to(device=device) - if torch.is_floating_point(buf): - buf = buf.to(dtype=dtype) - setattr(module, key, buf) - - def free_storage_(data: torch.Tensor) -> None: """Free underlying storage of a Tensor.""" if data.storage().size() > 0: