forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to ma…
…ximize DMA bandwidth (vllm-project#13245) Signed-off-by: Lingfan Yu <lingfany@amazon.com>
- Loading branch information
Showing
3 changed files
with
764 additions
and
348 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import os | ||
|
||
import neuronxcc.nki.language as nl | ||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from neuronxcc import nki | ||
|
||
from vllm.attention.ops.nki_flash_attn import ( | ||
load_block_tables, transform_block_tables_for_indirect_load) | ||
|
||
|
||
def is_power_of_2(n): | ||
return n > 0 and (n & (n - 1) == 0) | ||
|
||
|
||
def nki_load_and_transform_block_tables( | ||
block_tables, | ||
num_tiles, | ||
num_blocks_per_tile, | ||
num_head, | ||
head_id, | ||
block_size_tiling_factor, | ||
): | ||
assert is_power_of_2( | ||
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" | ||
block_tables_sbuf = load_block_tables(block_tables, num_tiles, | ||
num_blocks_per_tile) | ||
|
||
# we need to pass an Index as head_id | ||
head_id = nl.arange(1)[None, :] + head_id | ||
|
||
block_tables_transposed = transform_block_tables_for_indirect_load( | ||
block_tables_sbuf, block_size_tiling_factor, num_head, head_id) | ||
B_P_SIZE = 128 | ||
assert block_tables_transposed.shape[1] == B_P_SIZE | ||
|
||
out = nl.ndarray( | ||
block_tables_transposed.shape, | ||
dtype=nl.int32, | ||
buffer=nl.shared_hbm, | ||
) | ||
for i in nl.affine_range(block_tables_transposed.shape[0]): | ||
nl.store(dst=out[i], value=block_tables_transposed[i]) | ||
return out | ||
|
||
|
||
def ref_block_tables_transform( | ||
block_tables, | ||
num_tiles, | ||
num_blocks_per_tile, | ||
num_head, | ||
head_id, | ||
block_size_tiling_factor, | ||
): | ||
assert block_tables.numel() == num_tiles * num_blocks_per_tile | ||
block_tables = block_tables.view(num_tiles, num_blocks_per_tile) | ||
B_F_SIZE = 128 | ||
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE | ||
block_tables = F.pad( | ||
block_tables, | ||
(0, 0, 0, num_tiles_padded - num_tiles), | ||
"constant", | ||
0, | ||
) | ||
|
||
block_tables = block_tables * num_head + head_id | ||
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) | ||
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) | ||
block_tables = block_tables * block_size_tiling_factor + offset | ||
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() | ||
|
||
num_blocks_per_tile = block_tables_transposed.shape[0] | ||
assert num_blocks_per_tile % B_F_SIZE == 0 | ||
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, | ||
B_F_SIZE, num_tiles_padded) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"q_head_per_kv_head,head_id", | ||
[ | ||
(1, 0), | ||
(3, 1), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"num_tiles,num_blocks_per_tile", | ||
[ | ||
(1, 1), | ||
(13, 16), | ||
(17, 128), | ||
(35, 512), | ||
(128, 128), | ||
(130, 64), | ||
(280, 256), | ||
(315, 1), | ||
], | ||
) | ||
@torch.inference_mode() | ||
def test_load_and_transform_block_tables( | ||
num_tiles, | ||
num_blocks_per_tile, | ||
q_head_per_kv_head, | ||
head_id, | ||
) -> None: | ||
import torch_xla.core.xla_model as xm | ||
|
||
device = xm.xla_device() | ||
|
||
compiler_flags = [ | ||
"-O1", | ||
"--retry_failed_compilation", | ||
] | ||
compiler_flags_str = " ".join(compiler_flags) | ||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str | ||
|
||
torch.manual_seed(10000) | ||
torch.set_printoptions(sci_mode=False) | ||
|
||
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient | ||
B_P_SIZE = 128 | ||
if num_blocks_per_tile < B_P_SIZE: | ||
assert B_P_SIZE % num_blocks_per_tile == 0 | ||
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile | ||
else: | ||
block_size_tiling_factor = 1 | ||
max_num_blocks = 100000 | ||
block_tables = torch.randint( | ||
0, | ||
max_num_blocks, | ||
(num_tiles * num_blocks_per_tile, ), | ||
dtype=torch.int32, | ||
) | ||
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( | ||
block_tables.to(device=device), | ||
num_tiles, | ||
num_blocks_per_tile, | ||
q_head_per_kv_head, | ||
head_id, | ||
block_size_tiling_factor, | ||
).cpu() | ||
ref_out = ref_block_tables_transform( | ||
block_tables, | ||
num_tiles, | ||
num_blocks_per_tile, | ||
q_head_per_kv_head, | ||
head_id, | ||
block_size_tiling_factor, | ||
) | ||
assert (nki_out.shape == ref_out.shape | ||
), f"{nki_out.shape=} != {ref_out.shape=}" | ||
assert torch.all(nki_out == ref_out) |
Oops, something went wrong.