Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLeanup] Revert SDPA attention changes that got in the static kv cache PR #29027

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,34 +659,28 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if (
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
): # user defined causal mask
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
# this one liner is equivalent to the pad_unpad function
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
Expand All @@ -695,9 +689,10 @@ def forward(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
27 changes: 11 additions & 16 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,34 +736,28 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if (
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
): # user defined causal mask
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
# this one liner is equivalent to the pad_unpad function
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
Expand All @@ -772,9 +766,10 @@ def forward(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
27 changes: 11 additions & 16 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,34 +669,28 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if (
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
): # user defined causal mask
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
# this one liner is equivalent to the pad_unpad function
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
Expand All @@ -705,9 +699,10 @@ def forward(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
Loading