Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Gradient accumulation #3512

Closed
wants to merge 32 commits into from
Closed

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Dec 11, 2019

This adds gradient accumulation to the trainer, following @scarecrow1123's work which I resurrected from #2721.

Fixes #3469 and #2112.

scarecrow1123 and others added 24 commits November 5, 2018 13:03
This fixes Issue [@2717](allenai#2717) to include day count in `training_duration` key in metrics.
`time.strftime` does not account for number of days more than
31. Changing it to `datetime.timedelta` and using its `str`
representation for printing epoch duration as well as training
duration.
Gradient accumulation is computing multiple mini batches
before doing a gradient update to accomodate for larger
batches. This commit adds a new key to the trainer config
`num_steps_to_accumulate`. The trainer performs an optimizer
step only after every specified number of mini batch
computation.
This removes checks that previously allowed gradient accumulation
only for single GPU training. This condition was not really necessary.
Gradient accumulation reduces the effective number of batches from
what is configured. The added check tests just that.
# Conflicts:
#	allennlp/tests/training/trainer_test.py
#	allennlp/training/trainer.py
Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadly looks good, here is one suggested improvement:
matt-peters@4992ed6#diff-043dcd121296c3cef3f3ff8c74127ff1R339-R349

Currently this solution will batch N batches regardless of the size of the batches. Another way to do this would be to have an "effective batch size" such that when we sample batches, we do a grad step whenever the cumulative batch size goes over this number.

See previous comment from Matt P here:
#2721 (comment)

@brendan-ai2
Copy link
Contributor

brendan-ai2 commented Dec 12, 2019

I'm definitely not strongly opposed to Matt P's suggestion. It's more that it's just confusing how to present this to our users. At least in the short term iterators will probably still have their own batch size setting. So I can imagine a user modifying just "effective_batch_size" while their iterator has a batch size of 1 or something an needlessly using gradient accumulation. Smart ideas to avoid this very welcome! :)

@brendan-ai2
Copy link
Contributor

Sorry, typo. Not strongly opposed, haha

@DeNeutoy
Copy link
Contributor

@dirkgr ah, I see, that makes sense. I think that problem might just go away once we use this https://github.com/allenai/allennlp/compare/torch-distributed to do multi-gpu training (because each DDP training process will be doing it's own data loading, meaning this clash doesn't exist any more), do you agree? It might make sense to tolerate this difference with the way splitting across GPUs works for now, because it would mean we don't have to change it later.

@brendan-ai2 - do you think if we call it something like gradient_accumulation_effective_batch_size or something, we can avoid that possibility of confusion?

@dirkgr
Copy link
Member Author

dirkgr commented Dec 12, 2019

It's still problematic from an implementation perspective, because only the iterator knows how to make batches. Would we put logic into the trainer to cut up batches? What about useless padding?

I don't like the name "effective batch size", because it's not clear what that means. Effective for the hardware, or effective for the math? That's why I started calling it "mathematical" batch size, even though that's kind of awkward.

