From 46be06aabbad56f956cde703e9664c0d8218cce9 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 15 Nov 2024 06:58:41 +0000 Subject: [PATCH] add tests Signed-off-by: jiang1.li --- .buildkite/run-cpu-test.sh | 6 ++ .../basic_correctness/test_chunked_prefill.py | 71 ++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index a00331abb7d03..e24d30b439666 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -57,6 +57,12 @@ function cpu_tests() { pytest -s -v \ tests/quantization/test_ipex_quant.py" + # Run chunked-prefill and prefix-cache test + docker exec cpu-test bash -c " + set -e + pytest -s -v -k cpu_only \ + tests/basic_correctness/test_chunked_prefill.py" + # online inference docker exec cpu-test bash -c " set -e diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index cc5bc2aca27c9..ee3ed1f9e8853 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -12,6 +12,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable +from vllm.platforms import current_platform from ..models.utils import check_logprobs_close, check_outputs_equal from ..utils import multi_gpu_test @@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache( # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("dtype", ["half"]) def test_with_prefix_caching( vllm_runner, max_tokens: int, enforce_eager: bool, chunk_size: int, tensor_parallel_size: int, + dtype: str, ) -> None: """ Checks exact match decode with and without prefix caching @@ -233,7 +236,7 @@ def test_with_prefix_caching( for enable in (True, False): with vllm_runner( model, - dtype="half", + dtype=dtype, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=True, enable_prefix_caching=enable, @@ -260,3 +263,69 @@ def test_with_prefix_caching( name_0="w/o prefix caching", name_1="with prefix caching", ) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) +@pytest.mark.cpu_only +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_models_cpu( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + tensor_parallel_size: int, + attention_backend: str, + monkeypatch, +) -> None: + test_models( + hf_runner, + vllm_runner, + example_prompts, + model, + dtype, + max_tokens, + chunked_prefill_token_size, + enforce_eager, + tensor_parallel_size, + attention_backend, + monkeypatch, + ) + + +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunk_size", [30, 32]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.cpu_only +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_with_prefix_caching_cpu( + vllm_runner, + max_tokens: int, + enforce_eager: bool, + chunk_size: int, + tensor_parallel_size: int, + dtype: str, +) -> None: + test_with_prefix_caching( + vllm_runner, + max_tokens, + enforce_eager, + chunk_size, + tensor_parallel_size, + dtype, + )