From 7d638f7e89a40e79fdf3194877d032624cb77ca6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 10 Jul 2024 18:13:45 +0000 Subject: [PATCH 1/9] tmp commit --- src/transformers/generation/utils.py | 19 +++-- .../models/llama/modeling_llama.py | 78 ++++++++++--------- 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6b4a055fba8d..6c1e257482cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -689,13 +689,16 @@ def _update_model_kwargs_for_generation( dim=-1, ) - if ( - model_kwargs.get("use_cache", True) - and "cache_position" in model_kwargs - and model_kwargs["cache_position"] is not None - ): + if model_kwargs.get("use_cache", True): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens - + else: + new_positions = torch.arange( + model_kwargs["cache_position"][-1], + model_kwargs["cache_position"][-1] + num_new_tokens, + device=model_kwargs["cache_position"].device, + dtype=model_kwargs["cache_position"].dtype + ) + model_kwargs["cache_position"] = torch.cat((model_kwargs["cache_position"], new_positions)) return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -1393,10 +1396,6 @@ def _prepare_generation_config( def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" - if not model_kwargs.get("use_cache", True): - model_kwargs["cache_position"] = None - return model_kwargs - past_length = 0 if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6a7a7145ba4b..10c2b4aaee42 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1123,35 +1123,43 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - past_length = 0 - if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + # If we have cache + cache positions: let's slice `input_ids` accordingly, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None and input_ids.shape[1] != cache_position.shape[0]: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0]:] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + + # # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + # past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + # max_cache_length = ( + # torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + # if past_key_values.get_max_length() is not None + # else None + # ) + # cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # # Keep only the unprocessed tokens: + # # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + # if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + # input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # # input_ids based on the past_length. + # elif past_length < input_ids.shape[1]: + # input_ids = input_ids[:, past_length:] + # # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + # if ( + # max_cache_length is not None + # and attention_mask is not None + # and cache_length + input_ids.shape[1] > max_cache_length + # ): + # attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1162,7 +1170,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise @@ -1170,11 +1178,11 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + # input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + # if cache_position is None: + # cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + # if use_cache: + # cache_position = cache_position[-input_length:] model_inputs.update( { From 32e7aa93a3d02c7c418567056ba0a7d37e2282ba Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 10 Jul 2024 18:46:31 +0000 Subject: [PATCH 2/9] shorter --- src/transformers/generation/utils.py | 2 +- .../models/llama/modeling_llama.py | 63 ++----------------- tests/utils/test_cache_utils.py | 4 +- 3 files changed, 8 insertions(+), 61 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6c1e257482cd..d47169e8d9b4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -696,7 +696,7 @@ def _update_model_kwargs_for_generation( model_kwargs["cache_position"][-1], model_kwargs["cache_position"][-1] + num_new_tokens, device=model_kwargs["cache_position"].device, - dtype=model_kwargs["cache_position"].dtype + dtype=model_kwargs["cache_position"].dtype, ) model_kwargs["cache_position"] = torch.cat((model_kwargs["cache_position"], new_positions)) return model_kwargs diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 10c2b4aaee42..c4fd7ea89798 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1114,53 +1114,17 @@ def forward( ) def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - use_cache=True, - **kwargs, + self, input_ids, cache_position, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - # If we have cache + cache positions: let's slice `input_ids` accordingly, to keep only the unprocessed tokens + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None and input_ids.shape[1] != cache_position.shape[0]: if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0]:] + input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - - # # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - # past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - # max_cache_length = ( - # torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - # if past_key_values.get_max_length() is not None - # else None - # ) - # cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # # Keep only the unprocessed tokens: - # # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - # if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - # input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # # input_ids based on the past_length. - # elif past_length < input_ids.shape[1]: - # input_ids = input_ids[:, past_length:] - # # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - # if ( - # max_cache_length is not None - # and attention_mask is not None - # and cache_length + input_ids.shape[1] > max_cache_length - # ): - # attention_mask = attention_mask[:, -max_cache_length:] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation @@ -1173,37 +1137,20 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. + # The `contiguous()` here is necessary to have a static stride during decoding model_inputs = {"input_ids": input_ids.contiguous()} - # input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - # if cache_position is None: - # cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - # if use_cache: - # cache_position = cache_position[-input_length:] - model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": use_cache, + "use_cache": kwargs.get("use_cache", True), "attention_mask": attention_mask, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 6924bf482f11..b8366cc27765 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -35,8 +35,8 @@ AutoModelForCausalLM, AutoTokenizer, DynamicCache, + GPT2LMHeadModel, LlamaConfig, - LlamaForCausalLM, SinkCache, StaticCache, ) @@ -94,7 +94,7 @@ def test_dynamic_cache_retrocompatibility(self): def test_reorder_cache_retrocompatibility(self): """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" - legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function + legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function legacy_cache = () new_cache = DynamicCache() From ab64f9fb6ccc336bd2024669da4ce151973faee9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 10 Jul 2024 18:53:07 +0000 Subject: [PATCH 3/9] nit --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c4fd7ea89798..d460557b3972 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1119,7 +1119,7 @@ def prepare_inputs_for_generation( # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None and input_ids.shape[1] != cache_position.shape[0]: + if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) From 5c0f6afbd478d5188d8b5aefd8a7838c27b8e005 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Jul 2024 09:53:48 +0000 Subject: [PATCH 4/9] explicit kwargs --- src/transformers/models/llama/modeling_llama.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d460557b3972..5c0c57f3effe 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1114,7 +1114,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, cache_position, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries @@ -1125,7 +1133,6 @@ def prepare_inputs_for_generation( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1137,15 +1144,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding - model_inputs = {"input_ids": input_ids.contiguous()} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache", True), + "use_cache": use_cache, "attention_mask": attention_mask, } ) From e8e085b1913ee3f022f98d202d30ece5815fe288 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Jul 2024 10:14:04 +0000 Subject: [PATCH 5/9] propagate changes --- .../models/cohere/modeling_cohere.py | 59 ++++--------------- src/transformers/models/dbrx/modeling_dbrx.py | 59 ++++--------------- .../models/idefics2/modeling_idefics2.py | 2 +- .../models/jetmoe/modeling_jetmoe.py | 10 ---- src/transformers/models/olmo/modeling_olmo.py | 59 ++++--------------- src/transformers/models/phi/modeling_phi.py | 10 ---- src/transformers/models/phi3/modeling_phi3.py | 10 ---- 7 files changed, 31 insertions(+), 178 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index e8b94fafe2a1..5322c2334d37 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1057,40 +1057,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1099,19 +1078,10 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { @@ -1123,12 +1093,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 13e2b14830ee..31810028ef44 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1330,40 +1330,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1372,19 +1351,10 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { @@ -1396,12 +1366,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - @staticmethod - def _reorder_cache(past_key_values: Cache, beam_idx: torch.LongTensor): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 3f487d379ecd..4d978c053d3f 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1697,7 +1697,7 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ return model_kwargs @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + # Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 40783b3524f0..93fa3d407224 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1366,16 +1366,6 @@ def prepare_inputs_for_generation( ) return model_inputs - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 9b7db4e1786d..4fd0c9268683 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1098,40 +1098,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1140,19 +1119,10 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { @@ -1164,12 +1134,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 7e7600d8eba7..28029303cde8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1274,16 +1274,6 @@ def prepare_inputs_for_generation( ) return model_inputs - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index b8d1d0f3f39b..a593a45fadb8 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1270,16 +1270,6 @@ def prepare_inputs_for_generation( ) return model_inputs - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ From 50c8260f1ea675398e2b1203c5049b28d4b58e29 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Jul 2024 11:30:22 +0000 Subject: [PATCH 6/9] mass propagation with a few manual touches (let's see how CI behaves) --- .../models/gemma/modeling_gemma.py | 59 +++------------ src/transformers/models/gemma2/diff_gemma2.py | 74 ------------------- .../models/gemma2/modeling_gemma2.py | 61 +++------------ .../models/jamba/modeling_jamba.py | 40 ++++------ .../models/jetmoe/modeling_jetmoe.py | 65 ++++------------ .../models/mistral/modeling_mistral.py | 68 +++-------------- .../models/mixtral/modeling_mixtral.py | 62 +++------------- .../models/paligemma/modeling_paligemma.py | 53 ++++--------- .../models/persimmon/modeling_persimmon.py | 60 +++------------ src/transformers/models/phi/modeling_phi.py | 52 +++---------- src/transformers/models/phi3/modeling_phi3.py | 52 +++---------- .../models/qwen2/modeling_qwen2.py | 61 +++------------ .../models/qwen2_moe/modeling_qwen2_moe.py | 61 +++------------ .../models/stablelm/modeling_stablelm.py | 60 +++------------ .../models/starcoder2/modeling_starcoder2.py | 61 +++------------ .../models/whisper/modeling_whisper.py | 9 --- 16 files changed, 167 insertions(+), 731 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a78869cb1976..80e97fe700b5 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1067,40 +1067,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1109,19 +1088,10 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { @@ -1134,15 +1104,6 @@ def prepare_inputs_for_generation( ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index bdf97a0f02ea..0e300c6337e2 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -569,80 +569,6 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - use_cache=True, - **kwargs, - ): - past_length = 0 - if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else torch.tensor(0, device=input_ids.device) - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) - return model_inputs - class Gemma2ForSequenceClassification(GemmaForSequenceClassification): pass diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 063b38cbfb70..10d00fa460ba 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -994,40 +994,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else torch.tensor(0, device=input_ids.device) - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1036,19 +1015,10 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { @@ -1061,15 +1031,6 @@ def prepare_inputs_for_generation( ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 9b1fc301a312..5682c53aea19 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1544,39 +1544,25 @@ def prepare_inputs_for_generation( inputs_embeds=None, output_router_logits=False, cache_position=None, + position_ids=None, + use_cache=True, **kwargs, ): empty_past_kv = past_key_values is None - # Omit tokens covered by past_key_values - if not empty_past_kv: - past_length = cache_position[0] if cache_position is not None else attention_mask.shape[1] - max_cache_length = self.config.sliding_window - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and past_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] else: past_key_values = HybridMambaAttentionDynamicCache( self.config, input_ids.shape[0], self.dtype, device=self.device ) - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1585,16 +1571,16 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and empty_past_kv: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, "num_logits_to_keep": self.config.num_logits_to_keep, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 93fa3d407224..16d8335e0a52 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1280,6 +1280,7 @@ def forward( router_logits=outputs.router_logits, ) + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1288,47 +1289,19 @@ def prepare_inputs_for_generation( inputs_embeds=None, cache_position=None, output_router_logits=False, + position_ids=None, + use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1337,29 +1310,17 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: - cache_position = cache_position[-input_length:] - - if has_static_cache: - past_key_values = None + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, } diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4625fe8933b6..5bd74a71e772 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1070,6 +1070,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1077,42 +1078,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 - # Omit tokens covered by past_key_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1120,26 +1098,11 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - # crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache - if ( - past_length > 0 - and attention_mask is not None - and isinstance(past_key_values, SlidingWindowCache) - and attention_mask.shape[1] > past_key_values.max_cache_len - ): - attention_mask = attention_mask[:, -past_key_values.max_cache_len :] - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { @@ -1152,15 +1115,6 @@ def prepare_inputs_for_generation( ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 778e7a741dff..4b88afcded37 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1286,44 +1286,21 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - output_router_logits=False, cache_position=None, + output_router_logits=False, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 - # Omit tokens covered by past_key_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1332,38 +1309,23 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, - "cache_position": cache_position, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index a640e7c7465a..8a693e56f80b 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -518,46 +518,22 @@ def prepare_inputs_for_generation( past_key_values=None, inputs_embeds=None, cache_position=None, + position_ids=None, pixel_values=None, attention_mask=None, token_type_ids=None, + use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - # here we need to recall past_length is num_image_tokens + previous input_ids. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - position_ids = kwargs.get("position_ids", None) + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -566,23 +542,20 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "cache_position": cache_position, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "pixel_values": pixel_values, "token_type_ids": token_type_ids, } ) return model_inputs - - def _reorder_cache(self, *args, **kwargs): - return self.language_model._reorder_cache(*args, **kwargs) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index fc7bcb74f6bb..fc1b729fa654 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -918,6 +918,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -925,41 +926,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -968,37 +947,22 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "cache_position": cache_position, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 28029303cde8..7ad34a578083 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1201,7 +1201,7 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1209,41 +1209,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1252,24 +1230,18 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "cache_position": cache_position, } ) return model_inputs diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index a593a45fadb8..b7d05bbed6ca 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1197,7 +1197,7 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1205,41 +1205,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1248,24 +1226,18 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "cache_position": cache_position, } ) return model_inputs diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b58f256eb09a..68923ed4052d 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1093,6 +1093,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1100,42 +1101,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 - # Omit tokens covered by past_key_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1144,37 +1122,22 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "cache_position": cache_position, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 7e4fbc492568..c20d74fb18c4 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1289,6 +1289,7 @@ def forward( router_logits=outputs.router_logits, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1296,42 +1297,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 - # Omit tokens covered by past_key_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1340,37 +1318,22 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "cache_position": cache_position, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 2346325d99ec..a17218361802 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1194,6 +1194,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1201,41 +1202,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1244,37 +1223,22 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "cache_position": cache_position, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 26fb21bbbf75..430befd24ae3 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1072,6 +1072,7 @@ def forward( attentions=outputs.attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1079,42 +1080,19 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, use_cache=True, **kwargs, ): - past_length = 0 - # Omit tokens covered by past_key_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1123,37 +1101,22 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: + if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_length:] + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "cache_position": cache_position, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 91d630fea6cd..7ba2af00ad81 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1811,15 +1811,6 @@ def prepare_inputs_for_generation( "cache_position": cache_position, } - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - class WhisperDecoderWrapper(WhisperPreTrainedModel): """ From d0ff984e762020a982a8bd64ffff17da0c85a7db Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Jul 2024 12:37:32 +0000 Subject: [PATCH 7/9] fix cacheless case --- src/transformers/generation/utils.py | 4 ++-- src/transformers/models/jamba/modeling_jamba.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d47169e8d9b4..77e67e7b2b0b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -693,8 +693,8 @@ def _update_model_kwargs_for_generation( model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: new_positions = torch.arange( - model_kwargs["cache_position"][-1], - model_kwargs["cache_position"][-1] + num_new_tokens, + model_kwargs["cache_position"][-1] + 1, + model_kwargs["cache_position"][-1] + num_new_tokens + 1, device=model_kwargs["cache_position"].device, dtype=model_kwargs["cache_position"].dtype, ) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 5682c53aea19..768e8e016075 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1553,7 +1553,7 @@ def prepare_inputs_for_generation( # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: + if not empty_past_kv: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) @@ -1571,7 +1571,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and empty_past_kv: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases From 8223ec4df0b440be3c2f3dac41be7209aaa3b9aa Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 12 Jul 2024 10:52:19 +0100 Subject: [PATCH 8/9] Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 77e67e7b2b0b..43988780a373 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -692,13 +692,11 @@ def _update_model_kwargs_for_generation( if model_kwargs.get("use_cache", True): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: + previous_positions = model_kwargs.pop("cache_position") new_positions = torch.arange( - model_kwargs["cache_position"][-1] + 1, - model_kwargs["cache_position"][-1] + num_new_tokens + 1, - device=model_kwargs["cache_position"].device, - dtype=model_kwargs["cache_position"].dtype, + previous_positions[-1] + 1, previous_positions[-1] + num_new_tokens + 1, device=previous_positions.device, dtype=previous_positions.dtype, ) - model_kwargs["cache_position"] = torch.cat((model_kwargs["cache_position"], new_positions)) + model_kwargs["cache_position"] = torch.cat((previous_positions, new_positions)) return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): From 02f741735c025af4b64d45ce80df7073a24f0676 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 12 Jul 2024 10:09:06 +0000 Subject: [PATCH 9/9] make fixup --- src/transformers/generation/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 43988780a373..5c05328d0f2d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -692,11 +692,11 @@ def _update_model_kwargs_for_generation( if model_kwargs.get("use_cache", True): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: - previous_positions = model_kwargs.pop("cache_position") + past_positions = model_kwargs.pop("cache_position") new_positions = torch.arange( - previous_positions[-1] + 1, previous_positions[-1] + num_new_tokens + 1, device=previous_positions.device, dtype=previous_positions.dtype, - ) - model_kwargs["cache_position"] = torch.cat((previous_positions, new_positions)) + past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype + ).to(past_positions.device) + model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) return model_kwargs def _reorder_cache(self, past_key_values, beam_idx):