Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support accumulating DDP grads using a context manager #21736

Closed
wants to merge 3 commits into from

Conversation

mrshenli
Copy link
Contributor

The first attempt and more discussions are available in #19577

Goal

Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations.

Concerns

Our first attempt in #19577 tries to do it using a variable or a function. But @apaszke made a good point that it will not be error prone, and favors a context manager instead.

Proposed Solution

Instead of providing a accumulate_grads variable/function/context, we provide a DistributedDataParallel.no_sync() context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that accumulate_grads means no_sync + no optimizer step, where the latter is not controlled by DDP.

It is true that users need to call another model(input).backward() after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required..

The application would then look like:

with ddp.no_sync():
  for input in inputs:
    ddp(input).backward()

ddp(one_more_input).backward() 
optimizer.step()

@chenyangyu1988 @myleott

@mrshenli mrshenli requested review from pietern and apaszke June 13, 2019 15:30
@pytorchbot pytorchbot added oncall: distributed Add this issue/PR to distributed oncall triage queue module: nn Related to torch.nn labels Jun 13, 2019
@mrshenli mrshenli mentioned this pull request Jun 13, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mrshenli
Copy link
Contributor Author

The failed test passed in rerun.

Copy link
Contributor

@pietern pietern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general, but with 2 comments:

  1. Would it be possible to consolidate some of the duplication in the test code into some kind of helper class, or a set of helper functions? There is a lot of duplication between these tests and I think there are at least half a dozen more similar to these.
  2. I think there is interaction between no_sync and the _sync_params function that could cause unintended results. For example, the gradients of model replicas are detached and zeroed in every iteration, whereas they should also accumulate. Then there is the question of the batch normalization buffer synchronization in every call to forward... not sure what to do about that one.

@mrshenli
Copy link
Contributor Author

  1. Would it be possible to consolidate some of the duplication in the test code into some kind of helper class, or a set of helper functions? There is a lot of duplication between these tests and I think there are at least half a dozen more similar to these.

Yes, let me try

  1. I think there is interaction between no_sync and the _sync_params function that could cause unintended results. For example, the gradients of model replicas are detached and zeroed in every iteration, whereas they should also accumulate. Then there is the question of the batch normalization buffer synchronization in every call to forward... not sure what to do about that one.

Sorry, I forgot about this. How about the following two options:

  1. Only allow creating no_sync context if the DDP does not contain module replicas, i.e., _sync_params becomes no op.

  2. The _sync_params should be called if the previous iteration invoked prepare_for_backward. So, maybe we can add an additional var recording prepare_for_backward was invoked last time?

Copy link
Contributor

@pietern pietern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

I think the new tests didn't fail because it doesn't require any buffer synchronization. If it did, it would have yielded different results without the require_forward_param_sync option.

@mrshenli
Copy link
Contributor Author

I think the new tests didn't fail because it doesn't require any buffer synchronization. If it did, it would have yielded different results without the require_forward_param_sync option.

@pietern do we want to raise an exception if no_sync is called from model with buffers?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mrshenli
Copy link
Contributor Author

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mrshenli merged this pull request in 08facca.

"""
old_require_backward_grad_sync = self.require_backward_grad_sync
self.require_backward_grad_sync = False
yield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in a try ... finally block, because in case of an exception you will fail to restore the flag!

@@ -272,6 +273,8 @@ def __init__(self, module, device_ids=None,
self.module = module
self.broadcast_buffers = broadcast_buffers
self.find_unused_parameters = find_unused_parameters
self.require_backward_grad_sync = True
self.require_forward_param_sync = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if DistributedDataParallel is picklable, but if it is, then you should add a __setstate__ that adds those two attributes, because otherwise people who load older checkpoints will get missing attribute errors.

apaszke added a commit that referenced this pull request Jun 21, 2019
@apaszke
Copy link
Contributor

apaszke commented Jun 21, 2019

I've fixed the problems in a PR that's referenced above.

iotamudelta pushed a commit to ROCm/pytorch that referenced this pull request Jun 21, 2019
Summary:
The first attempt and more discussions are available in pytorch#19577

#### Goal

Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations.

#### Concerns

Our first attempt in pytorch#19577 tries to do it using a variable or a function. But apaszke made a good point that it will not be error prone, and favors a context manager instead.

#### Proposed Solution

Instead of providing a `accumulate_grads` variable/function/context, we provide a `DistributedDataParallel.no_sync()` context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that `accumulate_grads` means `no_sync` + no optimizer step, where the latter is not controlled by DDP.

It is true that users need to call another `model(input).backward()` after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required..

The application would then look like:

```python
with ddp.no_sync():
  for input in inputs:
    ddp(input).backward()

ddp(one_more_input).backward()
optimizer.step()
```

chenyangyu1988 myleott
Pull Request resolved: pytorch#21736

Differential Revision: D15805215

Pulled By: mrshenli

fbshipit-source-id: 73405797d1e39965c52016af5cf45b15525ce21c
facebook-github-bot pushed a commit that referenced this pull request Jun 24, 2019
Summary:
cc mrshenli
Pull Request resolved: #22074

Differential Revision: D15965376

Pulled By: mrshenli

fbshipit-source-id: 50ff96de6390817d8ea52c04322c6bee3d649b32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: nn Related to torch.nn oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants