From 1b8a0d71cf5aa1a43c14478ec90538c3fbe1b315 Mon Sep 17 00:00:00 2001
From: leiwen83 <leiwen83@users.noreply.github.com>
Date: Sat, 15 Jun 2024 08:23:56 +0800
Subject: [PATCH] [Core][Bugfix]: fix prefix caching for blockv2 (#5364)

Signed-off-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
---
 tests/core/block/e2e/test_correctness.py | 67 ++++++++++++++++++++++++
 vllm/core/block/prefix_caching_block.py  |  7 ++-
 2 files changed, 72 insertions(+), 2 deletions(-)

diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py
index ad253635e0ba0..8502eab0f8da0 100644
--- a/tests/core/block/e2e/test_correctness.py
+++ b/tests/core/block/e2e/test_correctness.py
@@ -477,3 +477,70 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
         assert expected_token_ids == actual_token_ids
 
     assert baseline_token_ids == test_token_ids
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        # Use a small model for a fast test.
+        "model": "facebook/opt-125m",
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+
+        # we keep the blocks small, so that hit eviction quickly
+        "max_model_len": 48,
+        "block_size": 16,
+        "num_gpu_blocks_override": 3,
+
+        # Test APC in v2 block
+        "use_v2_block_manager": True,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "enable_prefix_caching": False
+}])
+@pytest.mark.parametrize("test_llm_kwargs", [{
+    "enable_prefix_caching": True,
+}])
+@pytest.mark.parametrize("seed", [1])
+def test_auto_prefix_caching_after_evition_start(baseline_llm_generator,
+                                                 test_llm_generator):
+    """Verify block manager v2 with auto prefix caching could works normal
+    even when eviction started.
+    With APC enabled, all blocks are held by native block at the beginning.
+    Then blocks are managed by evictor instead. If cache hit at the evitor's
+    block, then it could be reused, or we need to recompute its kv cache.
+    """
+    output_len = 10
+    temperature = 0.0
+
+    prompts = [
+        "You are a helpful assistant. Please answer truthfully and write "
+        "out your thinking step by step to be sure you get the right answer. "
+        "If you make a mistake, attempt to correct it. who are you?",
+        "You are a helpful assistant. Please answer truthfully and write out "
+        "your thinking step by step to be sure you get the right answer. You "
+        "are helpful and harmless and you follow ethical guidelines. "
+        "who are you?"
+    ]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+    )
+
+    print('Getting token ids with APC disabled')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids with APC enabled')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids
diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py
index 405e9705659df..88dbbfb2f3690 100644
--- a/vllm/core/block/prefix_caching_block.py
+++ b/vllm/core/block/prefix_caching_block.py
@@ -176,14 +176,17 @@ def allocate_mutable(self,
 
             self._refcounter.incr(block_id)
 
-            # the block comes from evictor already contain computed result
+            # Now this block is pop from evictor and ready to write
+            # with new content which most probably different with
+            # original content. So need to tell worker to recompute
+            # its kvcache
             block = self._create_block(
                 prev_block=prev_block,
                 token_ids=[],
                 block_size=self._block_size,
                 allocator=self,
                 block_id=block_id,
-                computed=True,
+                computed=False,
             )
             assert block.content_hash is None