Skip to content

Commit

Permalink
workaround for torch.utils.checkpoint.checkpoint not accepting bool
Browse files Browse the repository at this point in the history
  • Loading branch information
ibeltagy committed Jun 12, 2020
1 parent b36648f commit 5eb68bb
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,14 +413,20 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if getattr(self.config, "gradient_checkpointing", False):

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
layer_module,
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
Expand Down

0 comments on commit 5eb68bb

Please sign in to comment.