From 07411dfde86b0dd7f1b2ee12a80bb3529672c1e9 Mon Sep 17 00:00:00 2001 From: zachzzc Date: Thu, 1 Aug 2024 06:48:50 +0000 Subject: [PATCH 1/5] [Bugfix] Fix block table for seqs with prefix cache seq --- vllm/attention/backends/flash_attn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7d7aff9dc3cdc..da8d96c5184d9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -209,6 +209,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False self.input_builder = input_builder self.runner = input_builder.runner @@ -252,10 +253,11 @@ def _add_seq_group( # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if inter_data.prefix_cache_hit: + if inter_data.prefix_cache_hit or self.has_prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] + self.has_prefix_cache_hit = True elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): block_table = block_tables[seq_id][-curr_sliding_window_block:] From 7fbe5076f4f52e4cd3350cfbda2553f7b0bb109f Mon Sep 17 00:00:00 2001 From: zachzzc Date: Thu, 1 Aug 2024 21:49:44 +0000 Subject: [PATCH 2/5] Add test --- .../basic_correctness/test_prefix_caching.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/basic_correctness/test_prefix_caching.py diff --git a/tests/basic_correctness/test_prefix_caching.py b/tests/basic_correctness/test_prefix_caching.py new file mode 100644 index 0000000000000..0638cab342e41 --- /dev/null +++ b/tests/basic_correctness/test_prefix_caching.py @@ -0,0 +1,62 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +It tests prefix caching. Chunked prefill can be enabled by +enable_prefix_caching=True. + +Run `pytest tests/basic_correctness/test_prefix_caching.py`. +""" +import pytest + +from tests.kernels.utils import override_backend_env_variable + +from ..models.utils import check_outputs_equal + +MODELS = [ + "facebook/opt-125m", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) +def test_mixed_requests( + hf_runner, + vllm_runner, + example_prompts, + model: str, + backend: str, + dtype: str, + max_tokens: int, + use_v2_block_manager: bool, + monkeypatch, +) -> None: + """ + Test the case when some sequences have the prefix cache hit + and the others don't. + """ + override_backend_env_variable(monkeypatch, backend) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + cached_prompt = example_prompts[0] + with vllm_runner( + model, + dtype=dtype, + enable_prefix_caching=True, + use_v2_block_manager=use_v2_block_manager, + ) as vllm_model: + # Run the first prompt so the cache is populated + vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) + + # Run all the promopts + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) From 5dde625bfdfea11210640b9f7e5858164e9fbf7b Mon Sep 17 00:00:00 2001 From: zachzzc Date: Thu, 1 Aug 2024 22:47:24 +0000 Subject: [PATCH 3/5] Fix comments --- .../basic_correctness/test_prefix_caching.py | 62 ------------------- tests/prefix_caching/test_prefix_caching.py | 53 ++++++++++++++++ vllm/attention/backends/flash_attn.py | 4 +- 3 files changed, 55 insertions(+), 64 deletions(-) delete mode 100644 tests/basic_correctness/test_prefix_caching.py diff --git a/tests/basic_correctness/test_prefix_caching.py b/tests/basic_correctness/test_prefix_caching.py deleted file mode 100644 index 0638cab342e41..0000000000000 --- a/tests/basic_correctness/test_prefix_caching.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Compare the outputs of HF and vLLM when using greedy sampling. - -It tests prefix caching. Chunked prefill can be enabled by -enable_prefix_caching=True. - -Run `pytest tests/basic_correctness/test_prefix_caching.py`. -""" -import pytest - -from tests.kernels.utils import override_backend_env_variable - -from ..models.utils import check_outputs_equal - -MODELS = [ - "facebook/opt-125m", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("use_v2_block_manager", [False, True]) -def test_mixed_requests( - hf_runner, - vllm_runner, - example_prompts, - model: str, - backend: str, - dtype: str, - max_tokens: int, - use_v2_block_manager: bool, - monkeypatch, -) -> None: - """ - Test the case when some sequences have the prefix cache hit - and the others don't. - """ - override_backend_env_variable(monkeypatch, backend) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - cached_prompt = example_prompts[0] - with vllm_runner( - model, - dtype=dtype, - enable_prefix_caching=True, - use_v2_block_manager=use_v2_block_manager, - ) as vllm_model: - # Run the first prompt so the cache is populated - vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) - - # Run all the promopts - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 7985001d34eb1..5744f4ac3c94a 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -10,6 +10,13 @@ from vllm.core.block_manager_v1 import CachedBlockAllocator from vllm.utils import Device +from tests.kernels.utils import override_backend_env_variable +from ..models.utils import check_outputs_equal + +MODELS = [ + "facebook/opt-125m", +] + @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("num_blocks", [16]) @@ -76,3 +83,49 @@ def test_eviction(num_blocks: int, ): assert (realloc_block != new_block) assert (new_block.block_hash == new_block_hash) assert (new_block.block_number == 2) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) +def test_mixed_requests( + hf_runner, + vllm_runner, + example_prompts, + model: str, + backend: str, + dtype: str, + max_tokens: int, + use_v2_block_manager: bool, + monkeypatch, +) -> None: + """ + Test the case when some sequences have the prefix cache hit + and the others don't. + """ + override_backend_env_variable(monkeypatch, backend) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + cached_prompt = example_prompts[0] + with vllm_runner( + model, + dtype=dtype, + enable_prefix_caching=True, + use_v2_block_manager=use_v2_block_manager, + ) as vllm_model: + # Run the first prompt so the cache is populated + vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) + + # Run all the promopts + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index da8d96c5184d9..d00e9813193d6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -253,11 +253,11 @@ def _add_seq_group( # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if inter_data.prefix_cache_hit or self.has_prefix_cache_hit: + self.has_prefix_cache_hit |= inter_data.prefix_cache_hit + if self.has_prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] - self.has_prefix_cache_hit = True elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): block_table = block_tables[seq_id][-curr_sliding_window_block:] From fa0d7eea766b7cb82278e30fe67b1add357e46cf Mon Sep 17 00:00:00 2001 From: zachzzc Date: Thu, 1 Aug 2024 23:02:32 +0000 Subject: [PATCH 4/5] Fix isort --- tests/prefix_caching/test_prefix_caching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 5744f4ac3c94a..3ea05e818bb92 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -6,11 +6,11 @@ import pytest +from tests.kernels.utils import override_backend_env_variable from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import CachedBlockAllocator from vllm.utils import Device -from tests.kernels.utils import override_backend_env_variable from ..models.utils import check_outputs_equal MODELS = [ From 10b2e048e01bbb6c19824c079b85e75317c27c58 Mon Sep 17 00:00:00 2001 From: zachzzc Date: Fri, 2 Aug 2024 05:02:55 +0000 Subject: [PATCH 5/5] Fix cases where non-cached are after cached --- tests/prefix_caching/test_prefix_caching.py | 7 +++++-- vllm/attention/backends/flash_attn.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 3ea05e818bb92..9821dbd066a59 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -89,6 +89,7 @@ def test_eviction(num_blocks: int, ): @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("cached_position", [0, 1]) @pytest.mark.parametrize("use_v2_block_manager", [False, True]) def test_mixed_requests( hf_runner, @@ -98,19 +99,21 @@ def test_mixed_requests( backend: str, dtype: str, max_tokens: int, + cached_position: int, use_v2_block_manager: bool, monkeypatch, ) -> None: """ Test the case when some sequences have the prefix cache hit - and the others don't. + and the others don't. The cached position determines where + the sequence is at among the batch of prefills. """ override_backend_env_variable(monkeypatch, backend) with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - cached_prompt = example_prompts[0] + cached_prompt = example_prompts[cached_position] with vllm_runner( model, dtype=dtype, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d00e9813193d6..58100d6db2ae6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -220,7 +220,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + chunked_prefill_enabled: bool, prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -253,8 +253,7 @@ def _add_seq_group( # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - self.has_prefix_cache_hit |= inter_data.prefix_cache_hit - if self.has_prefix_cache_hit: + if prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] @@ -283,9 +282,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1