Skip to content

Commit

Permalink
add gradient accumulation - tested on a single GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
eladsegal committed Nov 28, 2019
1 parent a4aeafa commit 9bf0282
Showing 1 changed file with 45 additions and 5 deletions.
50 changes: 45 additions & 5 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
should_log_learning_rate: bool = False,
log_batch_size_period: Optional[int] = None,
moving_average: Optional[MovingAverage] = None,
gradient_accumulation_batch_size: int = None
) -> None:
"""
A trainer for doing supervised learning. It just takes a labeled dataset
Expand Down Expand Up @@ -171,6 +172,9 @@ def __init__(
parameters. Be careful that when saving the checkpoint, we will save the moving averages of
parameters. This is necessary because we want the saved model to perform as well as the validated
model if we load it later. But this may cause problems if you restart the training from checkpoint.
gradient_accumulation_batch_size: ``int``, (default = None)
if provided, then accumulate gradients until the effective batch
size is at least this value.
"""
super().__init__(serialization_dir, cuda_device)

Expand Down Expand Up @@ -254,6 +258,8 @@ def __init__(
if histogram_interval is not None:
self._tensorboard.enable_activation_logging(self.model)

self.gradient_accumulation_batch_size = gradient_accumulation_batch_size

def rescale_gradients(self) -> Optional[float]:
return training_util.rescale_gradients(self.model, self._grad_norm)

Expand Down Expand Up @@ -318,22 +324,54 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
logger.info("Training")
train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches)
cumulative_batch_size = 0

accumulated_batches = []
accumulated_batch_sizes = []

for batch_group in train_generator_tqdm:
accumulated_batches.append(batch_group)
cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
accumulated_batch_sizes.append(cur_batch)
effective_batch_size = sum(accumulated_batch_sizes)

# check to see if this is a gradient update step
if self.gradient_accumulation_batch_size is None:
do_update_grads = True
else:
if effective_batch_size >= self.gradient_accumulation_batch_size:
do_update_grads = True
else:
do_update_grads = False

if not do_update_grads:
# get another batch from the generator
continue

# else run the forward/backward for each batch
batches_this_epoch += 1
self._batch_num_total += 1
batch_num_total = self._batch_num_total

self.optimizer.zero_grad()

loss = self.batch_loss(batch_group, for_training=True)
# process all the accumulated gradients
for this_batch, this_batch_size in zip(
accumulated_batches, accumulated_batch_sizes
):
loss = self.batch_loss(this_batch, for_training=True)
loss = loss * (this_batch_size / float(effective_batch_size))

if torch.isnan(loss):
raise ValueError("nan loss encountered")

if torch.isnan(loss):
raise ValueError("nan loss encountered")
loss.backward()

loss.backward()
train_loss += loss.item()

train_loss += loss.item()
accumulated_batches = []
accumulated_batch_sizes = []

# now update the gradients
batch_grad_norm = self.rescale_gradients()

# This does nothing if batch_num_total is None or you are using a
Expand Down Expand Up @@ -697,6 +735,7 @@ def from_params( # type: ignore
grad_clipping = params.pop_float("grad_clipping", None)
lr_scheduler_params = params.pop("learning_rate_scheduler", None)
momentum_scheduler_params = params.pop("momentum_scheduler", None)
gradient_accumulation_batch_size = params.pop_int("gradient_accumulation_batch_size", None)

if isinstance(cuda_device, list):
model_device = cuda_device[0]
Expand Down Expand Up @@ -779,4 +818,5 @@ def from_params( # type: ignore
should_log_learning_rate=should_log_learning_rate,
log_batch_size_period=log_batch_size_period,
moving_average=moving_average,
gradient_accumulation_batch_size=gradient_accumulation_batch_size
)

1 comment on commit 9bf0282

@eladsegal
Copy link
Owner Author

@eladsegal eladsegal commented on 9bf0282 Nov 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.