Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gradient checkpointing in BERT #4659

Merged
merged 9 commits into from
Jun 22, 2020
Prev Previous commit
Next Next commit
fix unit tests
  • Loading branch information
ibeltagy committed May 29, 2020
commit 1765a1430d091e0464a2413354eac362628db56a
10 changes: 7 additions & 3 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,14 @@ def forward(
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if self.config.gradient_checkpointing:
if getattr(self.config, "gradient_checkpointing", False):
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_module, hidden_states, attention_mask, head_mask[i],
encoder_hidden_states, encoder_attention_mask
layer_module,
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
Expand Down