Skip to content

Commit

Permalink
cherry-pick: [V1] Move KV block hashes from Request to KVCacheManager (
Browse files Browse the repository at this point in the history
…vllm-project#12922)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
  • Loading branch information
WoosukKwon authored and heheda12345 committed Feb 8, 2025
1 parent a7173a2 commit 5e2d3bd
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 40 deletions.
21 changes: 11 additions & 10 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes[0]) == 3
assert len(manager.req_to_block_hashes[0][req0.request_id]) == 3
assert not computed_blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks,
Expand All @@ -89,7 +89,7 @@ def test_prefill():
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes[0]) == 3
assert len(manager.req_to_block_hashes[0][req1.request_id]) == 3
assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_prefill():
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes[0]) == 3
assert len(manager.req_to_block_hashes[0][req2.request_id]) == 3
assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
Expand Down Expand Up @@ -509,10 +509,11 @@ def test_mm_prefix_caching():
# Completed block should have hashes with extra keys.
assert not computed_blocks[0]
assert num_computed_tokens == 0
assert len(req0.kv_block_hashes[0]) == 3
assert req0.kv_block_hashes[0][0].extra_keys == ("aaa", )
assert req0.kv_block_hashes[0][1].extra_keys == ("aaa", "bbb")
assert req0.kv_block_hashes[0][2].extra_keys == ("bbb", )
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes[0]) == 3
assert block_hashes[0][0].extra_keys == ("aaa", )
assert block_hashes[0][1].extra_keys == ("aaa", "bbb")
assert block_hashes[0][2].extra_keys == ("bbb", )

blocks = manager.allocate_slots(req0, 59, computed_blocks,
num_computed_tokens)
Expand All @@ -526,8 +527,8 @@ def test_mm_prefix_caching():
assert new_blocks is not None and len(new_blocks[0]) == 0

# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes[0]) == 4
assert req0.kv_block_hashes[0][3].extra_keys == ("ccc", )
assert len(block_hashes[0]) == 4
assert block_hashes[0][3].extra_keys == ("ccc", )

# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
Expand Down Expand Up @@ -632,7 +633,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes[0]) == 3
assert len(manager.req_to_block_hashes[0][req1.request_id]) == 3
assert len(computed_blocks[0]) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks[0]] == [4]
Expand Down
37 changes: 26 additions & 11 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def __init__(
self.req_to_blocks: DefaultDict[str, ReqKVCacheBlocks] = defaultdict(
lambda: [[] for _ in range(self.num_kv_cache_groups)])

# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: DefaultDict[
str, List[List[BlockHashType]]] = defaultdict(
lambda: [[] for _ in range(self.num_kv_cache_groups)])

@property
def usage(self) -> float:
return 1.0 - (self.free_block_queue.num_free_blocks /
Expand All @@ -121,17 +128,19 @@ def get_computed_blocks(self,
return [[] for _ in self.managers], 0

# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes([
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = [
hash_request_tokens(manager.block_size, request, i)
for i, manager in enumerate(self.managers)
])
]
self.req_to_block_hashes[request.request_id] = block_hashes

computed_blocks: ReqKVCacheBlocks = [] # computed blocks of each group
prefix_length: List[PrefixLength] = [
] # possible cached prefix length of each group
block_hashes = request.kv_block_hashes

for i, manager in enumerate(self.managers):
prefix_length_i, computed_blocks_i = (
manager.get_possible_cached_prefix(block_hashes[i]))
Expand All @@ -154,7 +163,6 @@ def get_computed_blocks(self,
for i, manager in enumerate(self.managers):
computed_blocks[i] = computed_blocks[i][:num_computed_tokens //
manager.block_size]

return computed_blocks, num_computed_tokens

def allocate_slots(
Expand Down Expand Up @@ -560,8 +568,8 @@ def _cache_full_blocks(
prev_block: The previous block in the chain.
kv_cache_group_id: The KV cache group that the blocks belong to
"""
num_cached_block_hashes = len(
request.kv_block_hashes[kv_cache_group_id])
block_hashes = self.req_to_block_hashes[request.request_id]
num_cached_block_hashes = len(block_hashes[kv_cache_group_id])

# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
Expand Down Expand Up @@ -596,8 +604,7 @@ def _cache_full_blocks(
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[kv_cache_group_id][
blk_idx]
block_hash = block_hashes[kv_cache_group_id][blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
Expand All @@ -620,13 +627,21 @@ def _cache_full_blocks(
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, kv_cache_group_id,
extra_keys)
request.append_kv_block_hashes(kv_cache_group_id, block_hash)
block_hashes.append(kv_cache_group_id, block_hash)

# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value

def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)

def get_null_block(self) -> KVCacheBlock:
return self._null_block

Expand Down
1 change: 1 addition & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def finish_requests(
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
Expand Down
19 changes: 0 additions & 19 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_utils import BlockHashType


class Request:
Expand Down Expand Up @@ -63,12 +62,6 @@ def __init__(
if self.mm_hashes:
assert len(self.mm_inputs) == len(self.mm_hashes)

# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[List[BlockHashType]] = []
self.kv_block_hashes = ConstantList(
[ConstantList(x) for x in self._kv_block_hashes])

# Read-only views
# Prevent directly appending to the these lists since
# they should also be updated simultaneously.
Expand Down Expand Up @@ -125,18 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens

def set_kv_block_hashes(self, value: List[List["BlockHashType"]]) -> None:
self._kv_block_hashes = value
# NOTE: self.kv_block_hashes._x is not self._kv_block_hashes, but
# self.kv_block_hashes[0]._x is self._kv_block_hashes[0]. This is
# correct because we never need to update the outer list.
self.kv_block_hashes = ConstantList(
[ConstantList(x) for x in self._kv_block_hashes])

def append_kv_block_hashes(self, group_id: int,
block_hash: "BlockHashType") -> None:
self._kv_block_hashes[group_id].append(block_hash)


class RequestStatus(enum.IntEnum):
"""Status of a request."""
Expand Down

0 comments on commit 5e2d3bd

Please sign in to comment.