Skip to content

Commit

Permalink
Added past_key_value in different conditions.
Browse files Browse the repository at this point in the history
Signed-off-by: gyou2021 <ganmei.you@intel.com>
  • Loading branch information
gyou2021 committed Mar 6, 2025
1 parent 3ca79ea commit 1dc414b
Showing 1 changed file with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,16 @@ def prefill_forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if token_idx is None:
if hasattr(past_key_value, "get_usable_length"):
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value[0].shape[-2]
else:
if reuse_cache:
kv_seq_len = past_key_value[0][-2]
else:
kv_seq_len = past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids)
# update & get all compressed_kv, k_pe
Expand Down Expand Up @@ -1247,7 +1256,7 @@ def prefill_forward(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

assert attention_mask is not None
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
Expand Down Expand Up @@ -1360,14 +1369,14 @@ def decode_forward(
torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)
) * self.softmax_scale

assert attention_mask is not None
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
assert attention_mask is not None
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_nope.dtype)
Expand Down

0 comments on commit 1dc414b

Please sign in to comment.