-
Notifications
You must be signed in to change notification settings - Fork 23.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support accumulating DDP grads using a context manager #21736
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
The failed test passed in rerun. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general, but with 2 comments:
- Would it be possible to consolidate some of the duplication in the test code into some kind of helper class, or a set of helper functions? There is a lot of duplication between these tests and I think there are at least half a dozen more similar to these.
- I think there is interaction between
no_sync
and the_sync_params
function that could cause unintended results. For example, the gradients of model replicas are detached and zeroed in every iteration, whereas they should also accumulate. Then there is the question of the batch normalization buffer synchronization in every call to forward... not sure what to do about that one.
Yes, let me try
Sorry, I forgot about this. How about the following two options:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
I think the new tests didn't fail because it doesn't require any buffer synchronization. If it did, it would have yielded different results without the require_forward_param_sync
option.
@pietern do we want to raise an exception if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@pytorchbot rebase this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
""" | ||
old_require_backward_grad_sync = self.require_backward_grad_sync | ||
self.require_backward_grad_sync = False | ||
yield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in a try ... finally
block, because in case of an exception you will fail to restore the flag!
@@ -272,6 +273,8 @@ def __init__(self, module, device_ids=None, | |||
self.module = module | |||
self.broadcast_buffers = broadcast_buffers | |||
self.find_unused_parameters = find_unused_parameters | |||
self.require_backward_grad_sync = True | |||
self.require_forward_param_sync = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if DistributedDataParallel
is picklable, but if it is, then you should add a __setstate__
that adds those two attributes, because otherwise people who load older checkpoints will get missing attribute errors.
I've fixed the problems in a PR that's referenced above. |
Summary: The first attempt and more discussions are available in pytorch#19577 #### Goal Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations. #### Concerns Our first attempt in pytorch#19577 tries to do it using a variable or a function. But apaszke made a good point that it will not be error prone, and favors a context manager instead. #### Proposed Solution Instead of providing a `accumulate_grads` variable/function/context, we provide a `DistributedDataParallel.no_sync()` context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that `accumulate_grads` means `no_sync` + no optimizer step, where the latter is not controlled by DDP. It is true that users need to call another `model(input).backward()` after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required.. The application would then look like: ```python with ddp.no_sync(): for input in inputs: ddp(input).backward() ddp(one_more_input).backward() optimizer.step() ``` chenyangyu1988 myleott Pull Request resolved: pytorch#21736 Differential Revision: D15805215 Pulled By: mrshenli fbshipit-source-id: 73405797d1e39965c52016af5cf45b15525ce21c
The first attempt and more discussions are available in #19577
Goal
Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations.
Concerns
Our first attempt in #19577 tries to do it using a variable or a function. But @apaszke made a good point that it will not be error prone, and favors a context manager instead.
Proposed Solution
Instead of providing a
accumulate_grads
variable/function/context, we provide aDistributedDataParallel.no_sync()
context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note thataccumulate_grads
meansno_sync
+ no optimizer step, where the latter is not controlled by DDP.It is true that users need to call another
model(input).backward()
after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required..The application would then look like:
@chenyangyu1988 @myleott