Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add benchmark for append_paged_kv_cache #583

Merged
merged 4 commits into from
Nov 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions benchmarks/bench_append_paged_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import argparse
import dataclasses
from typing import cast

import flashinfer
import torch
from triton.testing import do_bench


@dataclasses.dataclass(kw_only=True)
class ModelConfig:
num_kv_heads: int
num_layers: int
head_dim: int


def _make_70b(tp: int) -> ModelConfig:
return ModelConfig(
num_kv_heads=8 // tp,
num_layers=80,
head_dim=128,
)


MODELS = {
"l1b": ModelConfig(
num_kv_heads=8,
num_layers=16,
head_dim=64,
),
"l3b": ModelConfig(
num_kv_heads=8,
num_layers=28,
head_dim=128,
),
"l8b": ModelConfig(
num_kv_heads=8,
num_layers=32,
head_dim=128,
),
"l70b-tp8": _make_70b(8),
}


@torch.inference_mode()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seqlen", type=int, default=5000)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--page-len", type=int, default=16)
parser.add_argument("--dtype", type=str, default="float16")
args = parser.parse_args()

seqlens_ = [
[1] * args.batch_size,
[args.seqlen - args.batch_size + 1] + [1] * (args.batch_size - 1),
[args.seqlen],
[args.seqlen // args.batch_size] * args.batch_size,
]
seqlen_strlen = max(len(str(seqlens)) for seqlens in seqlens_)
page_len = int(args.page_len)
dtype = getattr(torch, args.dtype)
assert isinstance(dtype, torch.dtype)
device = torch.device("cuda:0")
total_pages = int(256000 / page_len)

torch.cuda.profiler.start()

for model_name, model in MODELS.items():
page_shape = (2, page_len, model.num_kv_heads, model.head_dim)
layer_buf = torch.empty((total_pages,) + page_shape, dtype=dtype, device=device)
for seqlens in seqlens_:
k = torch.rand(
(sum(seqlens), model.num_kv_heads, model.head_dim),
dtype=dtype,
device=device,
)
v = torch.rand(
(sum(seqlens), model.num_kv_heads, model.head_dim),
dtype=dtype,
device=device,
)
x_indptr = torch.tensor([0] + seqlens, device=device, dtype=torch.int32)
x_indptr = torch.cumsum(x_indptr, 0, dtype=torch.int32)
kv_indices_host = []
kv_indptr_host = [0]
next_page_id = 0
for seqlen in seqlens:
npages = (seqlen + page_len - 1) // page_len
kv_indices_host.extend(range(next_page_id, next_page_id + npages))
next_page_id += npages
kv_indptr_host.append(len(kv_indices_host))
kv_indices = torch.tensor(kv_indices_host, device=device, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr_host, device=device, dtype=torch.int32)
kv_last_page_len = torch.tensor(
[(seqlen - 1) % page_len + 1 for seqlen in seqlens],
device=device,
dtype=torch.int32,
)

@torch.cuda.nvtx.range(f"model={model_name}, seqlens={seqlens}")
def fn():
flashinfer.append_paged_kv_cache(
k,
v,
x_indptr,
layer_buf,
kv_indices,
kv_indptr,
kv_last_page_len,
"NHD",
)

latency_ms = cast(float, do_bench(fn))
all_layers_latency_ms = latency_ms * model.num_layers
throughput = (
k.numel()
* k.element_size()
* sum(1 for _ in ["k", "v"])
* sum(1 for _ in ["read", "write"])
/ (latency_ms * 1e-3)
)
print(
f"model: {model_name:8}",
f"seqlens: {seqlens!r:{seqlen_strlen}}",
f"single_layer: {latency_ms:5.3f}ms",
f"all_layers: {all_layers_latency_ms:7.3f}ms",
f"throughput: {throughput*1e-9:8.3f}GB/s",
)
print("---")

torch.cuda.profiler.stop()


if __name__ == "__main__":
main()