Skip to content

Commit

Permalink
Add support for gradient checkpointing in BERT (#4659)
Browse files Browse the repository at this point in the history
* add support for gradient checkpointing in BERT

* fix unit tests

* isort

* black

* workaround for `torch.utils.checkpoint.checkpoint` not accepting bool

* Revert "workaround for `torch.utils.checkpoint.checkpoint` not accepting bool"

This reverts commit 5eb68bb.

* workaround for `torch.utils.checkpoint.checkpoint` not accepting bool

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
ibeltagy and LysandreJik authored Jun 22, 2020
1 parent f4e1f02 commit 90f4b24
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
4 changes: 4 additions & 0 deletions src/transformers/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class BertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to False):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -137,3 +140,4 @@ def __init__(
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
35 changes: 27 additions & 8 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import warnings

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

Expand Down Expand Up @@ -391,6 +392,7 @@ def forward(
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

def forward(
Expand All @@ -409,14 +411,31 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
if getattr(self.config, "gradient_checkpointing", False):

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]

if output_attentions:
Expand Down

0 comments on commit 90f4b24

Please sign in to comment.