Skip to content

Commit

Permalink
tests: add more unittests for logits cap (#352)
Browse files Browse the repository at this point in the history
followup of #350 
add the case of `logits_soft_case=1.0` to correctness tests.
add batch decode/prefill tests.
  • Loading branch information
yzh119 authored Jul 3, 2024
1 parent f5f7a2a commit d1d443a
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 32 deletions.
27 changes: 20 additions & 7 deletions python/tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0])
@pytest.mark.parametrize("return_lse", [True, False])
@pytest.mark.parametrize("q_dtype", [torch.float16])
@pytest.mark.parametrize(
Expand All @@ -43,6 +44,7 @@ def test_batch_decode_with_paged_kv_cache(
head_dim,
kv_layout,
pos_encoding_mode,
logits_soft_cap,
return_lse,
q_dtype,
kv_dtype,
Expand Down Expand Up @@ -72,16 +74,23 @@ def test_batch_decode_with_paged_kv_cache(
head_dim,
page_size,
"NONE",
logits_soft_cap=logits_soft_cap,
data_type=kv_dtype,
q_data_type=q_dtype,
)
if return_lse:
o, _ = wrapper.forward_return_lse(
q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode
q,
kv_data.to(kv_dtype),
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o = wrapper.forward(
q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode
q,
kv_data.to(kv_dtype),
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)

for i in range(batch_size):
Expand Down Expand Up @@ -119,7 +128,11 @@ def test_batch_decode_with_paged_kv_cache(
dim=0,
).to(kv_dtype)
o_ref_i = flashinfer.single_decode_with_kv_cache(
qi, ki, vi, pos_encoding_mode=pos_encoding_mode
qi,
ki,
vi,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o_i_np = o[i].cpu().numpy()
o_ref_i_np = o_ref_i.cpu().numpy()
Expand Down Expand Up @@ -293,13 +306,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(

if __name__ == "__main__":
test_batch_decode_with_paged_kv_cache(
256, 54, 8, 8, 8, 128, "NHD", "NONE", False, torch.float16, torch.float16
256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16
)
test_batch_decode_with_paged_kv_cache(
12, 2048, 8, 8, 8, 128, "NHD", "NONE", False, torch.float16, torch.float16
12, 2048, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16
)
test_batch_decode_with_paged_kv_cache(
12, 54, 1, 8, 8, 128, "HND", "NONE", True, torch.float16, torch.float8_e5m2
12, 54, 1, 8, 8, 128, "HND", "NONE", 0.0, True, torch.float16, torch.float8_e5m2
)
test_cuda_graph_batch_decode_with_paged_kv_cache(
12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16
Expand All @@ -308,7 +321,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16
)
test_batch_decode_with_paged_kv_cache(
12, 54, 1, 8, 8, 128, "HND", "NONE", True, torch.float16, torch.float8_e5m2
12, 54, 1, 8, 8, 128, "HND", "NONE", 0.0, True, torch.float16, torch.float8_e5m2
)
test_cuda_graph_batch_decode_with_paged_kv_cache(
12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float8_e5m2
Expand Down
139 changes: 116 additions & 23 deletions python/tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"])
@pytest.mark.parametrize("use_cuda_graph", [False, True])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0])
@pytest.mark.parametrize("return_lse", [True, False])
def test_batch_prefill_with_paged_kv_cache(
batch_size,
Expand All @@ -45,6 +46,7 @@ def test_batch_prefill_with_paged_kv_cache(
kv_layout,
pos_encoding_mode,
use_cuda_graph,
logits_soft_cap,
return_lse,
):
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
Expand Down Expand Up @@ -83,9 +85,22 @@ def test_batch_prefill_with_paged_kv_cache(
head_dim,
page_size,
)
o = wrapper.forward(
q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode
)
if return_lse:
o, _ = wrapper.forward_return_lse(
q,
kv_data,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o = wrapper.forward(
q,
kv_data,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
q_indptr_buffer = torch.empty(batch_size + 1).int().to(0)
kv_indptr_buffer = torch.empty(batch_size + 1).int().to(0)
Expand Down Expand Up @@ -124,23 +139,39 @@ def test_batch_prefill_with_paged_kv_cache(
for _ in range(3):
if return_lse:
o, _ = wrapper.forward_return_lse(
q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode
q,
kv_data,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o = wrapper.forward(
q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode
q,
kv_data,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
torch.cuda.current_stream().wait_stream(s)
# capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
if return_lse:
o, _ = wrapper.forward_return_lse(
q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode
q,
kv_data,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o = wrapper.forward(
q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode
q,
kv_data,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
wrapper.end_forward()

Expand Down Expand Up @@ -196,7 +227,12 @@ def test_batch_prefill_with_paged_kv_cache(
dim=0,
)
o_ref_i = flashinfer.single_prefill_with_kv_cache(
qi, ki, vi, causal=causal, pos_encoding_mode=pos_encoding_mode
qi,
ki,
vi,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy()
o_ref_i_np = o_ref_i.cpu().numpy()
Expand All @@ -212,6 +248,7 @@ def test_batch_prefill_with_paged_kv_cache(
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0])
@pytest.mark.parametrize("return_lse", [True, False])
def test_batch_prefill_with_paged_kv_cache_custom_mask(
batch_size,
Expand All @@ -223,6 +260,7 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask(
head_dim,
kv_layout,
pos_encoding_mode,
logits_soft_cap,
return_lse,
):
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
Expand Down Expand Up @@ -269,10 +307,18 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask(
)
if return_lse:
o_custom, _ = wrapper.forward_return_lse(
q, kv_data, pos_encoding_mode=pos_encoding_mode
q,
kv_data,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o_custom = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode)
o_custom = wrapper.forward(
q,
kv_data,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
wrapper.end_forward()

# use causal
Expand All @@ -288,11 +334,19 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask(
)
if return_lse:
o_causal, _ = wrapper.forward_return_lse(
q, kv_data, causal=True, pos_encoding_mode=pos_encoding_mode
q,
kv_data,
causal=True,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o_causal = wrapper.forward(
q, kv_data, causal=True, pos_encoding_mode=pos_encoding_mode
q,
kv_data,
causal=True,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
numpy.testing.assert_allclose(
o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3
Expand All @@ -307,6 +361,7 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask(
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0])
@pytest.mark.parametrize("return_lse", [True, False])
def test_batch_prefill_with_ragged_kv_cache(
batch_size,
Expand All @@ -317,6 +372,7 @@ def test_batch_prefill_with_ragged_kv_cache(
head_dim,
causal,
pos_encoding_mode,
logits_soft_cap,
return_lse,
):
kv_layout = "NHD"
Expand All @@ -340,10 +396,22 @@ def test_batch_prefill_with_ragged_kv_cache(
)
if return_lse:
o, _ = wrapper.forward_return_lse(
q, k, v, causal=causal, pos_encoding_mode=pos_encoding_mode
q,
k,
v,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o = wrapper.forward(q, k, v, causal=causal, pos_encoding_mode=pos_encoding_mode)
o = wrapper.forward(
q,
k,
v,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)

for i in range(batch_size):
o_ref_i = flashinfer.single_prefill_with_kv_cache(
Expand All @@ -352,6 +420,7 @@ def test_batch_prefill_with_ragged_kv_cache(
v[kv_indptr[i] : kv_indptr[i + 1]],
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy()
o_ref_i_np = o_ref_i.cpu().numpy()
Expand All @@ -365,6 +434,7 @@ def test_batch_prefill_with_ragged_kv_cache(
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0])
@pytest.mark.parametrize("return_lse", [True, False])
def test_batch_prefill_with_ragged_kv_cache_custom_mask(
batch_size,
Expand All @@ -374,6 +444,7 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask(
num_qo_heads,
head_dim,
pos_encoding_mode,
logits_soft_cap,
return_lse,
):
kv_layout = "NHD"
Expand Down Expand Up @@ -409,21 +480,41 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask(
)
if return_lse:
o_custom, _ = wrapper.forward_return_lse(
q, k, v, pos_encoding_mode=pos_encoding_mode
q,
k,
v,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o_custom = wrapper.forward(q, k, v, pos_encoding_mode=pos_encoding_mode)
o_custom = wrapper.forward(
q,
k,
v,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
wrapper.end_forward()

# use causal
wrapper.begin_forward(q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim)
if return_lse:
o_causal, _ = wrapper.forward_return_lse(
q, k, v, causal=True, pos_encoding_mode=pos_encoding_mode
q,
k,
v,
causal=True,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
else:
o_causal = wrapper.forward(
q, k, v, causal=True, pos_encoding_mode=pos_encoding_mode
q,
k,
v,
causal=True,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
numpy.testing.assert_allclose(
o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3
Expand All @@ -432,15 +523,17 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask(

if __name__ == "__main__":
test_batch_prefill_with_paged_kv_cache(
12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, False
12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False
)
test_batch_prefill_with_paged_kv_cache(
12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, False
12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, 0.0, False
)
test_batch_prefill_with_paged_kv_cache_custom_mask(
12, 137, 137, 1, 8, 8, 128, "HND", "NONE", False
12, 137, 137, 1, 8, 8, 128, "HND", "NONE", 0.0, False
)
test_batch_prefill_with_ragged_kv_cache(
12, 54, 37, 8, 8, 128, True, "NONE", 0.0, False
)
test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True, "NONE", False)
test_batch_prefill_with_ragged_kv_cache_custom_mask(
12, 137, 137, 8, 8, 128, "NONE", False
12, 137, 137, 8, 8, 128, "NONE", 0.0, False
)
4 changes: 2 additions & 2 deletions python/tests/test_logits_cap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def attention_logits_soft_cap_torch(q, k, v, soft_cap):
@pytest.mark.parametrize("seq_len", [1, 9, 81, 729, 33001])
@pytest.mark.parametrize("num_heads", [4, 8, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("soft_cap", [30.0, 50.0])
@pytest.mark.parametrize("soft_cap", [1.0, 30.0, 50.0])
def test_single_decode_logits_soft_cap(
seq_len,
num_heads,
Expand All @@ -56,7 +56,7 @@ def test_single_decode_logits_soft_cap(
@pytest.mark.parametrize("kv_len", [1, 17, 81, 987, 31111])
@pytest.mark.parametrize("num_heads", [4, 8, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("soft_cap", [30.0, 50.0])
@pytest.mark.parametrize("soft_cap", [1.0, 30.0, 50.0])
def test_single_prefill_logits_soft_cap(
q_len,
kv_len,
Expand Down

0 comments on commit d1d443a

Please sign in to comment.