Skip to content

Commit

Permalink
misc: use the new plan/run API for unittests (#467)
Browse files Browse the repository at this point in the history
In the previous PR #466 we replace the old-style
`begin_forward`/`end_forward`/`forward` APIs with the new `plan`/`run`
APIs, but didn't update the unit tests accordingly (this is intentional
because we want a commit that keeps unit tests that uses the old-style
API to check backward compatibility).

This PR updates the unit tests with new APIs.

Some other changes:
- Remove old-style APIs from docstring.
- Fix some errors in docstring with new APIs.
  • Loading branch information
yzh119 authored Aug 26, 2024
1 parent d940d2e commit 78ec6db
Show file tree
Hide file tree
Showing 17 changed files with 156 additions and 306 deletions.
6 changes: 6 additions & 0 deletions docs/api/python/cascade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ Cascade Attention Wrapper Classes
.. autoclass:: MultiLevelCascadeAttentionWrapper
:members:

.. automethod:: __init__

:exclude-members: begin_forward, end_forward, forward, forward_return_lse

.. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper
:members:

.. automethod:: __init__

.. autoclass:: BatchPrefillWithSharedPrefixPagedKVCacheWrapper
:members:

.. automethod:: __init__

2 changes: 2 additions & 0 deletions docs/api/python/decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Batch Decoding
:members:

.. automethod:: __init__

:exclude-members: begin_forward, end_forward, forward, forward_return_lse

.. autoclass:: CUDAGraphBatchDecodeWithPagedKVCacheWrapper
:members:
Expand Down
4 changes: 4 additions & 0 deletions docs/api/python/prefill.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ Batch Prefill/Append Attention

.. automethod:: __init__

:exclude-members: begin_forward, end_forward, forward, forward_return_lse

.. autoclass:: BatchPrefillWithRaggedKVCacheWrapper
:members:

.. automethod:: __init__

:exclude-members: begin_forward, end_forward, forward, forward_return_lse
4 changes: 3 additions & 1 deletion docs/api/python/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ Kernels for block sparse flashattention.
.. autoclass:: BlockSparseAttentionWrapper
:members:

.. automethod:: __init__
.. automethod:: __init__

