Skip to content

Commit

Permalink
[FSDP] add no_broadcast_optim_state option (#560)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer authored Apr 4, 2021
1 parent 54a97ee commit 1fcbd62
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 42 deletions.
32 changes: 21 additions & 11 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
non_tensor_state = {}

# Populate `new_state["state"]`. (Assuming sd is sorted)
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for global_id, buffers in sd["state"].items():
local_id = param_id_map[global_id]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
if buffer_name not in new_state[local_id]:
new_state[local_id][buffer_name] = []
new_state[local_id][buffer_name].append(p.reshape(-1))
else:
non_tensor_state[buffer_name] = p

# Now combine all tensors in each buffer using torch.cat().
for consolidated_pid, state in new_state.items():
for local_id, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}

# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
Expand Down Expand Up @@ -109,6 +109,7 @@ def _unflatten_optim_state(

# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}

if non_tensor_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id

Expand All @@ -134,24 +135,33 @@ def _unflatten_optim_state(
return unflat_state, global_to_local_id


def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1

# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
del world_optim_states

# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential
param_groups[0]["params"] = list(range(num_params))
return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
}
46 changes: 38 additions & 8 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ class FullyShardedDataParallel(nn.Module):
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
no_broadcast_optim_state: (bool, Optional)
do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
"""

def __init__(
Expand All @@ -173,6 +179,7 @@ def __init__(
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False,
):
super().__init__()
self.process_group = process_group or dist.new_group()
Expand All @@ -187,6 +194,8 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor

Expand Down Expand Up @@ -849,6 +858,12 @@ def _set_is_root(self) -> None:
if m.process_group != self.process_group:
self.children_share_process_group = False

# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)

def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root:
Expand Down Expand Up @@ -1391,7 +1406,7 @@ def _consolidate_optim_state_dict(
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
if rank == self.rank:
sd = optim.state_dict()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
Expand Down Expand Up @@ -1428,15 +1443,29 @@ def gather_full_optim_state_dict(
if self.rank != recipient_rank and recipient_rank is not None:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states)
# TODO: check if this code supports nested instances with different world size
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, world_optim_states, self.uncollected_opt_state
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
return new_state_dict

@property
def _fsdp_instances(self) -> List[nn.Module]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self.uncollected_opt_state = {k: v for k, v in osd["state"].items() if k in uncollected_ids}

pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg
return new_dct

def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard
Expand All @@ -1451,18 +1480,19 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any])
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
assert all(
x.world_size == self.world_size for x in instance_list
), "all nested instances must have same world size"
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list))
assert len(full_optim_state_dict["state"]) in (
0,
len(instance_list),
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'

# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
if torch.is_tensor(v) and id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont shard entries that are not tensors
Expand Down
3 changes: 2 additions & 1 deletion tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True

Expand All @@ -795,7 +796,7 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

if wrapper_config is not None:
# we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group([group.rank()])
expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard
expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)

shared = FullyShardedDataParallel(shared, group, **wrapper_config)
Expand Down
60 changes: 38 additions & 22 deletions tests/nn/data_parallel/test_fsdp_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .test_fsdp import (
DistributedTest,
DummyProcessGroup,
NestedWrappedModule,
MixtureOfExperts,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
Expand All @@ -36,11 +36,12 @@ def assert_equal(a, b):

class TestOptimizerUtils(DistributedTest):
@parameterized.expand(
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]],
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
name_func=rename_test,
)
def test_consolidate_optimizer(self, optim_fn, transformer):
config = {"mixed_precision": True, "flatten_parameters": True}
config["compute_dtype"] = torch.float32
test_fn = functools.partial(
self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer
)
Expand All @@ -53,11 +54,11 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim
# Establish reference behavior.

if transformer:
unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
fsdp = self.get_wrapped_model(group, config=config).cuda()
unwrapped_model = TransformerWithSharedParams(group).cuda()
else:
fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda()
unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda()
unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda()
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()

try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,)
Expand All @@ -68,27 +69,39 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim

fsdp_optim.zero_grad()
optim_unwrapped.zero_grad()

x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
fsdp_optim.step()

output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss)
optim_unwrapped.step()
with torch.cuda.amp.autocast(enabled=True):
x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
fsdp_optim.step()

output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss)
optim_unwrapped.step()
unwrapped_sd = optim_unwrapped.state_dict()

if not transformer:
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state

tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

if fsdp.rank > 0:
assert sd is None
return
unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
state_after_get_shard = sd["state"]
assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects.

assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
Expand All @@ -97,18 +110,21 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim
sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]),
)

shard_sd = fsdp.get_shard_from_optim_state_dict(sd)

original_shard_sd = fsdp_optim.state_dict()
assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
assert_equal(shard_sd.keys(), original_shard_sd.keys())
original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")

# Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
assert_equal(
[first_tensor_numel(v) for k, v in shard_sd["state"].items()],
[first_tensor_numel(v) for k, v in original_shard_sd["state"].items()],
)
assert_equal(
sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]),
sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]),
[v for k, v in shard_sd["param_groups"][0].items()],
[v for k, v in original_shard_sd["param_groups"][0].items()],
)
assert objects_are_equal(shard_sd, original_shard_sd)
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)

def test_named_params_ordering(self):
"""Test assumption of consolidate_optimizer_state_dict"""
Expand Down

0 comments on commit 1fcbd62

Please sign in to comment.