From 33cacf997035090cdd7fbc29dbb180dcb7641637 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 5 Nov 2024 00:20:31 +0000 Subject: [PATCH] index using block Signed-off-by: Cody Yu --- tests/v1/core/test_prefix_caching.py | 95 ++++++++------- vllm/v1/core/kv_cache_manager.py | 171 +++++++++++++-------------- vllm/v1/core/scheduler.py | 25 ++-- 3 files changed, 142 insertions(+), 149 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 4310134bf890f..c7b9951440688 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,7 +1,8 @@ """Compare the with and without prefix caching.""" from vllm.inputs import DecoderOnlyInputs from vllm.sampling_params import SamplingParams -from vllm.v1.core.kv_cache_manager import KVCacheManager, Request +from vllm.v1.core.kv_cache_manager import (KVCacheManager, Request, + hash_block_tokens) def make_request(request_id, prompt_token_ids): @@ -31,25 +32,25 @@ def test_prefill(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) - computed_block_ids = manager.get_computed_blocks(req0) - assert not computed_block_ids - block_ids = manager.allocate_slots(req0, 55, computed_block_ids) - assert block_ids == [0, 1, 2, 3, 4] + computed_blocks = manager.get_computed_blocks(req0) + assert not computed_blocks + blocks = manager.allocate_slots(req0, 55, computed_blocks) + assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] # Check full block metadata - prev_block_id = None + parent_block_hash = None for block_id in (0, 1, 2): - assert manager.block_pool[block_id].parent_block_id == prev_block_id - assert manager.block_pool[block_id].block_hash is not None + block_hash = hash_block_tokens(parent_block_hash, + manager.block_pool[block_id].token_ids) + assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].ref_cnt == 1 assert manager.block_pool[block_id].num_hashed_tokens == 16 * ( block_id + 1) assert manager.block_pool[block_id].token_ids == [block_id] * 16 - prev_block_id = block_id + parent_block_hash = block_hash # Check partial/preallocated block metadata for block_id in (3, 4): - assert manager.block_pool[block_id].parent_block_id == block_id - 1 assert manager.block_pool[block_id].block_hash is None assert manager.block_pool[block_id].ref_cnt == 1 assert manager.block_pool[block_id].num_hashed_tokens == 0 @@ -62,14 +63,13 @@ def test_prefill(): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) - computed_block_ids = manager.get_computed_blocks(req1) - assert computed_block_ids == [0, 1, 2] + computed_blocks = manager.get_computed_blocks(req1) + assert [b.block_id for b in computed_blocks] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 - block_ids = manager.allocate_slots(req1, num_new_tokens, - computed_block_ids) - assert block_ids == [5, 6] - for block_id in (0, 1, 2): - assert manager.block_pool[block_id].ref_cnt == 2 + blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) + assert [b.block_id for b in blocks] == [5, 6] + for block in computed_blocks: + assert block.ref_cnt == 2 # At this point, we should have 3 free blocks left. assert manager.free_block_queue.num_free_blocks == 3 @@ -92,12 +92,11 @@ def test_prefill(): # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) - computed_block_ids = manager.get_computed_blocks(req2) - assert computed_block_ids == [0, 1, 2] + computed_block = manager.get_computed_blocks(req2) + assert [b.block_id for b in computed_block] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 - block_ids = manager.allocate_slots(req2, num_new_tokens, - computed_block_ids) - assert block_ids == [7, 8] + blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) + assert [b.block_id for b in blocks] == [7, 8] # Although we only have 5 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -112,11 +111,11 @@ def test_prefill(): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 9)) - computed_block_ids = manager.get_computed_blocks(req3) - assert not computed_block_ids - block_ids = manager.allocate_slots(req2, 16 * 9, computed_block_ids) + computed_blocks = manager.get_computed_blocks(req3) + assert not computed_blocks + blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. - assert block_ids == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] + assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert manager.free_block_queue.num_free_blocks == 0 assert manager.free_block_queue.free_list_head is None assert manager.free_block_queue.free_list_tail is None @@ -138,16 +137,16 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) - computed_block_ids = manager.get_computed_blocks(req0) - assert not computed_block_ids - block_ids = manager.allocate_slots(req0, 55, computed_block_ids) - assert block_ids == [0, 1, 2, 3, 4] + computed_blocks = manager.get_computed_blocks(req0) + assert not computed_blocks + blocks = manager.allocate_slots(req0, 55, computed_blocks) + assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 req0.output_token_ids = [8] * 4 - new_block_ids = manager.append_slots(req0, 4) - assert new_block_ids is not None and len(new_block_ids) == 0 + new_blocks = manager.append_slots(req0, 4) + assert new_blocks is not None and len(new_blocks) == 0 assert len(manager.block_pool[3].token_ids) == 11 # Append slots without allocating a new block, but start using the @@ -156,8 +155,8 @@ def test_decode(): # 6 tokens to fill the previous block, and 10 tokens to fill # the preallocated block. req0.output_token_ids += [7] * (5 + 10) - new_block_ids = manager.append_slots(req0, 15) - assert new_block_ids is not None and len(new_block_ids) == 0 + new_blocks = manager.append_slots(req0, 15) + assert new_blocks is not None and len(new_blocks) == 0 assert len(manager.block_pool[3].token_ids) == 16 assert len(manager.block_pool[4].token_ids) == 10 @@ -166,9 +165,9 @@ def test_decode(): # 6 tokens to fill the previous block, and 10 tokens to fill # the preallocated block. req0.output_token_ids += [12] * (6 + 11) - new_block_ids = manager.append_slots(req0, 17) + new_blocks = manager.append_slots(req0, 17) # Plus one preallocated block. - assert new_block_ids is not None and len(new_block_ids) == 2 + assert new_blocks is not None and len(new_blocks) == 2 assert len(manager.block_pool[4].token_ids) == 16 assert len(manager.block_pool[5].token_ids) == 11 assert len(manager.block_pool[6].token_ids) == 0 @@ -185,18 +184,18 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) - computed_block_ids = manager.get_computed_blocks(req0) - assert not computed_block_ids - block_ids = manager.allocate_slots(req0, 5 * 16 + 7, computed_block_ids) - assert len(block_ids) == 7 # 5 full + 1 partial + 1 preallocated + computed_blocks = manager.get_computed_blocks(req0) + assert not computed_blocks + blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) + assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) - computed_block_ids = manager.get_computed_blocks(req1) - assert not computed_block_ids - block_ids = manager.allocate_slots(req1, 3 * 16, computed_block_ids) - assert len(block_ids) == 3 # 3 full blocks + computed_blocks = manager.get_computed_blocks(req1) + assert not computed_blocks + blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) + assert len(blocks) == 3 # 3 full blocks last_token_id += 3 * 16 assert manager.free_block_queue.num_free_blocks == 0 @@ -210,8 +209,8 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) - computed_block_ids = manager.get_computed_blocks(req2) - assert computed_block_ids == [0, 1] - block_ids = manager.allocate_slots(req2, 3, computed_block_ids) - assert block_ids == [6, 5] + computed_blocks = manager.get_computed_blocks(req2) + assert [b.block_id for b in computed_blocks] == [0, 1] + blocks = manager.allocate_slots(req2, 3, computed_blocks) + assert [b.block_id for b in blocks] == [6, 5] assert manager.free_block_queue.num_free_blocks == 6 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3af52757ded86..cda0fa4273e6e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,6 +1,5 @@ from collections import defaultdict from dataclasses import dataclass, field -from functools import lru_cache from typing import Dict, List, Optional, Tuple from vllm.logger import init_logger @@ -15,8 +14,6 @@ class KVCacheBlock: """KV-cache block metadata.""" # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int - # Parent block ID. Used to include block chain in the block hash. - parent_block_id: Optional[int] = None # Reference count. ref_cnt: int = 0 # Token IDs in the block. @@ -34,7 +31,6 @@ class KVCacheBlock: def reset(self): """Reset the block metadata.""" - self.parent_block_id = None self.ref_cnt = 0 self.token_ids.clear() self.block_hash = None @@ -126,6 +122,7 @@ def append(self, block: KVCacheBlock) -> None: self.free_list_tail = block else: # The free list is empty. + assert self.free_list_head is None self.free_list_head = self.free_list_tail = block block.next_free_block = None @@ -172,9 +169,13 @@ def __init__( self.num_preallocate_tokens = num_preallocate_tokens self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) + # A Block pool of all kv-cache blocks. self.block_pool: List[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) # {block_hash: {block ID: block}}. A cached block is @@ -185,16 +186,16 @@ def __init__( # meaning that if a block becomes full and is cached, we don't check # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that - # block IDs are append-only. + # block tables are append-only. self.cached_block_hash_to_block: Dict[int, Dict[ int, KVCacheBlock]] = defaultdict(dict) - # Mapping from request ID to block IDs to track the blocks allocated + # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_block_ids: Dict[str, List[int]] = {} + self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} - def get_computed_blocks(self, request: Request) -> List[int]: + def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -202,13 +203,13 @@ def get_computed_blocks(self, request: Request) -> List[int]: request: The request to get the computed blocks. Returns: - A list of block IDs that are computed for the request. + A list of blocks that are computed for the request. """ if not self.enable_caching: # Prefix caching is disabled. return [] - computed_block_ids = [] + computed_blocks = [] block_hashes = self.hash_prompt_tokens(request.prompt_token_ids) for block_hash in block_hashes: @@ -216,17 +217,17 @@ def get_computed_blocks(self, request: Request) -> List[int]: # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := self._get_cached_block(block_hash): - computed_block_ids.append(cached_block.block_id) + computed_blocks.append(cached_block) else: break - return computed_block_ids + return computed_blocks def append_slots( self, request: Request, num_tokens: int, - ) -> Optional[List[int]]: + ) -> Optional[List[KVCacheBlock]]: """Append slots to the block table of the request. We first append slots to already allocated blocks. If the allocated blocks are not enough, we allocate new blocks. @@ -236,22 +237,22 @@ def append_slots( num_tokens: The number of tokens to append. Returns: - A list of new block IDs if new blocks are allocated, or None + A list of new blocks if new blocks are allocated, or None if new blocks are required but cannot be allocated. """ num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, self.block_size) - req_block_ids = self.req_to_block_ids[request.request_id] + req_blocks = self.req_to_blocks[request.request_id] - num_new_blocks = num_required_blocks - len(req_block_ids) + num_new_blocks = num_required_blocks - len(req_blocks) if num_new_blocks > self.free_block_queue.num_free_blocks: # Need to allocate new blocks due to insufficient pre-allocated # slots, but we cannot allocate new blocks due to the limit. return None - # Assign token IDs to already allocated blocks. + # When caching is enabled, assign token IDs to already allocated blocks. new_token_ids = None - parent_block_id = None + parent_block = None if self.enable_caching: # Figure out the token IDs to add to the blocks. if request.num_computed_tokens < request.num_prompt_tokens: @@ -269,25 +270,25 @@ def append_slots( # Find the last full block index. # TODO: This may be optimized by calculating the computed tokens. - last_full_block_idx = len(req_block_ids) - 1 - while (last_full_block_idx >= 0 and self.block_pool[ - req_block_ids[last_full_block_idx]].block_hash is None): + last_full_block_idx = len(req_blocks) - 1 + while (last_full_block_idx >= 0 + and req_blocks[last_full_block_idx].block_hash is None): last_full_block_idx -= 1 - parent_block_id = (last_full_block_idx - if last_full_block_idx >= 0 else None) + parent_block = (req_blocks[last_full_block_idx] + if last_full_block_idx >= 0 else None) token_id_idx = self._add_token_ids_to_blocks( - block_ids=req_block_ids[last_full_block_idx + 1:], + blocks=req_blocks[last_full_block_idx + 1:], token_ids=new_token_ids, - parent_block_id=parent_block_id) + parent_block=parent_block) new_token_ids = new_token_ids[token_id_idx:] - parent_block_id = req_block_ids[-1] + parent_block = req_blocks[-1] # No new block is needed. When caching is enabled, we make sure # token_id_idx is equal to len(new_token_ids), meaning that all tokens # are added to allocated blocks. - if num_required_blocks <= len(req_block_ids): + if num_required_blocks <= len(req_blocks): assert not self.enable_caching or token_id_idx == num_tokens, \ f"{token_id_idx=} != {num_tokens=}" return [] @@ -297,27 +298,26 @@ def append_slots( num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks, self.free_block_queue.num_free_blocks) new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - parent_block_id) - new_block_ids = [blk.block_id for blk in new_blocks] - req_block_ids.extend(new_block_ids) - return new_block_ids + parent_block) + req_blocks.extend(new_blocks) + return new_blocks def allocate_slots( self, request: Request, num_tokens: int, - computed_block_ids: List[int], - ) -> Optional[List[int]]: + computed_blocks: List[KVCacheBlock], + ) -> Optional[List[KVCacheBlock]]: """Allocate slots for a new request. Args: request: The request to allocate slots. num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. - computed_block_ids: The block IDs that have already been computed. + computed_blocks: The blocks that have already been computed. Returns: - A list of new allocated block IDs. + A list of new allocated blocks. """ if num_tokens == 0: raise ValueError( @@ -326,10 +326,8 @@ def allocate_slots( # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = len([ - bid for bid in computed_block_ids - if self.block_pool[bid].ref_cnt == 0 - ]) + num_evictable_computed_blocks = len( + [blk for blk in computed_blocks if blk.ref_cnt == 0]) num_required_blocks = cdiv(num_tokens, self.block_size) if (num_required_blocks > self.free_block_queue.num_free_blocks - @@ -343,11 +341,21 @@ def allocate_slots( num_required_blocks + self.num_preallocate_blocks, self.free_block_queue.num_free_blocks - num_evictable_computed_blocks) - # Get the token IDs for the blocks being allocated for hashing. - # Note that we expect this function to be called only once per - # request, so we must have all new token IDs in the prompt. - num_computed_tokens = len(computed_block_ids) * self.block_size + + num_computed_tokens = len(computed_blocks) * self.block_size + + # When caching is enabled, get the new token IDs and the parent block + # ID to generate cache keys. + new_token_ids = None + parent_block = None if self.enable_caching: + # Touch the computed blocks to make sure they won't be evicted. + self._touch(computed_blocks) + + # Get the token IDs for the blocks being allocated for hashing. + # Note that we expect allocate_slots to be called only once per + # new request, so num_computed_tokens + num_tokens must be less + # than or equal to the total number of tokens in the prompt. new_token_ids = request.prompt_token_ids[ num_computed_tokens:num_computed_tokens + num_tokens] if not new_token_ids: @@ -356,23 +364,15 @@ def allocate_slots( f"#prompt_tokens={len(request.prompt_token_ids)} < " f"#computed_tokens={num_computed_tokens}") - # Touch the computed blocks to make sure they won't be evicted. - self._touch(computed_block_ids) - # Get the parent block ID to construct the block chain. - parent_block_id = computed_block_ids[ - -1] if computed_block_ids else None - else: - new_token_ids = None - parent_block_id = None + parent_block = computed_blocks[-1] if computed_blocks else None + new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - parent_block_id) - new_block_ids = [blk.block_id for blk in new_blocks] + parent_block) # Concatenate the computed block IDs and the new block IDs. - block_ids = computed_block_ids + new_block_ids - self.req_to_block_ids[request.request_id] = block_ids - return new_block_ids + self.req_to_blocks[request.request_id] = computed_blocks + new_blocks + return new_blocks def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -382,22 +382,22 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ - block_ids = self.req_to_block_ids.pop(request.request_id) + blocks = self.req_to_blocks.pop(request.request_id) if self.enable_caching: # Free blocks in reverse order so that the tail blocks are # freed first. - block_ids = reversed(block_ids) + blocks = reversed(blocks) - for block_id in block_ids: - self.block_pool[block_id].ref_cnt -= 1 - if self.block_pool[block_id].ref_cnt == 0: - self.free_block_queue.append(self.block_pool[block_id]) + for block in blocks: + block.ref_cnt -= 1 + if block.ref_cnt == 0: + self.free_block_queue.append(block) def _get_new_blocks( self, num_blocks: int, token_ids: Optional[List[int]] = None, - parent_block_id: Optional[int] = None) -> List[KVCacheBlock]: + parent_block: Optional[int] = None) -> List[KVCacheBlock]: """Get new blocks from the free block pool, and add token IDs to allocated blocks if caching is enabled. Note that we do not check block cache in this function. @@ -405,7 +405,7 @@ def _get_new_blocks( Args: num_blocks: The number of blocks to allocate. token_ids: The token IDs in the blocks. None if caching is disabled. - parent_block_id: The parent block ID. Used to include block chain + parent_block: The parent block. Used to include block chain in the block hash. Returns: @@ -442,9 +442,7 @@ def _get_new_blocks( if self.enable_caching: assert token_ids is not None token_id_idx = self._add_token_ids_to_blocks( - block_ids=[blk.block_id for blk in ret], - token_ids=token_ids, - parent_block_id=parent_block_id) + blocks=ret, token_ids=token_ids, parent_block=parent_block) assert token_id_idx == len(token_ids) return ret @@ -483,51 +481,46 @@ def _get_cached_block(self, block_hash: int) -> Optional[KVCacheBlock]: return self.cached_block_hash_to_block[block_hash][first_block_id] return None - def _touch(self, block_ids: List[int]) -> None: + def _touch(self, blocks: List[KVCacheBlock]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. Args: - block_id: The ID of the block to touch. + blocks: A list of blocks to touch. """ - for block_id in block_ids: - curr_block = self.block_pool[block_id] + for block in blocks: # ref_cnt=0 means this block is in the free list (i.e. eviction # candidate), so remove it. - if curr_block.ref_cnt == 0: - self.free_block_queue.remove(curr_block) - curr_block.ref_cnt += 1 - - def _add_token_ids_to_blocks(self, - block_ids: List[int], - token_ids: List[int], - parent_block_id: Optional[int] = None) -> int: + if block.ref_cnt == 0: + self.free_block_queue.remove(block) + block.ref_cnt += 1 + + def _add_token_ids_to_blocks( + self, + blocks: List[KVCacheBlock], + token_ids: List[int], + parent_block: Optional[KVCacheBlock] = None) -> int: """Add token IDs to a list of allocated blocks. If a block becomes full after adding token IDs, cache it. Return the token ID index that has not been added to the blocks if the blocks are not enough to hold all the token IDs. Args: - block_ids: A list of block IDs to add token IDs. + blocks: A list of blocks to add token IDs. token_ids: A list of token IDs to add. - parent_block_id: The parent block ID. None if this is the + parent_block: The parent block. None if this is the first block. Returns: The starting token ID index that has not been added to the blocks due to insufficient given blocks. """ - parent_block = self.block_pool[ - parent_block_id] if parent_block_id is not None else None token_id_start = 0 - for block_id in block_ids: - curr_block = self.block_pool[block_id] - curr_block.parent_block_id = parent_block_id - + for curr_block in blocks: # If all token IDs are added, then the rest of the blocks are # preallocated blocks, so we only need to update the - # parent_block_id. + # parent_block_id. FIXME if token_id_start == len(token_ids): continue @@ -539,7 +532,6 @@ def _add_token_ids_to_blocks(self, if len(curr_block.token_ids) == self.block_size: self._cache_full_block(curr_block, parent_block) parent_block = curr_block - parent_block_id = parent_block.block_id token_id_start = token_id_end return token_id_start @@ -567,7 +559,6 @@ def hash_prompt_tokens(self, token_ids: List[int]) -> List[int]: return ret -@lru_cache(maxsize=1024) def hash_block_tokens(parent_block_hash: Optional[int], cur_block_token_ids: Tuple[int]) -> int: """Computes a hash value corresponding to the contents of a block and diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 6752f85cff18a..a60f8b8138ecf 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -91,9 +91,9 @@ def schedule(self) -> "SchedulerOutput": assert num_new_tokens > 0 while True: - new_block_ids = self.kv_cache_manager.append_slots( + new_blocks = self.kv_cache_manager.append_slots( request, num_new_tokens) - if new_block_ids is None: + if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. preempted_req = self.running.pop() @@ -110,7 +110,9 @@ def schedule(self) -> "SchedulerOutput": # The request can be scheduled. scheduled_running_reqs.append(request) - req_to_new_block_ids[request.request_id] = new_block_ids + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -126,12 +128,12 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[0] # Get already-cached tokens. - computed_block_ids = self.kv_cache_manager.get_computed_blocks( + computed_blocks = self.kv_cache_manager.get_computed_blocks( request) # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. - num_computed_tokens = len(computed_block_ids) * self.block_size + num_computed_tokens = len(computed_blocks) * self.block_size # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests, @@ -143,12 +145,12 @@ def schedule(self) -> "SchedulerOutput": # the last token. num_computed_tokens -= 1 num_new_tokens = 1 - computed_block_ids.pop() + computed_blocks.pop() num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 - new_block_ids = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_block_ids) - if new_block_ids is None: + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens, computed_blocks) + if new_blocks is None: # The request cannot be scheduled. break request.num_computed_tokens = num_computed_tokens @@ -163,8 +165,9 @@ def schedule(self) -> "SchedulerOutput": raise RuntimeError( f"Invalid request status: {request.status}") - req_to_new_block_ids[request.request_id] = ( - computed_block_ids + new_block_ids) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in computed_blocks + new_blocks + ] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING