-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] add no_broadcast_optim_state option #560
Changes from 6 commits
12112fd
27944e8
388e942
57f88c4
d38e380
0fb1850
ae5922c
c301726
c8f5f0b
660effa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: | |
non_tensor_state = {} | ||
|
||
# Populate `new_state["state"]`. (Assuming sd is sorted) | ||
for expanded_pid, buffers in sd["state"].items(): | ||
consolidated_pid = param_id_map[expanded_pid] | ||
for global_id, buffers in sd["state"].items(): | ||
local_id = param_id_map[global_id] | ||
for buffer_name, p in buffers.items(): | ||
if torch.is_tensor(p): | ||
if buffer_name not in new_state[consolidated_pid]: | ||
new_state[consolidated_pid][buffer_name] = [] | ||
new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) | ||
if buffer_name not in new_state[local_id]: | ||
new_state[local_id][buffer_name] = [] | ||
new_state[local_id][buffer_name].append(p.reshape(-1)) | ||
else: | ||
non_tensor_state[buffer_name] = p | ||
|
||
# Now combine all tensors in each buffer using torch.cat(). | ||
for consolidated_pid, state in new_state.items(): | ||
for local_id, state in new_state.items(): | ||
for buffer_name, tensors in state.items(): | ||
new_state[consolidated_pid][buffer_name] = torch.cat(tensors) | ||
new_state[consolidated_pid].update(non_tensor_state) | ||
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} | ||
new_state[local_id][buffer_name] = torch.cat(tensors) | ||
new_state[local_id].update(non_tensor_state) | ||
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a comment on the deep copy? Also, last time I checked that it seems deepcopy doesn't really copy tensors. you may want to double check (with some asserts and testing) to verify the deep copy here is doing the right thing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Luckily there are no tensors in param_groups, but that's very useful to know! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 to adding a comment about param groups not having tensors (thus deepcopy being okay) |
||
|
||
# add pointers from the `params` dict. | ||
for pg_id, _ in enumerate(sd["param_groups"]): | ||
|
@@ -53,7 +53,8 @@ def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: | |
f"there were {n_local_params_in_opt}" | ||
) | ||
stateless = len(full_optim_state_dict["state"]) == 0 | ||
assert stateless or (n_instances == n_local_params_in_opt), msg | ||
if not (stateless or (n_instances == n_local_params_in_opt)): | ||
print(msg) | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# All functions below here help saving the list of optimizer states, one from each rank | ||
|
@@ -134,7 +135,9 @@ def _unflatten_optim_state( | |
return unflat_state, global_to_local_id | ||
|
||
|
||
def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: | ||
def build_unflat_state_dict( | ||
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict] | ||
) -> Dict: | ||
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank.""" | ||
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] | ||
assert all(len(s) == len(instance_list) for s in world_pad_info) | ||
|
@@ -144,14 +147,25 @@ def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_st | |
|
||
# Aggregate from a list of dictionaries to a dictionary of lists | ||
combined_state = _combine_state([x["state"] for x in world_optim_states]) | ||
for local_id, v in uncollected_opt_state.items(): | ||
assert local_id not in combined_state | ||
combined_state[local_id] = {} | ||
for buffer_name, tensor in v.items(): | ||
combined_state[local_id][buffer_name] = [tensor] | ||
del world_optim_states | ||
|
||
# local ids are in the current state, global_ids will be in returned state. | ||
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) | ||
uncollected_global_ids = set() | ||
for g, l in global_to_local_id.items(): | ||
if l in uncollected_opt_state: | ||
uncollected_global_ids.add(g) | ||
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore | ||
param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential | ||
param_groups[0]["params"] = list(range(num_params)) | ||
return { | ||
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted | ||
"param_id_map": global_to_local_id, | ||
"param_groups": param_groups, | ||
"uncollected_global_ids": uncollected_global_ids, | ||
"uncollected_local_ids": list(uncollected_opt_state.keys()), | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
from .test_fsdp import ( | ||
DistributedTest, | ||
DummyProcessGroup, | ||
NestedWrappedModule, | ||
MixtureOfExperts, | ||
TransformerWithSharedParams, | ||
rename_test, | ||
spawn_and_init, | ||
|
@@ -36,11 +36,12 @@ def assert_equal(a, b): | |
|
||
class TestOptimizerUtils(DistributedTest): | ||
@parameterized.expand( | ||
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]], | ||
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]], | ||
name_func=rename_test, | ||
) | ||
def test_consolidate_optimizer(self, optim_fn, transformer): | ||
config = {"mixed_precision": True, "flatten_parameters": True} | ||
config["compute_dtype"] = torch.float32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm guessing this and the autocast change where required to get tests to pass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I followed the logic of |
||
test_fn = functools.partial( | ||
self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer | ||
) | ||
|
@@ -54,10 +55,10 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim | |
|
||
if transformer: | ||
fsdp = self.get_wrapped_model(group, config=config).cuda() | ||
unwrapped_model = TransformerWithSharedParams(group).cuda() | ||
unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda() | ||
else: | ||
fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda() | ||
unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda() | ||
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda() | ||
unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda() | ||
|
||
try: | ||
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) | ||
|
@@ -68,17 +69,17 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim | |
|
||
fsdp_optim.zero_grad() | ||
optim_unwrapped.zero_grad() | ||
|
||
x = fsdp.module.get_input(torch.device("cuda")) | ||
output = fsdp(*x) | ||
loss = fsdp.module.get_loss(x, output).to("cuda") | ||
fsdp.module.run_backward(loss) | ||
fsdp_optim.step() | ||
|
||
output = unwrapped_model(*x) | ||
loss = unwrapped_model.get_loss(x, output) | ||
unwrapped_model.run_backward(loss) | ||
optim_unwrapped.step() | ||
with torch.cuda.amp.autocast(enabled=True): | ||
x = fsdp.module.get_input(torch.device("cuda")) | ||
output = fsdp(*x) | ||
loss = fsdp.module.get_loss(x, output).to("cuda") | ||
fsdp.module.run_backward(loss) | ||
fsdp_optim.step() | ||
|
||
output = unwrapped_model(*x) | ||
loss = unwrapped_model.get_loss(x, output) | ||
unwrapped_model.run_backward(loss) | ||
optim_unwrapped.step() | ||
unwrapped_sd = optim_unwrapped.state_dict() | ||
|
||
tstart = time() | ||
|
@@ -89,6 +90,12 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim | |
|
||
if fsdp.rank > 0: | ||
return | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unflat_state = sd["state"] | ||
assert "uncollected_local_ids" in sd | ||
shard_sd = fsdp.get_shard_from_optim_state_dict(sd) | ||
shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu") | ||
state_after_get_shard = sd["state"] | ||
assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects. | ||
|
||
assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) | ||
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) | ||
|
@@ -97,18 +104,21 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim | |
sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]), | ||
) | ||
|
||
shard_sd = fsdp.get_shard_from_optim_state_dict(sd) | ||
|
||
original_shard_sd = fsdp_optim.state_dict() | ||
assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) | ||
assert_equal(shard_sd.keys(), original_shard_sd.keys()) | ||
original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") | ||
|
||
# Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks. | ||
assert_equal( | ||
[first_tensor_numel(v) for k, v in shard_sd["state"].items()], | ||
[first_tensor_numel(v) for k, v in original_shard_sd["state"].items()], | ||
) | ||
assert_equal( | ||
sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]), | ||
sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]), | ||
[v for k, v in shard_sd["param_groups"][0].items()], | ||
[v for k, v in original_shard_sd["param_groups"][0].items()], | ||
) | ||
assert objects_are_equal(shard_sd, original_shard_sd) | ||
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) | ||
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd) | ||
|
||
def test_named_params_ordering(self): | ||
"""Test assumption of consolidate_optimizer_state_dict""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the rename of the variables!