diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1f5a164815aa..200852b11b34 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -971,13 +971,14 @@ def update( return k_out, v_out def get_max_length(self) -> Optional[int]: - # in theory there is no limit because the sliding window size is fixed - # no matter how long the sentence is + # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is return None def reset(self): - self.key_cache.zero_() - self.value_cache.zero_() + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() class EncoderDecoderCache(Cache):