Skip to content

Commit

Permalink
[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to ma…
Browse files Browse the repository at this point in the history
…ximize DMA bandwidth (vllm-project#13245)

Signed-off-by: Lingfan Yu <lingfany@amazon.com>
  • Loading branch information
lingfanyu authored Feb 21, 2025
1 parent 71face8 commit 3317008
Show file tree
Hide file tree
Showing 3 changed files with 764 additions and 348 deletions.
153 changes: 153 additions & 0 deletions tests/neuron/test_block_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
import os

import neuronxcc.nki.language as nl
import pytest
import torch
import torch.nn.functional as F
from neuronxcc import nki

from vllm.attention.ops.nki_flash_attn import (
load_block_tables, transform_block_tables_for_indirect_load)


def is_power_of_2(n):
return n > 0 and (n & (n - 1) == 0)


def nki_load_and_transform_block_tables(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
num_blocks_per_tile)

# we need to pass an Index as head_id
head_id = nl.arange(1)[None, :] + head_id

block_tables_transposed = transform_block_tables_for_indirect_load(
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
B_P_SIZE = 128
assert block_tables_transposed.shape[1] == B_P_SIZE

out = nl.ndarray(
block_tables_transposed.shape,
dtype=nl.int32,
buffer=nl.shared_hbm,
)
for i in nl.affine_range(block_tables_transposed.shape[0]):
nl.store(dst=out[i], value=block_tables_transposed[i])
return out


def ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert block_tables.numel() == num_tiles * num_blocks_per_tile
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
B_F_SIZE = 128
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
block_tables = F.pad(
block_tables,
(0, 0, 0, num_tiles_padded - num_tiles),
"constant",
0,
)

block_tables = block_tables * num_head + head_id
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
block_tables = block_tables * block_size_tiling_factor + offset
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()

num_blocks_per_tile = block_tables_transposed.shape[0]
assert num_blocks_per_tile % B_F_SIZE == 0
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
B_F_SIZE, num_tiles_padded)


@pytest.mark.parametrize(
"q_head_per_kv_head,head_id",
[
(1, 0),
(3, 1),
],
)
@pytest.mark.parametrize(
"num_tiles,num_blocks_per_tile",
[
(1, 1),
(13, 16),
(17, 128),
(35, 512),
(128, 128),
(130, 64),
(280, 256),
(315, 1),
],
)
@torch.inference_mode()
def test_load_and_transform_block_tables(
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
) -> None:
import torch_xla.core.xla_model as xm

device = xm.xla_device()

compiler_flags = [
"-O1",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str

torch.manual_seed(10000)
torch.set_printoptions(sci_mode=False)

# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
B_P_SIZE = 128
if num_blocks_per_tile < B_P_SIZE:
assert B_P_SIZE % num_blocks_per_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
else:
block_size_tiling_factor = 1
max_num_blocks = 100000
block_tables = torch.randint(
0,
max_num_blocks,
(num_tiles * num_blocks_per_tile, ),
dtype=torch.int32,
)
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
block_tables.to(device=device),
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
).cpu()
ref_out = ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
)
assert (nki_out.shape == ref_out.shape
), f"{nki_out.shape=} != {ref_out.shape=}"
assert torch.all(nki_out == ref_out)
Loading

0 comments on commit 3317008

Please sign in to comment.