:exclude-members: begin_forward, end_forward, forward, forward_return_lse
2 changes: 1 addition & 1 deletion docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ When using multi-level `cascade inference <https://flashinfer.ai/2024/02/02/casc
the query and output are stored in ragged tensors, and KV-Cache of all levels are stored
in a unified Paged KV-Cache. Each level has a unique ``qo_indptr`` array which is the prefix sum of the
accumulated number of tokens to append in the subtree, as well as ``kv_page_indptr``, ``kv_page_indices``, and
``kv_last_page_len`` which has same semantics as in :ref:`<page-layout>` section. The following figure
``kv_last_page_len`` which has same semantics as in :ref:`page-layout` section. The following figure
introduce how to construct these data structures for append attention operation for 8 requests where we
treat their KV-Cache as 3 levels for prefix reuse:

Expand Down
22 changes: 9 additions & 13 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,14 @@ class BatchPrefillWithPagedKVCacheWrapper:
... num_kv_heads,
... head_dim,
... page_size,
... causal=True,
... )
>>> outputs = []
>>> for i in range(num_layers):
... q = q_at_layer[i]
... kv_cache = kv_cache_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o = prefill_wrapper.run(
... q, kv_cache, causal=True
... )
... o = prefill_wrapper.run(q, kv_cache)
... outputs.append(o)
...
>>> outputs[0].shape
Expand All @@ -513,20 +512,18 @@ class BatchPrefillWithPagedKVCacheWrapper:
... num_kv_heads,
... head_dim,
... page_size,
... mask
... custom_mask=mask,
... )
>>> outputs_custom_mask = []
>>> for i in range(num_layers):
... q = q_at_layer[i]
... kv_cache = kv_cache_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o_custom = prefill_wrapper.run(
... q, kv_cache
... )
... o_custom = prefill_wrapper.run(q, kv_cache)
... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3)
...
Note
----
To accelerate computation, FlashInfer's batch prefill/append attention operators
Expand Down Expand Up @@ -1161,17 +1158,16 @@ class BatchPrefillWithRaggedKVCacheWrapper:
... kv_indptr,
... num_qo_heads,
... num_kv_heads,
... head_dim
... head_dim,
... causal=True,
... )
>>> outputs = []
>>> for i in range(num_layers):
... q = q_at_layer[i]
... k = k_at_layer[i]
... v = v_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o = prefill_wrapper.run(
... q, k, v, causal=True
... )
... o = prefill_wrapper.run(q, k, v)
... outputs.append(o)
...
>>> outputs[0].shape
Expand All @@ -1195,7 +1191,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
... num_qo_heads,
... num_kv_heads,
... head_dim,
... mask
... custom_mask=mask
... )
>>> outputs_custom_mask = []
>>> for i in range(num_layers):
Expand Down
53 changes: 16 additions & 37 deletions python/tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,33 +65,23 @@ def test_batch_decode_with_paged_kv_cache(

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)
wrapper.begin_forward(
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
logits_soft_cap=logits_soft_cap,
pos_encoding_mode=pos_encoding_mode,
data_type=kv_dtype,
q_data_type=q_dtype,
)
if return_lse:
o, _ = wrapper.forward_return_lse(
q,
kv_data.to(kv_dtype),
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o, _ = wrapper.run_return_lse(q, kv_data.to(kv_dtype))
else:
o = wrapper.forward(
q,
kv_data.to(kv_dtype),
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o = wrapper.run(q, kv_data.to(kv_dtype))

for i in range(batch_size):
perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3]
Expand Down Expand Up @@ -186,33 +176,23 @@ def test_batch_decode_with_tuple_paged_kv_cache(

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)
wrapper.begin_forward(
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
logits_soft_cap=logits_soft_cap,
pos_encoding_mode=pos_encoding_mode,
data_type=kv_dtype,
q_data_type=q_dtype,
)
if return_lse:
o, _ = wrapper.forward_return_lse(
q,
tuple(map(lambda _: _.to(kv_dtype), kv_data)),
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o, _ = wrapper.run_return_lse(q, tuple(map(lambda _: _.to(kv_dtype), kv_data)))
else:
o = wrapper.forward(
q,
tuple(map(lambda _: _.to(kv_dtype), kv_data)),
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o = wrapper.run(q, tuple(map(lambda _: _.to(kv_dtype), kv_data)))

k_cache, v_cache = kv_data
for i in range(batch_size):
Expand Down Expand Up @@ -313,48 +293,47 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
kv_last_page_device_buffer,
kv_layout,
)
wrapper.begin_forward(
wrapper.plan(
kv_indptr_host_warmup,
kv_indices_host_warmup,
kv_last_page_len_host_warmup,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
data_type=kv_dtype,
pos_encoding_mode=pos_encoding_mode,
q_data_type=q_dtype,
)
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
o = wrapper.forward(q, kv_data_dtype, pos_encoding_mode=pos_encoding_mode)
o = wrapper.run(q, kv_data_dtype)
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
o = wrapper.forward(q, kv_data_dtype, pos_encoding_mode=pos_encoding_mode)
wrapper.end_forward()
o = wrapper.run(q, kv_data_dtype)

# replay multiple times
for i in range(1, min(4, num_pages_per_seq)):
kv_indptr_host = torch.arange(0, batch_size + 1).int() * i
kv_indices_host = torch.arange(0, i * batch_size).int()
kv_last_page_len_host = torch.full((batch_size,), page_size, dtype=torch.int32)

wrapper.begin_forward(
wrapper.plan(
kv_indptr_host,
kv_indices_host,
kv_last_page_len_host,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
data_type=kv_dtype,
pos_encoding_mode=pos_encoding_mode,
q_data_type=q_dtype,
)
g.replay()
Expand All @@ -366,16 +345,16 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32
)

wrapper.begin_forward(
wrapper.plan(
kv_indptr_host,
kv_indices_host,
kv_last_page_len_host,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
data_type=kv_dtype,
pos_encoding_mode=pos_encoding_mode,
q_data_type=q_dtype,
)
g.replay()
Expand Down
Loading

0 comments on commit 78ec6db

Please sign in to comment.