Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][feature] optimizer state dict save and load #537

Merged
merged 33 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ee088bb
consolidate works
sshleifer Mar 19, 2021
ad7df24
cat
sshleifer Mar 19, 2021
ed7526a
Unpad before cat
sshleifer Mar 19, 2021
ed75c59
update params list
sshleifer Mar 19, 2021
44158f7
simple case passing
sshleifer Mar 19, 2021
f82f3b6
found other bug
sshleifer Mar 19, 2021
1022e1e
Broken tests for other optimizers
sshleifer Mar 19, 2021
75119c2
boom boom
sshleifer Mar 20, 2021
89947a4
Merge branch 'master' into fsdp-gather-optimizer
sshleifer Mar 20, 2021
8dcf0a8
remove oss changes
sshleifer Mar 20, 2021
2caf928
passing besides mypy
sshleifer Mar 20, 2021
0b888fd
Smaller delta
sshleifer Mar 20, 2021
a2aacd0
Nesting works
sshleifer Mar 21, 2021
0fc045d
passing, lint attempt
sshleifer Mar 21, 2021
d859734
merge master
sshleifer Mar 21, 2021
3635277
update test list
sshleifer Mar 21, 2021
dbb426f
mypy
sshleifer Mar 22, 2021
f537632
Simpler consolidate_optim_state_dict
sshleifer Mar 22, 2021
a04b406
slightly cleaner
sshleifer Mar 22, 2021
e5e91df
Simplified signature, helper fn for unflattening
sshleifer Mar 22, 2021
ea9d4b5
add todo
sshleifer Mar 22, 2021
47e7cba
Give CI more time to show me a traceback
sshleifer Mar 23, 2021
6cebcec
Fix broadcast_object regression
sshleifer Mar 23, 2021
93c0857
Move most dictionary manipulation to fsdp_optim_utils.py
sshleifer Mar 23, 2021
c93d1db
passing
sshleifer Mar 23, 2021
13b0537
style
sshleifer Mar 23, 2021
9d3dfb7
passing
sshleifer Mar 23, 2021
9f619b2
stateless fix
sshleifer Mar 23, 2021
a4778b7
Min comments
sshleifer Mar 23, 2021
c77a9f7
Min comments
sshleifer Mar 24, 2021
aeefe69
Apply suggestions from code review
sshleifer Mar 24, 2021
d645337
Min comments
sshleifer Mar 24, 2021
75bdd3f
also test param groups
sshleifer Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 250 additions & 8 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import calc_grad_norm
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
Expand Down Expand Up @@ -88,8 +88,8 @@ class FullyShardedDataParallel(nn.Module):
import torch
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the doc here!

