-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager #12003
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from collections import defaultdict | ||
from typing import Dict, Iterable, List, Optional | ||
from typing import Dict, Iterable, List, Optional, Tuple | ||
|
||
from vllm.logger import init_logger | ||
from vllm.utils import cdiv | ||
|
@@ -9,6 +9,7 @@ | |
hash_block_tokens, | ||
hash_request_tokens) | ||
from vllm.v1.request import Request, RequestStatus | ||
from vllm.v1.utils import ConstantList | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -69,7 +70,8 @@ def __init__( | |
# is finished. | ||
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} | ||
|
||
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: | ||
def get_computed_blocks( | ||
self, request: Request) -> Tuple[List[KVCacheBlock], int]: | ||
"""Get the computed (cached) blocks for the request. | ||
Note that the computed blocks must be full. | ||
|
||
|
@@ -81,7 +83,7 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: | |
""" | ||
if not self.enable_caching: | ||
# Prefix caching is disabled. | ||
return [] | ||
return [], 0 | ||
|
||
computed_blocks = [] | ||
|
||
|
@@ -92,6 +94,16 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: | |
hash_request_tokens(self.block_size, request)) | ||
block_hashes = request.kv_block_hashes | ||
|
||
if len(block_hashes) * self.block_size == request.num_tokens: | ||
# When prompt length is divisible by the block size and all blocks | ||
# are cached, we need to recompute the last token. This have to be | ||
# achieved by re-computing an entire block because allocate_slots() | ||
# assumes num_computed_tokens is always a multiple of the block | ||
# size. This limitation can potentially be removed in the future to | ||
# slightly improve the performance. To achieve this, the last block | ||
# is removed from the computed block_hashes. | ||
block_hashes = ConstantList(block_hashes[:-1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was making a similar change but @WoosukKwon prefers to have it in scheduler, because this limitation is in the model runner instead of kv cache manager. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any suggestion on supporting sliding window? I don't want to make the complex special handling of sliding window in scheduler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What handling you plan to add for sliding window in scheduler? In general we should try to be modulization. For this particular logic, I think we could just class Scheduler:
...
def maybe_recompute_last_block(self, computed_blocks, num_computed_tokens):
...
def schedule(self):
...
computed_blocks, num_computed_tokens = self.kv_cache_manager.get_computed_blocks(...)
computed_blocks, num_computed_tokens = self.maybe_recompute_last_block(computed_blocks, num_computed_tokens) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This interface is not very friendly to sliding window as the removal of last block needs a redo of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure I'm ok with it. |
||
|
||
for block_hash in block_hashes: | ||
# block_hashes is a chain of block hashes. If a block hash is not | ||
# in the cached_block_hash_to_id, the following block hashes are | ||
|
@@ -101,7 +113,11 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: | |
else: | ||
break | ||
|
||
return computed_blocks | ||
# NOTE(woosuk): Since incomplete blocks are not eligible for | ||
# sharing, `num_computed_tokens` is always a multiple of | ||
# `block_size`. | ||
num_computed_tokens = len(computed_blocks) * self.block_size | ||
return computed_blocks, num_computed_tokens | ||
|
||
def append_slots( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the docstring about the return type.