diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c8a83dbe592f..ee51bcea794e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -363,6 +363,12 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. @@ -385,11 +391,16 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute - if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window: + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): slicing_tokens = 1 - self.config.sliding_window - past_key = past_key_value[0] - past_value = past_key_value[1] + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() @@ -400,8 +411,6 @@ def forward( f" {past_key.shape}" ) - past_key_value = (past_key, past_value) - if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 7268673441fe..3ae8567317ef 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -414,6 +414,12 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. @@ -436,11 +442,16 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute - if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window: + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): slicing_tokens = 1 - self.config.sliding_window - past_key = past_key_value[0] - past_value = past_key_value[1] + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() @@ -451,8 +462,6 @@ def forward( f" {past_key.shape}" ) - past_key_value = (past_key, past_value) - if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)