Skip to content

Commit

Permalink
perf: use packed bit array for attention mask (flashinfer-ai#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Jun 16, 2024
1 parent 876cc53 commit 3d43dc9
Show file tree
Hide file tree
Showing 23 changed files with 593 additions and 128 deletions.
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
set(FLASHINFER_GEN_MASK_MODES 0 1)
set(FLASHINFER_GEN_MASK_MODES 0 1 2)

# Set target cuda architectures for tests/benchmarks, defaults to native.
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.
Expand Down
14 changes: 14 additions & 0 deletions docs/api/python/quantization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _apiquantization:

flashinfer.quantization
=======================

Quantization related kernels.

.. currentmodule:: flashinfer.quantization

.. autosummary::
:toctree: _generate

packbits
segment_packbits
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ FlashInfer is a library for Language Languages Models that provides high-perform
api/python/sampling
api/python/group_gemm
api/python/norm
api/python/quantization
8 changes: 8 additions & 0 deletions docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,21 @@ to store the start offset of each request's mask in the flattened mask array: ``
``mask_data`` has shape ``(qk_indptr[-1],)``, we can use ``mask_data[qk_indptr[i]:qk_indptr[i+1]]`` to slice the flattened
mask of request ``i``.

To save memory, we can further packes the boolean flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements
are packed together as a `uint8`) with "little" bit-order (see `numpy.packbits <https://numpy.org/doc/stable/reference/generated/numpy.packbits.html>`_
for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed
array internally.

FlashInfer APIs
~~~~~~~~~~~~~~~

:class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`
allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions,
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.

:meth:`flashinfer.quantization.packbits` and :meth:`flashinfer.quantization.segment_packbits` are the utility functions
to pack boolean mask into bit-packed array.

.. _page-layout:

Page Table
Expand Down
28 changes: 14 additions & 14 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ template <bool partition_kv, MaskMode mask_mode, uint32_t num_warps, uint32_t nu
__device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
const uint32_t kv_idx_base, const uint32_t qo_len,
const uint32_t kv_len, const uint32_t chunk_end,
const uint_fastdiv group_size, float* custom_mask,
const uint_fastdiv group_size, uint8_t* custom_mask,
DTypeQKAccum (*s_frag)[num_frags_z][8]) {
const uint32_t tx = threadIdx.x;
#pragma unroll
Expand All @@ -565,11 +565,11 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end))
: kv_idx >= chunk_end);
s_frag[fx][fz][reg_id] =
out_of_boundary ? DTypeQKAccum(-5e4)
: s_frag[fx][fz][reg_id] +
DTypeQKAccum((mask_mode == MaskMode::kCustom && q_idx < qo_len)
? custom_mask[q_idx * kv_len + kv_idx]
: 0.f);
(out_of_boundary ||
((mask_mode == MaskMode::kCustom && q_idx < qo_len &&
!(custom_mask[(q_idx * kv_len + kv_idx) / 8] >> ((q_idx * kv_len + kv_idx) % 8)))))
? DTypeQKAccum(-5e4)
: s_frag[fx][fz][reg_id];
}
}
}
Expand Down Expand Up @@ -891,7 +891,7 @@ template <LogitsPostHook logits_post_hook, bool partition_kv, MaskMode mask_mode
typename DTypeQKAccum, typename DTypeOut>
__global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k,
DTypeIn* __restrict__ v,
float* __restrict__ custom_mask,
uint8_t* __restrict__ custom_mask,
DTypeOut* __restrict__ o, void* __restrict__ tmp,
float* __restrict__ lse, const uint32_t qo_len,
const uint32_t kv_len, const uint_fastdiv group_size,
Expand Down Expand Up @@ -1107,7 +1107,7 @@ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, QKVLayout kv_layo
__global__ void BatchPrefillWithRaggedKVCacheKernel(
DTypeIn* __restrict__ q, IdType* __restrict__ request_indices,
IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k,
DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, float* __restrict__ custom_mask,
DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask,
IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset,
IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp,
float* __restrict__ lse, uint32_t batch_size, const uint_fastdiv group_size, float sm_scale,
Expand Down Expand Up @@ -1324,9 +1324,9 @@ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode p
__global__ void BatchPrefillWithPagedKVCacheKernel(
IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
IdType* __restrict__ qo_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr,
IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp,
float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale,
IdType* __restrict__ qo_indptr, uint8_t* __restrict__ custom_mask,
IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o,
float* __restrict__ tmp, float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale,
float log2_rope_rcp_scale, float log2_rope_rcp_theta) {
static_assert(sizeof(DTypeIn) == 2);
static_assert(sizeof(DTypeOut) == 2);
Expand Down Expand Up @@ -1534,7 +1534,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
float* custom_mask, DTypeOut* o, float* tmp,
uint8_t* custom_mask, DTypeOut* o, float* tmp,
float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, float sm_scale, float rope_scale,
Expand Down Expand Up @@ -1674,7 +1674,7 @@ template <uint32_t num_frags_x, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HO
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size,
const uint32_t num_qo_heads, const uint32_t num_qo_tiles, const uint32_t num_kv_heads,
const float sm_scale, const float rope_scale, const float rope_theta,
Expand Down Expand Up @@ -1758,7 +1758,7 @@ template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, float* custom_mask,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads,
uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/prefill_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
float* custom_mask, DTypeOut* o, float* tmp,
uint8_t* custom_mask, DTypeOut* o, float* tmp,
float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, float sm_scale, float rope_scale,
Expand All @@ -43,7 +43,7 @@ template <uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HO
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size,
uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream = nullptr);
Expand All @@ -54,7 +54,7 @@ template <PageStorage PAGE_STORAGE, uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, float* custom_mask,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles,
uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);

Expand All @@ -63,7 +63,7 @@ template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POS
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, float* custom_mask,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream) {
float* tmp = nullptr;
Expand Down Expand Up @@ -98,7 +98,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
Expand Down
114 changes: 114 additions & 0 deletions include/flashinfer/quantization.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_QUANTIZATION_CUH_
#define FLASHINFER_QUANTIZATION_CUH_
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>

#include <cub/cub.cuh>

#include "utils.cuh"

namespace flashinfer {
namespace quantization {

enum class BitOrder { kBig = 0U, kLittle = 1U };

#define DISPATCH_BITORDER(bitorder, BITORDER, ...) \
if (bitorder == BitOrder::kBig) { \
constexpr BitOrder BITORDER = BitOrder::kBig; \
__VA_ARGS__ \
} else { \
constexpr BitOrder BITORDER = BitOrder::kLittle; \
__VA_ARGS__ \
}

template <BitOrder BITORDER>
__global__ void PackBitsKernel(bool* input, uint8_t* output, int64_t num_elements) {
int64_t start_offset = blockIdx.x * blockDim.x * 8, tx = threadIdx.x;
uint8_t ret = 0;
bool input_vec[8];
typedef cub::BlockLoad<bool, 256, 8, cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage temp_storage;
BlockLoad(temp_storage)
.Load(input + start_offset, input_vec, num_elements - start_offset, /*default=*/0);

if constexpr (BITORDER == BitOrder::kBig) {
ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) |
(input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7];
} else {
ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) |
(input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0];
}
if (start_offset + tx * 8 < num_elements) output[start_offset / 8 + tx] = ret;
}

template <BitOrder BITORDER, typename IdType>
__global__ void SegmentPackBitsKernel(bool* input, uint8_t* output, IdType* input_indptr,
IdType* output_indptr) {
int64_t bx = blockIdx.x, tx = threadIdx.x;
bool input_vec[8];
typedef cub::BlockLoad<bool, 256, 8, cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage temp_storage;
int64_t num_elements = input_indptr[bx + 1] - input_indptr[bx];
for (uint32_t start_offset = 0; start_offset < num_elements; start_offset += 8 * blockDim.x) {
uint8_t ret = 0;
BlockLoad(temp_storage)
.Load(input + input_indptr[bx] + start_offset, input_vec, num_elements - start_offset,
/*default=*/0);

if constexpr (BITORDER == BitOrder::kBig) {
ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) |
(input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7];
} else {
ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) |
(input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0];
}
if (start_offset + tx * 8 < num_elements)
output[output_indptr[bx] + start_offset / 8 + tx] = ret;
}
}

cudaError_t PackBits(bool* input, uint8_t* output, int64_t num_elements, BitOrder bitorder,
cudaStream_t stream) {
DISPATCH_BITORDER(bitorder, BITORDER, {
auto kernel = PackBitsKernel<BITORDER>;
const dim3 nthrs(256);
const dim3 nblks(ceil_div(num_elements, nthrs.x * 8));
void* args[] = {&input, &output, &num_elements};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
}

template <typename IdType>
cudaError_t SegmentPackBits(bool* input, uint8_t* output, IdType* input_indptr,
IdType* output_indptr, uint32_t batch_size, BitOrder bitorder,
cudaStream_t stream) {
DISPATCH_BITORDER(bitorder, BITORDER, {
auto kernel = SegmentPackBitsKernel<BITORDER, IdType>;
const dim3 nthrs(256);
const dim3 nblks(batch_size);
void* args[] = {&input, &output, &input_indptr, &output_indptr};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
}

} // namespace quantization
} // namespace flashinfer

#endif // FLASHINFER_QUANTIZATION_CUH_
4 changes: 2 additions & 2 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
handler_.get(), static_cast<c_type*>(q.data_ptr()),
static_cast<int32_t*>(qo_indptr.data_ptr()),
/*q_offset=*/nullptr, paged_kv,
static_cast<float*>(custom_mask.data_ptr()),
static_cast<uint8_t*>(custom_mask.data_ptr()),
static_cast<int32_t*>(qk_indptr.data_ptr()),
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
Expand Down Expand Up @@ -434,7 +434,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
static_cast<int32_t*>(qo_indptr.data_ptr()),
static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(v.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<float*>(custom_mask.data_ptr()),
static_cast<uint8_t*>(custom_mask.data_ptr()),
static_cast<int32_t*>(qk_indptr.data_ptr()),
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<c_type*>(o.data_ptr()),
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool, unsigned int>())
Expand Down
Loading

0 comments on commit 3d43dc9

Please sign in to comment.