-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Gradient accumulation #3512
Gradient accumulation #3512
Changes from 28 commits
9573e55
563647a
5ccf8ad
aaeddef
96db26f
64552a0
d0ac4ca
383bf6d
ad94b1f
a33865b
f7a8ff7
6f45aa3
955e2c4
e2e48e3
12a87ca
71398f1
07cd9b2
c66557f
b1dbf6d
43fc57b
2a45ecf
b6c47ff
6c697ea
1cc0f87
c56006f
437b490
4659312
83087fa
a7fc41b
7da83de
e7935c5
74ee2b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,7 @@ def __init__( | |
should_log_learning_rate: bool = False, | ||
log_batch_size_period: Optional[int] = None, | ||
moving_average: Optional[MovingAverage] = None, | ||
num_gradient_accumulation_steps: int = 1, | ||
) -> None: | ||
""" | ||
A trainer for doing supervised learning. It just takes a labeled dataset | ||
|
@@ -171,6 +172,10 @@ def __init__( | |
parameters. Be careful that when saving the checkpoint, we will save the moving averages of | ||
parameters. This is necessary because we want the saved model to perform as well as the validated | ||
model if we load it later. But this may cause problems if you restart the training from checkpoint. | ||
num_gradient_accumulation_steps: ``int``, optional, (default = 1) | ||
Gradients are accumulated for the given number of steps before doing an optimizer step. This can | ||
be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's | ||
[post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation. | ||
""" | ||
super().__init__(serialization_dir, cuda_device) | ||
|
||
|
@@ -250,6 +255,8 @@ def __init__( | |
|
||
self._last_log = 0.0 # time of last logging | ||
|
||
self._num_gradient_accumulation_steps = num_gradient_accumulation_steps | ||
|
||
# Enable activation logging. | ||
if histogram_interval is not None: | ||
self._tensorboard.enable_activation_logging(self.model) | ||
|
@@ -300,12 +307,18 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: | |
# Set the model to "train" mode. | ||
self.model.train() | ||
|
||
num_gpus = len(self._cuda_devices) | ||
# A `batch_group` has chunks of tensors that form a single batch together for an optimizer | ||
# step. A single chunk always contains as many instances as configured in the iterator's | ||
# `batch_size` param. The number of chunks in a single `batch_group` is | ||
# `num_gradient_accumulation_steps` * `num_gpus`. | ||
batch_group_length = self._num_gradient_accumulation_steps * len(self._cuda_devices) | ||
|
||
# Get tqdm for the training batches | ||
raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle) | ||
train_generator = lazy_groups_of(raw_train_generator, num_gpus) | ||
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data) / num_gpus) | ||
train_generator = lazy_groups_of(raw_train_generator, batch_group_length) | ||
num_training_batches = math.ceil( | ||
self.iterator.get_num_batches(self.train_data) / batch_group_length | ||
) | ||
self._last_log = time.time() | ||
last_save_time = time.time() | ||
|
||
|
@@ -325,14 +338,18 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: | |
|
||
self.optimizer.zero_grad() | ||
|
||
loss = self.batch_loss(batch_group, for_training=True) | ||
|
||
if torch.isnan(loss): | ||
raise ValueError("nan loss encountered") | ||
batches_for_step = list(lazy_groups_of(iter(batch_group), len(self._cuda_devices))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The conversion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's needed because I need to know how many there are to scale the loss. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
for batch_for_step in batches_for_step: | ||
loss = self.batch_loss(batch_for_step, for_training=True) | ||
if torch.isnan(loss): | ||
raise ValueError("nan loss encountered") | ||
|
||
loss.backward() | ||
# `len(batches_for_step)` should always be `num_gradient_accumulation_steps`, except | ||
# for the last batch in the epoch. | ||
loss = loss / len(batches_for_step) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be good if the loss per sub-batch was scaled relative to it's proportion of the overall gradient accumulated batch - e.g a batch of size 64 and a batch of size 12 would get weighted evenly here. You can do this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we sure we want this? Sometimes we're already scaling by sample size. For instance, https://github.com/allenai/allennlp/blob/master/allennlp/models/language_model.py#L322. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh! Sorry, my bad. Disregard. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm trying this right now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in a7fc41b. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dirkgr, I might not have been clear. We shouldn't do this. It breaks cases where users have scaled by sample size in their models. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After a long discussion, I reverted a7fc41b. Looks like we're going to break somebody, but this at least keeps the more common cases the same. |
||
loss.backward() | ||
|
||
train_loss += loss.item() | ||
train_loss += loss.item() | ||
|
||
batch_grad_norm = self.rescale_gradients() | ||
|
||
|
@@ -697,6 +714,7 @@ def from_params( # type: ignore | |
grad_clipping = params.pop_float("grad_clipping", None) | ||
lr_scheduler_params = params.pop("learning_rate_scheduler", None) | ||
momentum_scheduler_params = params.pop("momentum_scheduler", None) | ||
num_gradient_accumulation_steps = params.pop("num_gradient_accumulation_steps", 1) | ||
|
||
if isinstance(cuda_device, list): | ||
model_device = cuda_device[0] | ||
|
@@ -779,4 +797,5 @@ def from_params( # type: ignore | |
should_log_learning_rate=should_log_learning_rate, | ||
log_batch_size_period=log_batch_size_period, | ||
moving_average=moving_average, | ||
num_gradient_accumulation_steps=num_gradient_accumulation_steps, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't precisely accurate. Iterators create batches that aren't
batch_size
for many reasons. Maybe say something like, for the "simple case of a BasicIterator".There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"typically contains"?
What are the other possibilities? Gradient accumulation becomes more complicated when the batches aren't all the same size, and this code doesn't handle that case properly. Neither does the multi-GPU code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See this comment: https://github.com/allenai/allennlp/blob/master/training_config/bidirectional_language_model.jsonnet#L28
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maximum_samples_per_batch
being the relevant config option.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of that LM we also scale our loss internally based on the number of tokens.