Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mask slicing for models with HybridCache #35681

Merged
merged 16 commits into from
Jan 28, 2025
14 changes: 7 additions & 7 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# Here we need to slice as we use a static cache by default, but FA2 does not support it
if self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
Expand Down Expand Up @@ -344,18 +349,13 @@ def forward(
"""

if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
if self.config._attn_implementation != "flash_attention_2":
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# Here we need to slice as we use a static cache by default, but FA2 does not support it
if self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
Expand Down Expand Up @@ -363,18 +368,13 @@ def forward(
"""

if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
if self.config._attn_implementation != "flash_attention_2":
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# Here we need to slice as we use a static cache by default, but FA2 does not support it
if self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
Expand Down Expand Up @@ -278,18 +283,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
if self.config._attn_implementation != "flash_attention_2":
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
Comment on lines -291 to -292
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super important ! Why is it removed?
I know it is counter intuitive, but _flash_attention_forward takes the attention mask to pad / unpad the input itds.
Thus you need the slicing otherwise this operation fails, see the blame !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I was too fast on this one, the HybridCache behaves slightly differently than I remembered. There was still an issue in the slicing during prefill for FA2 though!


residual = hidden_states

Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# Here we need to slice as we use a static cache by default, but FA2 does not support it
if self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
Expand Down Expand Up @@ -314,18 +319,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
if self.config._attn_implementation != "flash_attention_2":
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

Expand Down
Loading