Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] add no_broadcast_optim_state option #560

Merged
merged 10 commits into from
Apr 4, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
non_tensor_state = {}

# Populate `new_state["state"]`. (Assuming sd is sorted)
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for global_id, buffers in sd["state"].items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the rename of the variables!

local_id = param_id_map[global_id]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
if buffer_name not in new_state[local_id]:
new_state[local_id][buffer_name] = []
new_state[local_id][buffer_name].append(p.reshape(-1))
else:
non_tensor_state[buffer_name] = p

# Now combine all tensors in each buffer using torch.cat().
for consolidated_pid, state in new_state.items():
for local_id, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment on the deep copy? Also, last time I checked that it seems deepcopy doesn't really copy tensors. you may want to double check (with some asserts and testing) to verify the deep copy here is doing the right thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Luckily there are no tensors in param_groups, but that's very useful to know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to adding a comment about param groups not having tensors (thus deepcopy being okay)


# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
Expand All @@ -53,7 +53,8 @@ def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances:
f"there were {n_local_params_in_opt}"
)
stateless = len(full_optim_state_dict["state"]) == 0
assert stateless or (n_instances == n_local_params_in_opt), msg
if not (stateless or (n_instances == n_local_params_in_opt)):
print(msg)
sshleifer marked this conversation as resolved.
Show resolved Hide resolved


# All functions below here help saving the list of optimizer states, one from each rank
Expand Down Expand Up @@ -134,7 +135,9 @@ def _unflatten_optim_state(
return unflat_state, global_to_local_id


def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
Expand All @@ -144,14 +147,25 @@ def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_st

# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
del world_optim_states

# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
uncollected_global_ids = set()
for g, l in global_to_local_id.items():
if l in uncollected_opt_state:
uncollected_global_ids.add(g)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential
param_groups[0]["params"] = list(range(num_params))
return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
"uncollected_global_ids": uncollected_global_ids,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
}
46 changes: 38 additions & 8 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ class FullyShardedDataParallel(nn.Module):
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
no_broadcast_optim_state: (bool, Optional)
do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
If you set this true, you are expected to overwrite the relevant state entries of the returned OSD
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
"""

def __init__(
Expand All @@ -173,6 +179,7 @@ def __init__(
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False,
):
super().__init__()
self.process_group = process_group or dist.new_group()
Expand All @@ -187,6 +194,8 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state

self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device
Expand Down Expand Up @@ -841,6 +850,12 @@ def _set_is_root(self) -> None:
if m.process_group != self.process_group:
self.children_share_process_group = False

# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably avoid OOM, so we won't do it.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)

def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root:
Expand Down Expand Up @@ -1380,7 +1395,7 @@ def _consolidate_optim_state_dict(
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
if rank == self.rank:
sd = optim.state_dict()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
Expand Down Expand Up @@ -1417,15 +1432,29 @@ def gather_full_optim_state_dict(
if self.rank != recipient_rank and recipient_rank is not None:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states)
# TODO: check if this code supports nested instances with different world size
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, world_optim_states, self.uncollected_opt_state
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
return new_state_dict

@property
def _fsdp_instances(self) -> List[nn.Module]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self.uncollected_opt_state = {k: v for k, v in osd["state"].items() if k in uncollected_ids}

pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg
return new_dct

def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard

Expand All @@ -1440,18 +1469,19 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any])
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
assert all(
x.world_size == self.world_size for x in instance_list
), "all nested instances must have same world size"
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list))
assert len(full_optim_state_dict["state"]) in (
0,
len(instance_list),
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'

# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
if torch.is_tensor(v) and id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont shard entries that are not tensors
Expand Down
3 changes: 2 additions & 1 deletion tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True

Expand All @@ -795,7 +796,7 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

if wrapper_config is not None:
# we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group([group.rank()])
expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard
expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)

shared = FullyShardedDataParallel(shared, group, **wrapper_config)
Expand Down
54 changes: 32 additions & 22 deletions tests/nn/data_parallel/test_fsdp_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .test_fsdp import (
DistributedTest,
DummyProcessGroup,
NestedWrappedModule,
MixtureOfExperts,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
Expand All @@ -36,11 +36,12 @@ def assert_equal(a, b):

class TestOptimizerUtils(DistributedTest):
@parameterized.expand(
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]],
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
name_func=rename_test,
)
def test_consolidate_optimizer(self, optim_fn, transformer):
config = {"mixed_precision": True, "flatten_parameters": True}
config["compute_dtype"] = torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this and the autocast change where required to get tests to pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I followed the logic of __test_identical_outputs

test_fn = functools.partial(
self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer
)
Expand All @@ -54,10 +55,10 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim

if transformer:
fsdp = self.get_wrapped_model(group, config=config).cuda()
unwrapped_model = TransformerWithSharedParams(group).cuda()
unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
else:
fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda()
unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda()
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()
unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda()

try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,)
Expand All @@ -68,17 +69,17 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim

fsdp_optim.zero_grad()
optim_unwrapped.zero_grad()

x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
fsdp_optim.step()

output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss)
optim_unwrapped.step()
with torch.cuda.amp.autocast(enabled=True):
x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
fsdp_optim.step()

output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss)
optim_unwrapped.step()
unwrapped_sd = optim_unwrapped.state_dict()

tstart = time()
Expand All @@ -89,6 +90,12 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim

if fsdp.rank > 0:
return
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
state_after_get_shard = sd["state"]
assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects.

assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
Expand All @@ -97,18 +104,21 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim
sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]),
)

shard_sd = fsdp.get_shard_from_optim_state_dict(sd)

original_shard_sd = fsdp_optim.state_dict()
assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
assert_equal(shard_sd.keys(), original_shard_sd.keys())
original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")

# Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
assert_equal(
[first_tensor_numel(v) for k, v in shard_sd["state"].items()],
[first_tensor_numel(v) for k, v in original_shard_sd["state"].items()],
)
assert_equal(
sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]),
sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]),
[v for k, v in shard_sd["param_groups"][0].items()],
[v for k, v in original_shard_sd["param_groups"][0].items()],
)
assert objects_are_equal(shard_sd, original_shard_sd)
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)

def test_named_params_ordering(self):
"""Test assumption of consolidate_optimizer_state_dict"""
Expand Down