Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PagedPrefill python api and some typos #441

Merged
merged 3 commits into from
Aug 13, 2024

Conversation

jianfei-wangg
Copy link
Contributor

Fix two small bugs:

  1. “NHD” and "HND" used confusing
  2. PagedPrefill use self._custom_mask_buf to judge whether is customized_mask, but uninitialized
    Here is the code snippet to reproduce the 2nd bug:
import torch
import flashinfer

# try to reproduce the bug under speculative decoding case
device = torch.device("cuda:0")
num_heads = 32
num_qo_heads = num_heads
num_kv_heads = 32
head_dim = 128
page_size = 4
max_num_pages = 4
batch_size = 1
seq_len = 4
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.bfloat16, device=device)
packed_kv_cache = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device)
ragged_key_cache = packed_kv_cache[:, 0].reshape(-1, num_kv_heads, head_dim)
ragged_value_cache = packed_kv_cache[:, 1].reshape(-1, num_kv_heads, head_dim)

# [4, 15] shape
attn_mask = torch.tensor([
    [ True,  True,  True,  True,  True,  True,  True,  True, False, False, False,  True, False, False, False],
    [ True,  True,  True,  True,  True,  True,  True, False,  True, False, False, False,  True, False, False],
    [ True,  True,  True,  True,  True,  True,  True,  True, False, False, False, False, False,  True, False],
    [ True,  True,  True,  True,  True,  True,  True, False, False,  True, False, False, False, False,  True]
    ], device=device)

mask = attn_mask.reshape(-1)
# packed_mask = flashinfer.quantization.packbits(mask)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
paged_prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, "NHD"
)
kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
kv_page_indptr = torch.tensor(
    [0, 4], dtype=torch.int32, device="cuda:0"
)
# 1 <= kv_last_page_len <= page_size
kv_last_page_len = torch.tensor(
    [3], dtype=torch.int32, device="cuda:0"
)
qo_indptr = torch.tensor(
[0, 4], dtype=torch.int32, device="cuda:0")

# create auxiliary data structures for batch decode attention
paged_prefill_wrapper.begin_forward(
    qo_indptr,
    kv_page_indptr,
    kv_page_indices,
    kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    mask,
    q_data_type=torch.bfloat16
)
# assert torch.equal(paged_prefill_wrapper._custom_mask, packed_mask)
# assert paged_prefill_wrapper._custom_mask_buf is not None
q = query
o = paged_prefill_wrapper.forward(q, packed_kv_cache, causal=False)
paged_prefill_wrapper.end_forward()

# ragged attn
workspace_buffer_ragged = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
ragged_prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
    workspace_buffer_ragged, "NHD"
)
kv_indptr = torch.tensor(
    [0, 15], dtype=torch.int32, device="cuda:0"
)
ragged_prefill_wrapper.begin_forward(
    qo_indptr,
    kv_indptr,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    mask,
    q_data_type='bfloat16'
    )
ragged_o = ragged_prefill_wrapper.forward(q, ragged_key_cache, ragged_value_cache)
ragged_prefill_wrapper.end_forward()
print("query shape: ", q.shape)
print("paged vs ragged allclose: ", torch.allclose(o, ragged_o, rtol=1e-3, atol=1e-3))
print("paged vs ragged equal: ", torch.equal(o, ragged_o))
assert torch.allclose(o, ragged_o, rtol=1e-3, atol=1e-3)
assert torch.equal(o, ragged_o)

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @jianfei-wangg , thanks for your contribution!

@yzh119 yzh119 merged commit 3fff008 into flashinfer-ai:main Aug 13, 2024
yzh119 added a commit that referenced this pull request Aug 13, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.5](v0.1.4...v0.1.5)
(2024-08-13)


### Bugfix

* Fix PagedPrefill python api and some typos
([#441](#441))
([3fff008](3fff008))
* fix prefill kernels' lse result for empty kv-cache
([#440](#440))
([6ac28f4](6ac28f4))

### Features

* decouple float and int workspace buffer
([#442](#442))
([a7ee566](a7ee566))


### Performance Improvements

* faster fp8-&gt;fp16 dequantization for pre sm_90 arch
([#439](#439))
([c93f647](c93f647))

### Acknowledgement

We thank contributions and feedbacks from the community:
[@comaniac](https://github.com/comaniac),
[@hnyls2002](https://github.com/hnyls2002),
[@jianfei-wangg](https://github.com/jianfei-wangg),
[@Yard1](https://github.com/Yard1).


---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants