Skip to content

Commit

Permalink
Resolved alibi bias issue due to porting flat PA pr
Browse files Browse the repository at this point in the history
Signed-off-by: Tanner Voas <tanner.voas@intel.com>
  • Loading branch information
tannervoas742 committed Nov 6, 2024
1 parent 0063520 commit 960fc31
Showing 1 changed file with 48 additions and 3 deletions.
51 changes: 48 additions & 3 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,52 @@ def block_softmax(batch_size, attn, block_mapping, block_scales, block_groups):
return attn


def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op, matmul_av_op, keys_fetch_func,
values_fetch_func):
def process_alibi_biases(alibi_biases, block_list, block_size, batch_size, num_heads):
# alibi_biases: [1, num_heads, seq_length, seq_length]
# block_list: [num_blocks]
# block_size: scalar integer
# Assume batch_size and num_heads are inferred from alibi_biases and block_list
seq_length = alibi_biases.size(2)
num_blocks = block_list.size(0)
block_per_sequence = seq_length // block_size + (1 if seq_length % block_size > 0 else 0)

# Step 1: Reduce alibi_biases to shape [1, num_heads, seq_length]
alibi_biases = alibi_biases[:, :, -1:, :].squeeze(-2) # Shape [1, num_heads, seq_length]

# Step 2: Expand alibi_biases to shape [batch_size, num_heads, seq_length]
alibi_biases = alibi_biases.expand(batch_size, num_heads, seq_length)

# Step 3: reshape alibi_biases to [batch_size, block_per_sequence, num_heads, block_size]
alibi_blocks = alibi_biases.view(batch_size, num_heads, block_per_sequence, block_size)
alibi_blocks = alibi_blocks.permute(0, 2, 1, 3)
alibi_blocks = alibi_blocks.contiguous()
alibi_blocks = alibi_blocks.view(-1, num_heads, block_size)

# Step 4: Use block_list to index into alibi_blocks
output = torch.zeros((num_blocks, num_heads, block_size), device=block_list.device, dtype=alibi_biases.dtype)
output[block_list[:alibi_blocks.size(0)]] = alibi_blocks
output = output.unsqueeze(-2)

return output

def flat_pa(
query,
key_cache,
value_cache,
block_list,
block_mapping,
block_bias,
block_scales,
block_groups,
scale,
alibi_slopes,
matmul_qk_op,
matmul_av_op,
keys_fetch_func,
values_fetch_func,
):
batch_size = query.size(0)
block_size = key_cache.size(1)
q_heads = query.size(1)
kv_heads = key_cache.size(2)

Expand All @@ -158,6 +200,9 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
key = key.transpose(2, 3)

attn = matmul_qk_op(query, key) + block_bias
if alibi_slopes is not None:
block_alibi_slopes = process_alibi_biases(alibi_slopes, block_list, block_size, batch_size, kv_heads)
attn.add_(block_alibi_slopes)
attn = block_softmax(batch_size, attn, block_mapping, block_scales, block_groups)
attn = matmul_av_op(attn, value)
attn = block2batch(attn, block_mapping)
Expand Down

0 comments on commit 960fc31

Please sign in to comment.