diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7af79099a0bd..7f4caf26aeac 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -667,7 +667,11 @@ def _update_model_kwargs_for_generation( dim=-1, ) - if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: + if ( + model_kwargs.get("use_cache", True) + and "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens return model_kwargs @@ -1293,6 +1297,10 @@ def _prepare_generation_config( def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + if not model_kwargs.get("use_cache", True): + model_kwargs["cache_position"] = None + return model_kwargs + past_length = 0 if "past_key_values" in model_kwargs: if isinstance(model_kwargs["past_key_values"], Cache):