diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b97f55b8c6535..fafd9d0ce4455 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -49,9 +49,10 @@ def test_prefill(): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) - computed_blocks = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.kv_block_hashes) == 3 assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -73,9 +74,10 @@ def test_prefill(): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) - computed_blocks = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.kv_block_hashes) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) assert [b.block_id for b in blocks] == [5, 6] @@ -91,7 +93,7 @@ def test_prefill(): # All blocks should be available. assert manager.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (7, 8)] + # [unallocated (7, 8, 9)] # [unique_req0 (4, 3)] # [unique_req1 (6, 5)] # [common (2, 1, 0)] @@ -103,9 +105,10 @@ def test_prefill(): # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) - computed_blocks = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.kv_block_hashes) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) assert [b.block_id for b in blocks] == [7, 8] @@ -123,8 +126,9 @@ def test_prefill(): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 9)) - computed_blocks = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] @@ -150,8 +154,9 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) - computed_blocks = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -197,16 +202,18 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) - computed_blocks = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) - computed_blocks = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) assert len(blocks) == 3 # 3 full blocks last_token_id += 3 * 16 @@ -222,8 +229,9 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) - computed_blocks = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert [b.block_id for b in computed_blocks] == [0, 1] + assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) assert [b.block_id for b in blocks] == [6, 5] assert manager.free_block_queue.num_free_blocks == 6 @@ -247,8 +255,9 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) - computed_blocks = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, computed_blocks) assert len(blocks) == 1 @@ -258,8 +267,9 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. req = make_request("1", list(range(num_tokens - 1))) - computed_blocks = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) assert len(blocks) == 1 @@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) - computed_blocks = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) assert len(blocks) == 1 assert blocks[0].block_id == 0 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) - computed_blocks = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) assert len(blocks) == 1 assert blocks[0].block_id == 1 @@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) - computed_blocks = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks) == 1 assert computed_blocks[0].block_id == 0 + assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, computed_blocks) @@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more - computed_blocks = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, computed_blocks) assert len(blocks) == 3 @@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix - computed_blocks = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, computed_blocks) assert len(blocks) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) - computed_blocks = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks + assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, computed_blocks) assert not blocks @@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size) req = make_request("0", list(range(block_size * 30))) - computed_blocks = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks + assert num_computed_tokens == 0 # Just ask for 1 block. blocks = manager.allocate_slots(req, block_size, computed_blocks) req.num_computed_tokens = block_size @@ -469,10 +486,11 @@ def test_mm_prefix_caching(): all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) - computed_blocks = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks + assert num_computed_tokens == 0 assert len(req0.kv_block_hashes) == 3 assert req0.kv_block_hashes[0].extra_keys == ("aaa", ) assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb") @@ -503,8 +521,9 @@ def test_mm_prefix_caching(): all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) - computed_blocks = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks) == 3 + assert num_computed_tokens == 3 * 16 def test_prefill_not_enough_free_blocks_with_computed_blocks(): @@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) - computed_blocks = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks + assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) block_part0 = manager.req_to_blocks[req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) - computed_blocks = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks == block_part0 + assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) block_part1 = manager.req_to_blocks[req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | @@ -547,8 +568,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) - computed_blocks = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks + assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, @@ -556,8 +578,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) - computed_blocks = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks == block_part1 + assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. assert manager.allocate_slots(req3, 48, computed_blocks) is None # Block 0-2 are used by Req 1. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 1cbff1e2d767e..bac77443c8560 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv @@ -69,7 +69,8 @@ def __init__( # is finished. self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} - def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: + def get_computed_blocks( + self, request: Request) -> Tuple[List[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -77,11 +78,13 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: request: The request to get the computed blocks. Returns: - A list of blocks that are computed for the request. + A tuple containing: + - A list of blocks that are computed for the request. + - The number of computed tokens. """ if not self.enable_caching: # Prefix caching is disabled. - return [] + return [], 0 computed_blocks = [] @@ -101,7 +104,11 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: else: break - return computed_blocks + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size + return computed_blocks, num_computed_tokens def append_slots( self, diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 2503d136aea7e..45e67c94f8f15 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -184,12 +184,8 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[0] # Get already-cached tokens. - computed_blocks = self.kv_cache_manager.get_computed_blocks( - request) - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size + computed_blocks, num_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks(request) # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests,