Skip to content

Commit

Permalink
Fix cache_position initialisation for generation with `use_cache=Fa…
Browse files Browse the repository at this point in the history
…lse` (#30485)

* Fix cache_position init for generation

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix cache position update

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
nurlanov-zh and ArthurZucker authored May 7, 2024
1 parent 54a2361 commit 4fda78c
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,11 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
if (
model_kwargs.get("use_cache", True)
and "cache_position" in model_kwargs
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

return model_kwargs
Expand Down Expand Up @@ -1293,6 +1297,10 @@ def _prepare_generation_config(

def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
if not model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = None
return model_kwargs

past_length = 0
if "past_key_values" in model_kwargs:
if isinstance(model_kwargs["past_key_values"], Cache):
Expand Down

0 comments on commit 4fda78c

Please sign in to comment.