From 668c4bc74d242dedf387251443c62f5e3ec2fcab Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 22:16:47 +0200 Subject: [PATCH] add example for chunking --- src/transformers/modeling_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 821f322ac4e0..02ae2240bc33 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2115,6 +2115,18 @@ def apply_chunking_to_forward( input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked Returns: a Tensor with the same shape the foward_fn would have given if applied + + + Examples:: + + # rename the usual forward() fn to forward_chunk() + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + # implement a chunked forward function + def forward(self, hidden_states): + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) """ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)