Skip to content

Commit

Permalink
[v1] Move block pool operations to a separate class (#13973)
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
  • Loading branch information
heheda12345 and comaniac authored Feb 28, 2025
1 parent b526ca6 commit 28943d3
Show file tree
Hide file tree
Showing 3 changed files with 360 additions and 277 deletions.
89 changes: 49 additions & 40 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
from typing import List

import pytest

from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)


def make_request(request_id,
Expand Down Expand Up @@ -62,14 +66,14 @@ def test_prefill():
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value

# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1

# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
Expand All @@ -86,20 +90,21 @@ def test_prefill():
assert block.ref_cnt == 2

# At this point, we should have 3 free blocks left.
assert manager.free_block_queue.num_free_blocks == 3
assert manager.block_pool.free_block_queue.num_free_blocks == 3

manager.free(req0)
manager.free(req1)

# All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]

# Cache hit in the common prefix when the original block is already free.
Expand All @@ -116,12 +121,14 @@ def test_prefill():

# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert manager.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 5
assert all([
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
b.ref_cnt == 0
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
])
assert len([b
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
assert len([
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]) == 5

manager.free(req2)

Expand All @@ -133,9 +140,9 @@ def test_prefill():
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
assert manager.free_block_queue.free_list_head is None
assert manager.free_block_queue.free_list_tail is None
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None


def test_decode():
Expand Down Expand Up @@ -219,13 +226,14 @@ def test_evict():
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16

assert manager.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.num_free_blocks == 0

manager.free(req0)
manager.free(req1)
assert manager.free_block_queue.num_free_blocks == 10
assert manager.block_pool.free_block_queue.num_free_blocks == 10
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]

# Touch the first 2 blocks.
Expand All @@ -235,7 +243,7 @@ def test_evict():
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
assert manager.block_pool.free_block_queue.num_free_blocks == 6


def test_hash_block_correct_reuse():
Expand Down Expand Up @@ -274,7 +282,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1

assert manager.block_pool[blocks[0].block_id].block_hash is None
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None


def test_computed_blocks_not_evicted():
Expand Down Expand Up @@ -413,13 +421,9 @@ def test_cache_blocks():
function of KVCacheManager.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
block_pool = BlockPool(
num_gpu_blocks=5,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
# Req:
# Block 0: [0, 1, 2, 3]
Expand All @@ -430,26 +434,31 @@ def test_cache_blocks():

# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
block_hashes: List[BlockHashType] = []

manager._cache_full_blocks(
block_pool.cache_full_blocks(
request=req,
blk_start_idx=0,
full_blocks=blocks,
prev_block=None,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
)

assert len(manager.cached_block_hash_to_block) == 2
assert len(block_pool.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])

# Test that blocks that don't start from the beginning are cached correctly.
blocks = [KVCacheBlock(block_id=2)]
manager._cache_full_blocks(
blocks += [KVCacheBlock(block_id=2)]
block_pool.cache_full_blocks(
request=req,
blk_start_idx=2,
full_blocks=blocks,
prev_block=None,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=2,
num_full_blocks=3,
block_size=block_size,
)
assert len(manager.cached_block_hash_to_block) == 3
assert len(block_pool.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None


Expand Down Expand Up @@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
Expand Down Expand Up @@ -621,12 +630,12 @@ def test_reset_prefix_cache():

# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
assert manager.cached_block_hash_to_block
assert manager.block_pool.cached_block_hash_to_block

# Free the blocks.
manager.free(req0)
manager.free(req1)

assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])
assert not manager.block_pool.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])
Loading

0 comments on commit 28943d3

Please sign in to comment.