Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add Gradient accumulation support to the default trainer #2721

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions allennlp/tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import time
from typing import Dict
from pathlib import Path

import torch
import pytest
Expand Down Expand Up @@ -678,6 +679,30 @@ def test_restoring_works_with_older_checkpointing(self):
assert trainer._metric_tracker._best_so_far == 0.1
assert trainer._metric_tracker._epochs_with_no_improvement == 1

def test_trainer_can_run_gradient_accumulation(self):
num_training_instances = 0
with Path(self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv').open() as input_file:
num_training_instances = sum(1 for i in input_file)

steps_to_accumulate = 2

trainer = Trainer(model=self.model,
optimizer=self.optimizer,
iterator=self.iterator,
train_dataset=self.instances,
validation_dataset=self.instances,
num_epochs=2,
num_gradient_accumulation_steps=steps_to_accumulate)
assert trainer._num_gradient_accumulation_steps == steps_to_accumulate
assert trainer._accumulate_gradients

metrics = trainer.train()

num_batches_trained_per_epoch = trainer._batch_num_total // (metrics["training_epochs"]+1)
num_batches_expected = num_training_instances // self.iterator._batch_size // steps_to_accumulate

assert num_batches_trained_per_epoch == num_batches_expected

class TestSparseClipGrad(AllenNlpTestCase):
def test_sparse_clip_grad(self):
# create a sparse embedding layer, then take gradient
Expand Down
66 changes: 57 additions & 9 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __init__(self,
should_log_parameter_statistics: bool = True,
should_log_learning_rate: bool = False,
log_batch_size_period: Optional[int] = None,
moving_average: Optional[MovingAverage] = None) -> None:
moving_average: Optional[MovingAverage] = None,
num_gradient_accumulation_steps: int = 1) -> None:
"""
A trainer for doing supervised learning. It just takes a labeled dataset
and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights
Expand Down Expand Up @@ -173,6 +174,10 @@ def __init__(self,
parameters. Be careful that when saving the checkpoint, we will save the moving averages of
parameters. This is necessary because we want the saved model to perform as well as the validated
model if we load it later. But this may cause problems if you restart the training from checkpoint.
num_gradient_accumulation_steps: ``int``, optional, (default = 1)
Gradients are accumulated for the given number of steps before doing an optimizer step. This can
be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's
[post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation.
"""
super().__init__(serialization_dir, cuda_device)

Expand Down Expand Up @@ -242,6 +247,16 @@ def __init__(self,

self._last_log = 0.0 # time of last logging

self._num_gradient_accumulation_steps = num_gradient_accumulation_steps

if self._num_gradient_accumulation_steps > 1 and self._multiple_gpu:
logger.warning(
"You have configured to use multiple GPUs along with gradient accumulation."
"Because of this, the effective batch size will be "
"batch_size * num_gradient_accumulation_steps * number of GPUs")

self._accumulate_gradients = self._num_gradient_accumulation_steps > 1

# Enable activation logging.
if histogram_interval is not None:
self._tensorboard.enable_activation_logging(self.model)
Expand Down Expand Up @@ -290,14 +305,35 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
# Set the model to "train" mode.
self.model.train()

num_gpus = len(self._cuda_devices)
# A `batch_group` has chunks of tensors that form a single batch together
# for an optimizer step. The length of a single chunk always pertains to
# the configured `batch_size` param. However, the number of chunks in a
# single `batch_group` corresponds to the way the trainer has been
# configured. The lengths of `batch_group` with possible configurations are:
#
# Singe GPU:
# List of 1 chunk
#
# `n` GPUs:
# Effective batch size here is `batch_size` * `n`. Hence it is a list of
# `n` chunks.
#
# Single GPU with `n` accumulation steps:
# Effective batch size here is `batch_size` * `n`. Hence `batch_group` is a
# list of `n` chunks.
#
# `n` GPUs with `m` accumulation steps:
# Effective batch size here is `batch_size` * `n` * `m`. Hence it is a
# list of `n * m` chunks.

batch_group_length = self._num_gradient_accumulation_steps * len(self._cuda_devices)

# Get tqdm for the training batches
raw_train_generator = self.iterator(self.train_data,
num_epochs=1,
shuffle=self.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data)/num_gpus)
train_generator = lazy_groups_of(raw_train_generator, batch_group_length)
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data)/batch_group_length)
self._last_log = time.time()
last_save_time = time.time()

Expand All @@ -319,12 +355,22 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:

self.optimizer.zero_grad()

loss = self.batch_loss(batch_group, for_training=True)
# batch_group consists of all the tensors necessary to compute a forward
# pass in a single batch. To do gradient accumulation in `n` steps, we split
# this group into sub groups further so that we can do forward pass in `n` iterations.
#
# In case of a single GPU, the size of this sub group essentially is 1. With multiple
# GPUs, every `self.batch_loss()` call should be passed with a group that has a length
# equal to the number of GPUs.
batch_group_for_stepwise_accumulation = lazy_groups_of(iter(batch_group), len(self._cuda_devices))
for batch_for_step in batch_group_for_stepwise_accumulation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am having a really tough time understanding what this code is doing here and why it's doing it, is there a way to clarify it / add comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if it is not succinct enough to understand. I've added few comments there. Hope it clarifies a bit.

The basic idea is something like this. So far batch_group is being used to do a forward pass for both single and multi GPU cases. In the multi GPU case, the tensors in batch_group are aggregated to form a single batch. I just wanted to extend the same flow for gradient accumulation too. The inner for loop in train_epoch tries to call batch_loss for num_steps_to_accumulate times. In each iteration, we use a chunk of the original batch_group to call batch_loss. As usual, the length of the chunk that gets passed in each iteration would be equal to the number of GPUs configured for training. Let me know if this makes sense.

loss = self.batch_loss(batch_for_step, for_training=True)

if torch.isnan(loss):
raise ValueError("nan loss encountered")
if torch.isnan(loss):
raise ValueError("nan loss encountered")

loss.backward()
loss = loss / self._num_gradient_accumulation_steps
loss.backward()

train_loss += loss.item()

Expand Down Expand Up @@ -673,6 +719,7 @@ def from_params(cls, # type: ignore
grad_clipping = params.pop_float("grad_clipping", None)
lr_scheduler_params = params.pop("learning_rate_scheduler", None)
momentum_scheduler_params = params.pop("momentum_scheduler", None)
num_gradient_accumulation_steps = params.pop("num_steps_to_accumulate", 1)

if isinstance(cuda_device, list):
model_device = cuda_device[0]
Expand Down Expand Up @@ -743,7 +790,8 @@ def from_params(cls, # type: ignore
should_log_parameter_statistics=should_log_parameter_statistics,
should_log_learning_rate=should_log_learning_rate,
log_batch_size_period=log_batch_size_period,
moving_average=moving_average)
moving_average=moving_average,
num_gradient_accumulation_steps=num_gradient_accumulation_steps)


class TrainerPieces(NamedTuple):
Expand Down