Skip to content

Commit

Permalink
add example for chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed May 4, 2020
1 parent c113844 commit 668c4bc
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,18 @@ def apply_chunking_to_forward(
input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
Returns:
a Tensor with the same shape the foward_fn would have given if applied
Examples::
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
"""

assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
Expand Down

0 comments on commit 668c4bc

Please sign in to comment.