diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b167cd1d1170..d572b8c8c716 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1083,7 +1083,7 @@ def get_max_length(self) -> Optional[int]: # no matter how long the sentence is return self.max_cache_len - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0): return None def reset(self): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2686f3af7af3..9c69bb35d264 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1399,7 +1399,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): cache = model_kwargs["past_key_values"] if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] - elif hasattr(cache, "get_seq_length"): + elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() if "inputs_embeds" in model_kwargs: