Skip to content

Commit

Permalink
handle non-contiguous states
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Feb 21, 2024
1 parent 9bf6c62 commit 4a03d87
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,21 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -947,7 +956,8 @@ def forward(
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa":
# for output_attentions case used fallback to eager attention realization
elif self._attn_implementation == "sdpa" and not output_attentions:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
Expand Down

0 comments on commit 4a03d87

Please sign in to comment.