with enable_wrap(**fsdp_params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
Expand Down Expand Up @@ -185,6 +185,9 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb

self.num_padded: List[int] = []
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
self._all_optimizer_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self.compute_device = compute_device

if self.fp32_reduce_scatter and not self.mixed_precision:
Expand Down Expand Up @@ -412,6 +415,7 @@ def _shard_parameters_(self) -> None:
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
self.num_padded = []
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
Expand All @@ -423,16 +427,18 @@ def _shard_parameters_(self) -> None:
p._orig_size = p.data.size()

if not p._is_sharded:
self.num_padded.append(0)
continue
p._is_sharded = True

# Replace p.data with the relevant shard.
orig_data = p.data
p.data = self._get_shard(p.data)
p.data, num_padded = self._get_shard(p.data)
self.num_padded.append(num_padded)
free_storage_(orig_data)

sshleifer marked this conversation as resolved.
Show resolved Hide resolved
def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
"""Return the local shard of a given full tensor."""
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
Expand All @@ -445,7 +451,7 @@ def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard
return shard, num_to_pad

def extra_repr(self) -> str:
return (
Expand Down Expand Up @@ -684,7 +690,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard = self._get_shard(full_tensor)
local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
Expand Down Expand Up @@ -1346,6 +1352,242 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
traceback.print_stack()
raise ValueError(msg)

# Optim State dict functions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered moving these to a separate FSDPOptimizerMixin in fsdp_optimizer_utils.py, but decided it wasn't really a mixin since it depends heavily on FSDP.


def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: int = 0) -> None:
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
"""Update the consolidated state_dict list, one per rank.

Arguments:
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
recipient_rank (int): on which rank to materialize the full state dict.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
-1 is a special value, which means that all ranks should have the state
sshleifer marked this conversation as resolved.
Show resolved Hide resolved

.. warning: This needs to be called on all replicas"""
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
_default_device = torch.device("cuda")
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP

# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = self.rank == recipient_rank or recipient_rank == -1
should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1
_all_optimizer_states: List[Dict[str, Any]] = []
for rank in range(self.world_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be complications here when nested FSDP instance have different world_size, right? For example, if BN layers are in their own world_size == 1 process groups, then we collect duplicated states for them? add a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added TODO in the caller

if rank == self.rank:
sd = optim.state_dict()
sd["num_padded"] = [m.num_padded for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
if should_collect_state:
_all_optimizer_states.append(
recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu"))
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
)

# Sync with other replicas
state_to_share = (
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
sd if should_send_state else torch.tensor([0], dtype=torch.uint8, device=_default_device)
)
broadcast_object(
state_to_share, src_rank=self.rank, group=self.process_group, dist_device=_default_device,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=_default_device),
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
src_rank=rank,
group=self.process_group,
dist_device=_default_device,
)

if should_collect_state:
_all_optimizer_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be rearranged to remove some duplication? Something like:

for rank in range(self.world_size):
    if rank == self.rank:
        state = optim.state_dict()
        sd["num_padded"] = ...
        state = broadcast_object(state, src_rank=rank, ...)
    else:
        state = broadcast_object(None, src_rank=rank, ...)

    if should_collect_state:
        _all_optimizer_states.append(recursive_copy_to_device(state, device=torch.device("cpu"))

Copy link
Contributor Author

@sshleifer sshleifer Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copy pasted this func from OSS. I think the reason for the extra append is to save useless communication from recipient_rank to recipient_rank

Copy link
Contributor Author

@sshleifer sshleifer Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the simplified implem working with torch.distributed.broadcast_object_list.
I no longer need compute_device. Still calling lazy_init_ for safety.


self._all_optimizer_states = _all_optimizer_states

def gather_full_optim_state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved

Returns:
a dict with two entries
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups

"""
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
if len(self._all_optimizer_states) == 0:
raise ValueError("You must call consolidate_optim_state_dict before gather_full_optim_state_dict")

# Unify the shard states by concatenating tensors and unflattening params
world_pad_info: List[List[int]] = [s.pop("num_padded") for s in self._all_optimizer_states]

param_groups = copy.deepcopy(self._all_optimizer_states[0]["param_groups"])

# combined_state refers to tensor values in sd[state][param_id].
# Here we just aggregate them into a list inside the dictionary from a list of dictionaries.
combined_state = self._combine_tensor_optim_state(
[x["state"] for x in self._all_optimizer_states], self.world_size
)

# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step"
# we check that these are identical across workers and then take the first
constant_state = [self._extract_constant_state(combined_state, id) for id in combined_state]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these comments/helper methods are very nice 😄


# cleanup all_optimizer_states_list
self._all_optimizer_states = []

new_state_dict = {"state": {}, "param_groups": param_groups}
instance_list = self._fsdp_instances
numels_per_instance = [sum(m._param_numels) for m in instance_list] # type: ignore

# loop over parameters in state.
# Tensor state will be padded, concatenated, and then restored to their original
# shape with FlattenParamsWrapper.get_views
# get_views multiple tensors, each of which is a new parameter with a new "global" id.

local_to_global_param_id: Dict[int, List[int]] = {}
# local ids are in the current state, global_ids will be in returned state.

next_global_param_id = 0 # gets incremented
for pg_id, param_group in enumerate(param_groups):
for local_id in param_group["params"]:
local_to_global_param_id[local_id] = []
if local_id not in combined_state:
continue
# undo the work of shard_parameters
for k, v in combined_state[local_id].items():
if k in constant_state[local_id]:
continue
assert isinstance(v, list), f"expected list, got {k}:{v} for {local_id} at rank {self.rank}"
assert all(len(s[local_id]) == 1 for s in world_pad_info) # because of flatten_parameters
pad_info = [s[local_id][0] for s in world_pad_info]
assert len(pad_info) == self.world_size == len(v), f"{len(pad_info), self.world_size, len(v)}"

v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)]
flat_buffer = torch.cat(v_unpad)
assert (
numels_per_instance[local_id] == flat_buffer.shape[0] == flat_buffer.numel()
), f"{numels_per_instance[local_id]} {flat_buffer.shape[0]}, {flat_buffer.numel()}"
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer)
for i, param_view in enumerate(param_views):
if i == len(local_to_global_param_id[local_id]):
# We have not seen this global param before, and make a new ID
local_to_global_param_id[local_id].append(next_global_param_id)
next_global_param_id += 1
global_id = local_to_global_param_id[local_id][i]
if global_id not in new_state_dict["state"]:
new_state_dict["state"][global_id] = copy.deepcopy(constant_state[local_id])
assert k not in new_state_dict["state"][global_id], f"already added {k} to new[{global_id}]"
new_state_dict["state"][global_id][k] = param_view
sshleifer marked this conversation as resolved.
Show resolved Hide resolved

if next_global_param_id == 0: # stateless optimizer
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this list could be quite large, right? I guess this only affects SGD w/o momentum, but I wonder if there's a more compact way. Let's not worry about it for now, but perhaps put a note or TODO to make it more efficient

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about list(range(num_params))? If so, it affects both cases.
I'll leave a TODO

else:
new_state_dict["param_groups"][pg_id]["params"] = list(range(next_global_param_id))

global_to_local_id = {}
for old_pid, global_ids in local_to_global_param_id.items():
for new_id in global_ids:
global_to_local_id[new_id] = old_pid

new_state_dict["param_id_map"] = global_to_local_id
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
new_state_dict["state"] = dict(sorted(new_state_dict["state"].items()))
return new_state_dict

@staticmethod
def _combine_tensor_optim_state(states: List[Dict], world_size: int) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if world_size == 1:
return combined_state

for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state

@staticmethod
def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict:
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it.
for k, v in combined_state[param_id].items():

if torch.is_tensor(v[0]):
continue
elif len(set(v)) == 1:
constant_state[k] = v[0]
else:
raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}")
return constant_state

@property
def _fsdp_instances(self) -> List[nn.Module]:
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard"""
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
# Assert nesting is the same as it was at save time
n_instances = len(self._fsdp_instances)
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values()))
msg = f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved there were {n_local_params_in_opt}"
stateless = len(full_optim_state_dict["state"]) == 0
assert stateless or (n_instances == n_local_params_in_opt), msg

stateless = len(full_optim_state_dict["state"]) == 0
instance_list = self._fsdp_instances
if self.flatten_parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this assume all inner FSDP instances also have flatten == True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will assert

full_optim_state_dict = self._flatten_optim_state_dict(full_optim_state_dict)
assert stateless or len(full_optim_state_dict["state"]) == len(instance_list)

# get the portion of dict associated with the shard
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont partition entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard

return full_optim_state_dict

@staticmethod
def _flatten_optim_state_dict(sd: Dict) -> Dict:
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
else:
new_state = {}
constant_state = {}

# assumes sd sorted
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
else:
assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]"
constant_state[buffer_name] = p
# TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check

for consolidated_pid, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(constant_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}

for pg_id, _ in enumerate(sd["param_groups"]):
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))

return new_sd


@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
Expand Down
6 changes: 3 additions & 3 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None:
# register the views as plain attributes
self._unflatten_params_as_views()

def _get_param_views(self, flat_param: Tensor) -> Generator:
def get_param_views(self, flat_param: Tensor) -> Generator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is becoming an public method, can you please:

  1. add docstring with proper doc
  2. assert flat_param is valid before using it?

return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))

def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:
assert self.is_flattened or flat_param is not None
self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param

ps = self._get_param_views(flat_param)
ps = self.get_param_views(flat_param)
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
Expand All @@ -144,7 +144,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:

def _unflatten_params_as_views(self) -> None:
assert self.is_flattened
ps = self._get_param_views(self.flat_param)
ps = self.get_param_views(self.flat_param)
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py
Expand Down
Loading