From b83770a1e4cff388219a3e85e3e9866c44881e8e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 22 Jan 2025 20:19:21 -0800 Subject: [PATCH] [V1] Add `uncache_blocks` (#12333) --- tests/v1/core/test_prefix_caching.py | 30 +++++++++++++++++++++++++ vllm/v1/core/kv_cache_manager.py | 33 ++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c5860809f9e62..f434fa8c61a80 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -626,3 +626,33 @@ def test_reset_prefix_cache(): assert manager.reset_prefix_cache() assert not manager.cached_block_hash_to_block assert all([blk.block_hash is None for blk in manager.block_pool]) + + +def test_uncache_blocks(): + manager = KVCacheManager( + block_size=16, + num_gpu_blocks=10, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=0, + ) + + req0 = make_request("0", list(range(30))) + blocks = manager.allocate_slots(req0, 30, []) + assert [b.block_id for b in blocks] == [0, 1] + assert len(manager.cached_block_hash_to_block) == 1 + + req0.num_computed_tokens = 30 + + # Simulate speculative tokens. + for _ in range(5): + req0.append_output_token_ids(8) + manager.append_slots(req0, 5) + assert len(manager.cached_block_hash_to_block) == 2 + + # After sampling, assuming only 1 token is accepted. + req0.num_computed_tokens = 31 + num_uncached_blocks = manager.uncache_blocks(req0) + assert num_uncached_blocks == 1 + assert len(manager.cached_block_hash_to_block) == 1 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8c8c8b3b55c0b..18fdfdfe4a010 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -285,6 +285,29 @@ def free(self, request: Request) -> None: if block.ref_cnt == 0: self.free_block_queue.append(block) + def uncache_blocks(self, request: Request) -> int: + """Uncache the blocks that are no longer full based on the + num_computed_tokens in the given request. This happens when + the blocks were full and cached due to speculative tokens, but the + speculative tokens are not accepted. + + Args: + request: The request. + + Returns: + The number of uncached blocks. + """ + blocks = self.req_to_blocks[request.request_id] + num_computed_tokens = request.num_computed_tokens + num_full_blocks = num_computed_tokens // self.block_size + num_uncached_blocks = 0 + for block in blocks[num_full_blocks:]: + # If the block is not cached, the following blocks are not cached. + if not self._maybe_evict_cached_block(block): + break + num_uncached_blocks += 1 + return num_uncached_blocks + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalid prefix caching after the weights are updated, @@ -386,7 +409,7 @@ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: # If the block is cached, evict it. if self.enable_caching: - self._evict_cached_block(curr_block) + self._maybe_evict_cached_block(curr_block) curr_block.incr_ref() ret.append(curr_block) @@ -394,13 +417,16 @@ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: return ret - def _evict_cached_block(self, block: KVCacheBlock) -> None: + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: """ If a block is cached in `cached_block_hash_to_block`, we reset its hash metadata and evict it from the cache. Args: block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. """ block_hash = block.block_hash if block_hash and block_hash in self.cached_block_hash_to_block: @@ -410,6 +436,9 @@ def _evict_cached_block(self, block: KVCacheBlock) -> None: if len(self.cached_block_hash_to_block[block_hash]) == 0: del self.cached_block_hash_to_block[block_hash] + return True + return False + def _get_cached_block(self, block_hash: BlockHashType) -> Optional[KVCacheBlock]: """Get a cached block by the block hash, or None if cache miss.