Skip to content

Commit

Permalink
fix: fix bugs introduced in #132 (#135)
Browse files Browse the repository at this point in the history
The way #132 computes `num_kv_chunks` is buggy for short inputs, this PR
fixes the issue.
  • Loading branch information
yzh119 authored Feb 25, 2024
1 parent 0372acc commit 9b7b0b9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
25 changes: 12 additions & 13 deletions include/flashinfer/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1526,13 +1526,13 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads,
smem_size));
uint32_t num_chunks =
min((num_blocks_per_sm * num_sm) /
(num_kv_heads *
ceil_div(qo_len * group_size, num_rows_per_cta)),
kv_len / 128);
uint32_t chunk_size = ceil_div(kv_len, num_chunks);
num_chunks = ceil_div(kv_len, chunk_size);
uint32_t max_num_kv_chunks =
(num_blocks_per_sm * num_sm) /
(num_kv_heads *
ceil_div(qo_len * group_size, num_rows_per_cta));
uint32_t chunk_size =
max(ceil_div(kv_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(kv_len, chunk_size);

max_grid_size = num_blocks_per_sm * num_sm;
if (num_chunks > 1) {
Expand Down Expand Up @@ -1623,12 +1623,11 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
uint32_t num_chunks =
min((num_blocks_per_sm * num_sm) /
(num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)),
kv_len / 128);
uint32_t chunk_size = ceil_div(kv_len, num_chunks);
num_chunks = ceil_div(kv_len, chunk_size);
uint32_t max_num_kv_chunks =
(num_blocks_per_sm * num_sm) /
(num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta));
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(kv_len, chunk_size);

if (num_chunks <= 1 || tmp == nullptr) {
// Enough parallelism, do not split-kv
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@pytest.mark.parametrize("page_size", [1, 8, 16])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
def test_batch_decode_with_paged_kv_cache(
batch_size,
Expand Down

0 comments on commit 9b7b0b9

Please sign in to comment.