diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 655c3fd8d20..e9f52d63f38 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -60,6 +60,7 @@ def __init__( should_log_learning_rate: bool = False, log_batch_size_period: Optional[int] = None, moving_average: Optional[MovingAverage] = None, + gradient_accumulation_batch_size: int = None ) -> None: """ A trainer for doing supervised learning. It just takes a labeled dataset @@ -171,6 +172,9 @@ def __init__( parameters. Be careful that when saving the checkpoint, we will save the moving averages of parameters. This is necessary because we want the saved model to perform as well as the validated model if we load it later. But this may cause problems if you restart the training from checkpoint. + gradient_accumulation_batch_size: ``int``, (default = None) + if provided, then accumulate gradients until the effective batch + size is at least this value. """ super().__init__(serialization_dir, cuda_device) @@ -254,6 +258,8 @@ def __init__( if histogram_interval is not None: self._tensorboard.enable_activation_logging(self.model) + self.gradient_accumulation_batch_size = gradient_accumulation_batch_size + def rescale_gradients(self) -> Optional[float]: return training_util.rescale_gradients(self.model, self._grad_norm) @@ -318,22 +324,54 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: logger.info("Training") train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches) cumulative_batch_size = 0 + + accumulated_batches = [] + accumulated_batch_sizes = [] + for batch_group in train_generator_tqdm: + accumulated_batches.append(batch_group) + cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group]) + accumulated_batch_sizes.append(cur_batch) + effective_batch_size = sum(accumulated_batch_sizes) + + # check to see if this is a gradient update step + if self.gradient_accumulation_batch_size is None: + do_update_grads = True + else: + if effective_batch_size >= self.gradient_accumulation_batch_size: + do_update_grads = True + else: + do_update_grads = False + + if not do_update_grads: + # get another batch from the generator + continue + + # else run the forward/backward for each batch batches_this_epoch += 1 self._batch_num_total += 1 batch_num_total = self._batch_num_total self.optimizer.zero_grad() - loss = self.batch_loss(batch_group, for_training=True) + # process all the accumulated gradients + for this_batch, this_batch_size in zip( + accumulated_batches, accumulated_batch_sizes + ): + loss = self.batch_loss(this_batch, for_training=True) + loss = loss * (this_batch_size / float(effective_batch_size)) + + if torch.isnan(loss): + raise ValueError("nan loss encountered") - if torch.isnan(loss): - raise ValueError("nan loss encountered") + loss.backward() - loss.backward() + train_loss += loss.item() - train_loss += loss.item() + accumulated_batches = [] + accumulated_batch_sizes = [] + # now update the gradients batch_grad_norm = self.rescale_gradients() # This does nothing if batch_num_total is None or you are using a @@ -697,6 +735,7 @@ def from_params( # type: ignore grad_clipping = params.pop_float("grad_clipping", None) lr_scheduler_params = params.pop("learning_rate_scheduler", None) momentum_scheduler_params = params.pop("momentum_scheduler", None) + gradient_accumulation_batch_size = params.pop_int("gradient_accumulation_batch_size", None) if isinstance(cuda_device, list): model_device = cuda_device[0] @@ -779,4 +818,5 @@ def from_params( # type: ignore should_log_learning_rate=should_log_learning_rate, log_batch_size_period=log_batch_size_period, moving_average=moving_average, + gradient_accumulation_batch_size=gradient_accumulation_batch_size )