loss.backward()
# `len(batches_for_step)` should always be `num_gradient_accumulation_steps`, except
# for the last batch in the epoch.
loss = loss / len(batches_for_step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good if the loss per sub-batch was scaled relative to it's proportion of the overall gradient accumulated batch - e.g a batch of size 64 and a batch of size 12 would get weighted evenly here. You can do this with training_util.get_batch_size(batch)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we want this? Sometimes we're already scaling by sample size. For instance, https://github.com/allenai/allennlp/blob/master/allennlp/models/language_model.py#L322.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! Sorry, my bad. Disregard.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying this right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a7fc41b.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dirkgr, I might not have been clear. We shouldn't do this. It breaks cases where users have scaled by sample size in their models.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a long discussion, I reverted a7fc41b. Looks like we're going to break somebody, but this at least keeps the more common cases the same.

@brendan-ai2
Copy link
Contributor

Hmm, it does seem like the notion of effective batch size at least in principle remains in the DDP setting, right? In that there's an update to the model that incorporates the gradients from all distributed workers. Correct me if I'm wildly misunderstanding. But yeah, if we decide to restrict it to the gradient accumulation sense, then a verbose name sounds sufficient.

In the short term I think @dirkgr should do something simple here (like what he's done) since we're going to have to integrate with the DDP work anyway and it'll be clearer how to do that integration then. Which may be what you meant by "tolerate this difference"?

@DeNeutoy
Copy link
Contributor

@brendan-ai2 I didn't mean that the accumulation vs fixed size distinction goes away - I was commenting on the problem that if we are doing multi-gpu training as we currently do (by making a large batch, and then splitting it across GPUs), it's hard to do the "effective batch size" version of gradient accumulation (which I think is what @dirkgr was saying). I believe that this problem will go away once we switch to the new multi-gpu training, because each worker would be able to accumulate batches however they would like - there would still be a syncronisation point between the workers at the optimizer step, but before that point they can accumulate batches up to a certain size themselves, rather than do a fixed number of batches.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 12, 2019

It seems like DDP will change everything about this. Right now, gradient accumulation is outside, and multi-gpu is inside. But with DDP, it'll be the other way around.

The other thing I just realized is this: Padding changes the results, so there will be a difference in results between batch_size=10, num_gradient_accumulation_steps=2 and batch_size=20, num_gradient_accumulation_steps=1. It's no longer just an implementation detail ... unless we do it by cutting up the batch in the trainer (and thus preserving the padding).

@scarecrow1123
Copy link
Contributor

Thanks for taking this up! Just a heads up from DDP perspective here. With the current changes that have gone into the torch-distributed branch, the data parallel stuff still works. Distributed training works only if distributed: true config is set. This is for backward compatibility and hence this multi-GPU vs effective batch size difference would still be there if multiple GPUs are configured for training without the distributed flag. Unless of course you guys have an idea of getting away with DataParallel entirely.

@DeNeutoy
Copy link
Contributor

@scarecrow1123 - certainly from my perspective I think your DDP stuff should be the way we do distributed training, as it's so much faster - I'd hope we can remove the old DP code when we merge your branch in.

@scarecrow1123
Copy link
Contributor

I'd hope we can remove the old DP code when we merge your branch in

Makes sense. Thanks!

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, actually maybe this fixed number of steps vs accumulated up to a certain size thing is not a massive issue - we can revisit if it causes sync issues with the DDP stuff (e.g if one worker gets 5 really big batches, all of the others have to wait for it for the optimisation step).

LGTM 👍

Copy link
Contributor

@brendan-ai2 brendan-ai2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. (The multi-GPU test I suggested is probably not the useful since we're going to try to pull that code out anyway.) I'm curious, were you able to train a full model using gradient accumulation and get equivalent test performance?

@dirkgr
Copy link
Member Author

dirkgr commented Dec 12, 2019

I'm curious, were you able to train a full model using gradient accumulation and get equivalent test performance?

This is why I haven't merged yet. I trained 5 Bidaf models each, with only three epochs for time, with and without accumulation, and the results were a little different. It's a small difference, but systematic. To confirm, and get some significance, I trained five more times each, which is still running (but about to be finished).

@dirkgr
Copy link
Member Author

dirkgr commented Dec 12, 2019

Looks like gradient accumulation makes it worse.

dirkg@aristo-server1 ~/allennlp> jq .validation_f1 with_accumulation_*/metrics.json                                                                                                              0.7505124997517719
0.7467911952063719
0.7477172867964137
0.7466802502854666
0.745641727328858
0.7486145217192928
0.7429948690365236
0.7472469311151695
0.7503132418016738
0.7502895571597034
dirkg@aristo-server1 ~/allennlp> jq .validation_f1 no_accumulation_*/metrics.json                                                                                                                0.7520797099948976
0.7504806157018171
0.7514005354031544
0.7491140279284737
0.7492509069570104
0.7507136568499159
0.7529942724422292
0.750237112746902
0.7463213189658947
0.7531923671016937

@brendan-ai2
Copy link
Contributor

Doesn't seem gratuitously bad. Mean of 0.7474 vs 0.7504. Standard deviation for both is around 0.002. What are the parameters for each? (Iterator batch size, accumulations, etc.)

@DeNeutoy
Copy link
Contributor

@dirkgr possible that the batch weighting makes a difference here - might be worth just throwing that commit back in and running it again?

@dirkgr
Copy link
Member Author

dirkgr commented Dec 13, 2019

Doesn't seem gratuitously bad. Mean of 0.7474 vs 0.7504. Standard deviation for both is around 0.002. What are the parameters for each? (Iterator batch size, accumulations, etc.)

The difference is small, but significant. Batch size was 40 with 1 accumulation, 20 with 2. Only 3 epochs of training, because otherwise this will take all day. Everything else was the same. Is Bidaf known to be a fairly unstable model? If I mess around with this a lot, it would be nice to experiment with a model that's known to be quick and stable.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 13, 2019

might be worth just throwing that commit back in and running it again?

This is running right now.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 13, 2019

We might have to make a decision about whether we care, how much investigation is worth it.

I've been chasing leaderboard scores for the last six months. The gains often come 0.3% at a time, so I'm inclined to give it some weight. But I don't know what other decisions are baked into the library that would jiggle the scores by that much. I'm not excited about dragging out this feature forever over a 0.3% drop.

@matt-gardner
Copy link
Contributor

I would say don't let the perfect be the enemy of the good in a case like this. What you have clearly seems like an improvement. If we find ways to improve it more in the future, then great.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 17, 2019

Because this bugged me, I tried this with some code that cuts the batches in half, and thus keeps the padding intact. The difference is still there, but it's much better.

So I think what we learned is that padding makes a difference after all (at least in Bidaf). Batch size 20 with 2 accumulations is not the same as batch size 40 with 1 accumulation.

I will revert the batch cutter and merge the change without it. I think that's the most reasonable way to run this.

@brendan-ai2
Copy link
Contributor

That's very interesting, Dirk. Would you mind sharing the numbers you got? I'm curious how much the padding change affected things.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 17, 2019

While I was messing around with this investigation, we merged the distributed trainer. So I re-did this on top of the distributed trainer at #3537.

@dirkgr dirkgr closed this Dec 17, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature request] Gradient Accumulation
5 participants