Skip to content

Commit

Permalink
feat: expose decoupled kv-cache to pytorch api (#383)
Browse files Browse the repository at this point in the history
Followup of #379
  • Loading branch information
yzh119 authored Jul 20, 2024
1 parent c6f20d1 commit 457a0ae
Show file tree
Hide file tree
Showing 17 changed files with 966 additions and 287 deletions.
6 changes: 0 additions & 6 deletions docs/api/python/cascade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ Merge Attention States
Cascade Attention
-----------------

.. autosummary::
:toctree: ../../generated

batch_decode_with_shared_prefix_padded_kv_cache


Cascade Attention Wrapper Classes
---------------------------------

Expand Down
6 changes: 0 additions & 6 deletions docs/api/python/decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ Single Request Decoding
Batch Decoding
--------------

.. autosummary::
:toctree: ../../generated

batch_decode_with_padded_kv_cache
batch_decode_with_padded_kv_cache_return_lse

.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
:members:

Expand Down
17 changes: 13 additions & 4 deletions docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,23 @@ The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed
The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``.
The overall ``kv_last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``.

The ``kv_data`` tensor is a 5-D tensor with shape (in ``NHD`` layout):
The ``kv_data`` tensor could either be a single 5-D tensor or a tuple of 4-D tensors,
when stored in a single tensor, ``kv_data`` has shape:

.. code::
.. code:: python
(max_num_pages, 2, page_size, num_heads, head_dim)
(max_num_pages, 2, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, 2, num_heads, page_size, head_dim) # HND layout
when stored in a tuple of tensors, ``kv_data = (k_data, v_data)``, and each one of them has shape:

.. code:: python
(max_num_pages, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, num_heads, page_size, head_dim) # HND layout
where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens
we fit into each page. ``2`` is the number of slots in each page (first one for keys, the second one for values).
we fit into each page. ``2`` in single tensor storage means K/V (first one for keys, the second one for values).

FlashInfer APIs
~~~~~~~~~~~~~~~
Expand Down
42 changes: 42 additions & 0 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,48 @@ struct paged_kv_t {
last_page_len(nullptr),
rope_pos_offset(nullptr) {}

/*!
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param kv_data The flattened key-value cache
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* kv_data,
DType* k_data, DType* v_data, IdType* indices, IdType* indptr,
IdType* last_page_len, IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
bool kv_defined = kv_data != nullptr;
if (kv_defined) {
stride_page = 2 * num_heads * page_size * head_dim;
this->k_data = kv_data;
this->v_data = kv_data + num_heads * page_size * head_dim;
} else {
stride_page = num_heads * page_size * head_dim;
this->k_data = k_data;
this->v_data = v_data;
}
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}

/*!
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
Expand Down
80 changes: 61 additions & 19 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,40 +105,71 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse) {
torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
CHECK_INPUT(q);
CHECK_INPUT(paged_kv_data);
bool paged_kv_defined = paged_kv_cache.has_value();
if (paged_kv_defined) {
CHECK_INPUT(paged_kv_cache.value());
} else {
CHECK_INPUT(paged_k_cache.value());
CHECK_INPUT(paged_v_cache.value());
}
CHECK_INPUT(paged_kv_indptr);
CHECK_INPUT(paged_kv_indices);
CHECK_INPUT(paged_kv_last_page_len);
auto device = q.device();
CHECK_EQ(paged_kv_data.device(), device);
if (paged_kv_defined) {
CHECK_EQ(paged_kv_cache->device(), device);
} else {
CHECK_EQ(paged_k_cache->device(), device);
CHECK_EQ(paged_v_cache->device(), device);
}
CHECK_EQ(paged_kv_indices.device(), device);
CHECK_EQ(paged_kv_indptr.device(), device);
CHECK_EQ(paged_kv_last_page_len.device(), device);
CHECK_DIM(3, q); // (B, H_qo, D)
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
CHECK_DIM(1, paged_kv_indptr); // (B+1,)
CHECK_DIM(1, paged_kv_indices); // (nnz,)
// (num_max_pages, 2, H_kv, page_size, head_dim) for HND
// (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
CHECK_DIM(5, paged_kv_data);
if (paged_kv_defined) {
// (num_max_pages, 2, H_kv, page_size, head_dim) for HND
// (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
CHECK_DIM(5, paged_kv_cache.value());
} else {
// (num_max_pages, H_kv, page_size, head_dim) for HND
// (num_max_pages, page_size, H_kv, head_dim) for NHD
CHECK_DIM(4, paged_k_cache.value());
CHECK_DIM(4, paged_v_cache.value());
}
int64_t batch_size = q.size(0);
int64_t num_qo_heads = q.size(1);
int64_t head_dim = q.size(2);
int64_t num_kv_heads, page_size;
if (kv_layout_ == QKVLayout::kHND) {
num_kv_heads = paged_kv_data.size(2);
page_size = paged_kv_data.size(3);
if (paged_kv_defined) {
CHECK_EQ(paged_kv_cache->size(1), 2);
CHECK_EQ(paged_kv_cache->size(4), head_dim);
if (kv_layout_ == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}
} else {
page_size = paged_kv_data.size(2);
num_kv_heads = paged_kv_data.size(3);
CHECK_EQ(paged_k_cache->size(3), head_dim);
CHECK_EQ(paged_v_cache->size(3), head_dim);
if (kv_layout_ == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
}
CHECK_EQ(paged_kv_data.size(1), 2);
CHECK_EQ(paged_kv_data.size(4), head_dim);
CHECK_GE(paged_kv_indptr.size(0), batch_size + 1);
CHECK_GE(paged_kv_last_page_len.size(0), batch_size);
// TODO(Zihao): support dispatching to different data types
Expand All @@ -159,7 +190,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;

auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_kv_data.scalar_type();
auto kv_scalar_type =
paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();

if (q_scalar_type == kv_scalar_type) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] {
Expand All @@ -169,7 +201,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
paged_kv_t<PageStorage::kIndices, qkv_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
static_cast<qkv_type*>(paged_kv_data.data_ptr()),
static_cast<qkv_type*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<qkv_type*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
: nullptr),
static_cast<qkv_type*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
: nullptr),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
Expand Down Expand Up @@ -197,7 +234,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
paged_kv_t<PageStorage::kIndices, kv_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
static_cast<kv_type*>(paged_kv_data.data_ptr()),
static_cast<kv_type*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<kv_type*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
: nullptr),
static_cast<kv_type*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
: nullptr),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
Expand Down
Loading

0 comments on commit 457a0ae

Please sign in to comment.