diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index cc1ff28ec..5d7e3642c 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -1526,13 +1526,13 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( cudaOccupancyMaxActiveBlocksPerMultiprocessor( &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); - uint32_t num_chunks = - min((num_blocks_per_sm * num_sm) / - (num_kv_heads * - ceil_div(qo_len * group_size, num_rows_per_cta)), - kv_len / 128); - uint32_t chunk_size = ceil_div(kv_len, num_chunks); - num_chunks = ceil_div(kv_len, chunk_size); + uint32_t max_num_kv_chunks = + (num_blocks_per_sm * num_sm) / + (num_kv_heads * + ceil_div(qo_len * group_size, num_rows_per_cta)); + uint32_t chunk_size = + max(ceil_div(kv_len, max_num_kv_chunks), 256); + uint32_t num_chunks = ceil_div(kv_len, chunk_size); max_grid_size = num_blocks_per_sm * num_sm; if (num_chunks > 1) { @@ -1623,12 +1623,11 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); - uint32_t num_chunks = - min((num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)), - kv_len / 128); - uint32_t chunk_size = ceil_div(kv_len, num_chunks); - num_chunks = ceil_div(kv_len, chunk_size); + uint32_t max_num_kv_chunks = + (num_blocks_per_sm * num_sm) / + (num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)); + uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); + uint32_t num_chunks = ceil_div(kv_len, chunk_size); if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index e396efd18..0b83d2de2 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) -@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) def test_batch_decode_with_paged_kv_cache( batch_size,