From 90f4b2452077ac3bac9453bdc63e0359aa4fe4d2 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Mon, 22 Jun 2020 07:47:14 -0700 Subject: [PATCH] Add support for gradient checkpointing in BERT (#4659) * add support for gradient checkpointing in BERT * fix unit tests * isort * black * workaround for `torch.utils.checkpoint.checkpoint` not accepting bool * Revert "workaround for `torch.utils.checkpoint.checkpoint` not accepting bool" This reverts commit 5eb68bb804f5ffbfc7ba13c45a47717f72d04574. * workaround for `torch.utils.checkpoint.checkpoint` not accepting bool Co-authored-by: Lysandre Debut --- src/transformers/configuration_bert.py | 4 +++ src/transformers/modeling_bert.py | 35 ++++++++++++++++++++------ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/transformers/configuration_bert.py b/src/transformers/configuration_bert.py index d03f573c541a..b1beceb215b6 100644 --- a/src/transformers/configuration_bert.py +++ b/src/transformers/configuration_bert.py @@ -90,6 +90,8 @@ class BertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, optional, defaults to False): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -121,6 +123,7 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, + gradient_checkpointing=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -137,3 +140,4 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps + self.gradient_checkpointing = gradient_checkpointing diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 20b1a80dfdd3..313dc47d227d 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -22,6 +22,7 @@ import warnings import torch +import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss, MSELoss @@ -391,6 +392,7 @@ def forward( class BertEncoder(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) def forward( @@ -409,14 +411,31 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module( - hidden_states, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: