Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1] Move block pool operations to a separate class #13973

Merged
merged 9 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 49 additions & 40 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
from typing import List

import pytest

from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)


def make_request(request_id,
Expand Down Expand Up @@ -62,14 +66,14 @@ def test_prefill():
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value

# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1

# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
Expand All @@ -86,20 +90,21 @@ def test_prefill():
assert block.ref_cnt == 2

# At this point, we should have 3 free blocks left.
assert manager.free_block_queue.num_free_blocks == 3
assert manager.block_pool.free_block_queue.num_free_blocks == 3

manager.free(req0)
manager.free(req1)

# All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]

# Cache hit in the common prefix when the original block is already free.
Expand All @@ -116,12 +121,14 @@ def test_prefill():

# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert manager.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 5
assert all([
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
b.ref_cnt == 0
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
])
assert len([b
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
assert len([
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]) == 5

manager.free(req2)

Expand All @@ -133,9 +140,9 @@ def test_prefill():
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
assert manager.free_block_queue.free_list_head is None
assert manager.free_block_queue.free_list_tail is None
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None


def test_decode():
Expand Down Expand Up @@ -219,13 +226,14 @@ def test_evict():
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16

assert manager.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.num_free_blocks == 0

manager.free(req0)
manager.free(req1)
assert manager.free_block_queue.num_free_blocks == 10
assert manager.block_pool.free_block_queue.num_free_blocks == 10
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]

# Touch the first 2 blocks.
Expand All @@ -235,7 +243,7 @@ def test_evict():
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
assert manager.block_pool.free_block_queue.num_free_blocks == 6


def test_hash_block_correct_reuse():
Expand Down Expand Up @@ -274,7 +282,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1

assert manager.block_pool[blocks[0].block_id].block_hash is None
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None


def test_computed_blocks_not_evicted():
Expand Down Expand Up @@ -413,13 +421,9 @@ def test_cache_blocks():
function of KVCacheManager.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
block_pool = BlockPool(
num_gpu_blocks=5,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
# Req:
# Block 0: [0, 1, 2, 3]
Expand All @@ -430,26 +434,31 @@ def test_cache_blocks():

# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
block_hashes: List[BlockHashType] = []

manager._cache_full_blocks(
block_pool.cache_full_blocks(
request=req,
blk_start_idx=0,
full_blocks=blocks,
prev_block=None,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
)

assert len(manager.cached_block_hash_to_block) == 2
assert len(block_pool.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])

# Test that blocks that don't start from the beginning are cached correctly.
blocks = [KVCacheBlock(block_id=2)]
manager._cache_full_blocks(
blocks += [KVCacheBlock(block_id=2)]
block_pool.cache_full_blocks(
request=req,
blk_start_idx=2,
full_blocks=blocks,
prev_block=None,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=2,
num_full_blocks=3,
block_size=block_size,
)
assert len(manager.cached_block_hash_to_block) == 3
assert len(block_pool.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None


Expand Down Expand Up @@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
Expand Down Expand Up @@ -621,12 +630,12 @@ def test_reset_prefix_cache():

# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
assert manager.cached_block_hash_to_block
assert manager.block_pool.cached_block_hash_to_block

# Free the blocks.
manager.free(req0)
manager.free(req1)

assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])
assert not manager.block_pool.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])
Loading