From 0f8a54cf7b7280c5145536bc6f61f09b046c0751 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 15 Jan 2025 01:10:51 -0800 Subject: [PATCH 01/48] can run Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 4 + vllm/forward_context.py | 20 ++- vllm/v1/core/kv_cache_manager.py | 49 ++++-- vllm/v1/core/kv_cache_utils.py | 12 +- vllm/v1/core/scheduler.py | 38 ++-- vllm/v1/request.py | 20 ++- vllm/v1/worker/block_table.py | 42 +++-- vllm/v1/worker/gpu_input_batch.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 268 ++++++++++++++++------------- 9 files changed, 273 insertions(+), 184 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a283e87d84070..365936426c12d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -241,6 +241,8 @@ def unified_attention( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] return self.impl.forward(query, key, value, kv_cache, attn_metadata, self._k_scale, self._v_scale) @@ -274,6 +276,8 @@ def unified_attention_with_output( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self.impl.forward(query, key, value, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 828b394ec5d21..c34694790b8d9 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -2,7 +2,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import torch @@ -24,11 +24,21 @@ @dataclass class ForwardContext: - # copy from vllm_config.compilation_config.static_forward_context + """ + Map from layer_name to all attention modules + copy from vllm_config.compilation_config.static_forward_context + """ attn_layers: Dict[str, Any] - # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass - # TODO: remove after making all virtual_engines share the same kv cache + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, mapping from layer_name to + AttentionMetadata of that layer + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", Dict[str, "AttentionMetadata"]] + """ + The virtual_engine for v0 pipeline parallelism + """ virtual_engine: int # set dynamically for each forward pass diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bac77443c8560..85e7946edab5e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,7 +4,8 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + KVCacheBlock, KVCacheBlocks, + ReqKVCacheBlocks, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) @@ -66,11 +67,12 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} + # is finished. KVCacheManager only supports models with one layer type, + # so the blocks can be stored by KVCacheBlocks type. + self.req_to_blocks: Dict[str, KVCacheBlocks] = {} - def get_computed_blocks( - self, request: Request) -> Tuple[List[KVCacheBlock], int]: + def get_computed_blocks(self, + request: Request) -> Tuple[ReqKVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -78,13 +80,14 @@ def get_computed_blocks( request: The request to get the computed blocks. Returns: + # TODO: update docstring A tuple containing: - A list of blocks that are computed for the request. - The number of computed tokens. """ if not self.enable_caching: # Prefix caching is disabled. - return [], 0 + return [[]], 0 computed_blocks = [] @@ -92,8 +95,8 @@ def get_computed_blocks( # if the request was preempted and resumed. if not request.kv_block_hashes: request.set_kv_block_hashes( - hash_request_tokens(self.block_size, request)) - block_hashes = request.kv_block_hashes + [hash_request_tokens(self.block_size, request)]) + block_hashes = request.kv_block_hashes[0] for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not @@ -108,13 +111,13 @@ def get_computed_blocks( # sharing, `num_computed_tokens` is always a multiple of # `block_size`. num_computed_tokens = len(computed_blocks) * self.block_size - return computed_blocks, num_computed_tokens + return [computed_blocks], num_computed_tokens def append_slots( self, request: Request, num_tokens: int, - ) -> Optional[List[KVCacheBlock]]: + ) -> Optional[ReqKVCacheBlocks]: """Append slots to the block table of the request. We first append slots to already allocated blocks. If the allocated blocks are not enough, we allocate new blocks. @@ -126,6 +129,7 @@ def append_slots( Returns: A list of new blocks if new blocks are allocated, or None if new blocks are required but cannot be allocated. + # TODO: update docstring """ num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, self.block_size) @@ -159,7 +163,7 @@ def append_slots( req_blocks.extend(new_blocks) if not self.enable_caching: - return new_blocks + return [new_blocks] num_computed_full_blocks = (request.num_computed_tokens // self.block_size) @@ -182,16 +186,17 @@ def append_slots( full_blocks=new_full_blocks, prev_block=req_blocks[num_computed_full_blocks - 1] if num_computed_full_blocks >= 1 else None, + kv_cache_group_id=0, ) - return new_blocks + return [new_blocks] def allocate_slots( self, request: Request, num_tokens: int, - computed_blocks: List[KVCacheBlock], - ) -> Optional[List[KVCacheBlock]]: + computed_blocks_of_groups: ReqKVCacheBlocks, + ) -> Optional[ReqKVCacheBlocks]: """Allocate slots for a new request. Args: @@ -201,12 +206,14 @@ def allocate_slots( computed_blocks: A list of computed blocks. Returns: + # TODO: update docstring A list of new allocated blocks. """ if num_tokens == 0: raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") + computed_blocks = computed_blocks_of_groups[0] # only one group # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. @@ -260,9 +267,10 @@ def allocate_slots( # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, prev_block=computed_blocks[-1] if computed_blocks else None, + kv_cache_group_id=0, ) - return new_blocks + return [new_blocks] def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -333,7 +341,7 @@ def get_num_common_prefix_blocks( num_common_blocks += 1 else: break - return num_common_blocks + return [num_common_blocks] def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: """Get new blocks from the free block pool. @@ -421,6 +429,7 @@ def _cache_full_blocks( blk_start_idx: int, full_blocks: List[KVCacheBlock], prev_block: Optional[KVCacheBlock], + kv_cache_group_id: int, ) -> None: """Cache a list of full blocks for prefix caching. @@ -437,7 +446,8 @@ def _cache_full_blocks( full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ - num_cached_block_hashes = len(request.kv_block_hashes) + num_cached_block_hashes = len( + request.kv_block_hashes[kv_cache_group_id]) # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None @@ -456,7 +466,8 @@ def _cache_full_blocks( # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. - block_hash = request.kv_block_hashes[blk_idx] + block_hash = request.kv_block_hashes[kv_cache_group_id][ + blk_idx] else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. @@ -478,7 +489,7 @@ def _cache_full_blocks( # Compute the hash of the current block. block_hash = hash_block_tokens(prev_block_hash_value, block_tokens, extra_keys) - request.append_kv_block_hashes(block_hash) + request.append_kv_block_hashes(kv_cache_group_id, block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 22a5d2fb08a48..a9cd1d2255c96 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -2,7 +2,6 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, List, NamedTuple, Optional, Tuple - from vllm.logger import init_logger from vllm.v1.request import Request @@ -61,6 +60,17 @@ def reset_hash(self): self._block_hash = None +"""When the model contains different types of layers (e.g., full attention + +sliding window attention), the layers will be splited to multiple groups, where +layers in the same group has the same type and with the same KVCacheBlock. +KVCacheBlocks: the blocks in one (group) of layer in one request +ReqKVCacheBlocks: the blocks in all groups of layers in one request. +Refer to KVCacheConfig class for the meaning of "group" +""" +KVCacheBlocks = List[KVCacheBlock] +ReqKVCacheBlocks = List[KVCacheBlocks] + + class FreeKVCacheBlockQueue: """This class organizes a list of KVCacheBlock objects to a doubly linked list of free blocks. We implement this class instead of using Python diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 45e67c94f8f15..574d128d396a4 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -95,7 +95,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] - req_to_new_block_ids: Dict[str, List[int]] = {} + req_to_new_block_ids: Dict[str, List[List[int]]] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -154,9 +154,9 @@ def schedule(self) -> "SchedulerOutput": # Schedule the request. scheduled_running_reqs.append(request) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] + req_to_new_block_ids[request.request_id] = [[ + b.block_id for b in new_blocks_of_group + ] for new_blocks_of_group in new_blocks] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -230,9 +230,11 @@ def schedule(self) -> "SchedulerOutput": raise RuntimeError( f"Invalid request status: {request.status}") - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[request.request_id] = [[ + b.block_id + for b in computed_blocks_of_group + new_blocks_of_group + ] for computed_blocks_of_group, new_blocks_of_group in zip( + computed_blocks, new_blocks)] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -536,14 +538,16 @@ class NewRequestData: mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams - block_ids: List[int] + # List of block IDs for each group. + # Refer to KVCacheConfig class for the meaning of "group" + block_ids: List[List[int]] num_computed_tokens: int @classmethod def from_request( cls, request: Request, - block_ids: List[int], + block_ids: List[List[int]], num_computed_tokens: int, ) -> "NewRequestData": return cls( @@ -563,14 +567,16 @@ def from_request( class ResumedRequestData: req_id: str - block_ids: List[int] + # List of block IDs for each kv cache group. + # Refer to KVCacheConfig class for the meaning of "group" + block_ids: List[List[int]] num_computed_tokens: int @classmethod def from_request( cls, request: Request, - block_ids: List[int], + block_ids: List[List[int]], num_computed_tokens: int, ) -> "ResumedRequestData": return cls( @@ -584,14 +590,16 @@ def from_request( class RunningRequestData: req_id: str - new_block_ids: List[int] + # List of block IDs for each kv cache group. + # Refer to KVCacheConfig class for the meaning of "group" + new_block_ids: List[List[int]] num_computed_tokens: int @classmethod def from_request( cls, request: Request, - new_block_ids: List[int], + new_block_ids: List[List[int]], num_computed_tokens: int, ) -> "RunningRequestData": return cls( @@ -611,7 +619,9 @@ class SchedulerOutput: num_scheduled_tokens: Dict[str, int] total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] - num_common_prefix_blocks: int + # Number of common prefix blocks per kv cache group + # Refer to KVCacheConfig class for the meaning of "group" + num_common_prefix_blocks: List[int] preempted_req_ids: Set[str] finished_req_ids: Set[str] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 45450165eaefe..097b2ffe7c831 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -62,7 +62,11 @@ def __init__( # Cache the computed kv block hashes of the request to avoid # recomputing. - self._kv_block_hashes: List[BlockHashType] = [] + # Different kv cache groups may have different block_size, so save their + # hash seperately. Each outer list represents a group, and each inner + # list contains the hashes of blocks with that group's block_size. + # Refer to KVCacheConfig class for the meaning of "group". + self._kv_block_hashes: List[List[BlockHashType]] = [] @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": @@ -127,15 +131,19 @@ def get_num_encoder_tokens(self, input_id: int) -> int: return num_tokens @property - def kv_block_hashes(self) -> ConstantList["BlockHashType"]: + def kv_block_hashes(self) -> List[ConstantList["BlockHashType"]]: # Prevent directly appending to the kv_block_hashes. - return ConstantList(self._kv_block_hashes) + return ConstantList([ + ConstantList(kv_block_hashes_one_group) + for kv_block_hashes_one_group in self._kv_block_hashes + ]) - def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: + def set_kv_block_hashes(self, value: List[List["BlockHashType"]]) -> None: self._kv_block_hashes = value - def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: - self._kv_block_hashes.append(block_hash) + def append_kv_block_hashes(self, group_id: int, + block_hash: "BlockHashType") -> None: + self._kv_block_hashes[group_id].append(block_hash) class RequestStatus(enum.IntEnum): diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 26a2084b131fa..31b9efb361166 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -17,49 +17,59 @@ def __init__( max_num_blocks_per_req: int, pin_memory: bool, device: torch.device, + num_kv_cache_groups: int, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.pin_memory = pin_memory self.device = device + self.num_kv_cache_groups = num_kv_cache_groups self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (num_kv_cache_groups, max_num_reqs, max_num_blocks_per_req), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (num_kv_cache_groups, max_num_reqs, max_num_blocks_per_req), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) self.block_table_np = self.block_table_cpu.numpy() - self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.num_blocks_per_row = np.zeros((num_kv_cache_groups, max_num_reqs), + dtype=np.int32) def append_row( self, row_idx: int, - start: int, - block_ids: List[int], + block_ids: List[List[int]], ) -> None: - num_blocks = len(block_ids) - self.block_table_np[row_idx, start:start + num_blocks] = block_ids - self.num_blocks_per_row[row_idx] = start + num_blocks + for i, (num_blocks, block_ids_of_group) in enumerate( + zip(self.num_blocks_per_row[:, row_idx], block_ids)): + num_new_blocks = len(block_ids_of_group) + self.block_table_np[i, row_idx, num_blocks:num_blocks + + num_new_blocks] = block_ids_of_group + self.num_blocks_per_row[i, row_idx] = num_blocks + num_new_blocks - def add_row(self, row_idx: int, block_ids: List[int]) -> None: - self.append_row(row_idx, 0, block_ids) + def add_row(self, row_idx: int, block_ids: List[List[int]]) -> None: + self.append_row(row_idx, block_ids) def move_row(self, src: int, tgt: int) -> None: - num_blocks = self.num_blocks_per_row[src] - self.block_table_np[tgt, :num_blocks] = self.block_table_np[ - src, :num_blocks] - self.num_blocks_per_row[tgt] = num_blocks + num_blocks = self.num_blocks_per_row[:, src] + self.block_table_np[:, tgt, :max(num_blocks)] = \ + self.block_table_np[:, src, :max(num_blocks)] + self.num_blocks_per_row[:, tgt] = num_blocks def commit(self, num_reqs: int) -> None: - self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], - non_blocking=True) + # NOTE: an alternative is + # self.block_table[:, :num_reqs].copy_( + # self.block_table_cpu[:, :num_reqs], non_blocking=True) + # but it will be a blocking copy when num_kv_cache_groups>1. + for i in range(self.num_kv_cache_groups): + self.block_table[i, :num_reqs].copy_( + self.block_table_cpu[i, :num_reqs], non_blocking=True) def clear(self) -> None: self.block_table.fill_(0) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22f0..4a74bd76841b4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -26,7 +26,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: List[int] + block_ids: List[List[int]] num_computed_tokens: int output_token_ids: List[int] @@ -45,6 +45,7 @@ def __init__( device: torch.device, pin_memory: bool, vocab_size: int, + num_kv_cache_groups: int, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -78,6 +79,7 @@ def __init__( max_num_blocks_per_req=max_num_blocks_per_req, pin_memory=pin_memory, device=device, + num_kv_cache_groups=num_kv_cache_groups, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fb87dc5a8222a..60017e41ed649 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -93,7 +93,7 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model - self.kv_caches: List[torch.Tensor] = [] + self.kv_caches: List[List[torch.Tensor]] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} @@ -107,6 +107,7 @@ def __init__( device=self.device, pin_memory=self.pin_memory, vocab_size=model_config.get_vocab_size(), + num_kv_cache_groups=1, # TODO: update after PR #11960 ) self.use_cuda_graph = (self.vllm_config.compilation_config.level @@ -152,10 +153,12 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.slot_mapping_cpu = torch.zeros( + 1, # TODO: update after PR #11960 + self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, @@ -206,13 +209,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_data.num_computed_tokens) # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: + max_num_new_blocks = max(len(b) for b in req_data.new_block_ids) + if max_num_new_blocks == 0: continue - start_index = len(req_state.block_ids) - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, + self.input_batch.block_table.append_row(req_index, req_data.new_block_ids) + for group_id, new_block_ids in enumerate(new_block_ids): + req_state.block_ids[group_id].extend(new_block_ids) req_ids_to_add: List[str] = [] # Add new requests to the cached states. @@ -319,22 +322,30 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + # for i in range(len(self.kv_cache_config.groups)): + for i in range(1): # TODO: update after PR #11960 + # group_spec = self.kv_cache_config.kv_cache_spec[ + # self.kv_cache_config.groups[i][0]] + # block_size = group_spec.block_size + block_size = self.block_size + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu[i].flatten()[block_table_indices]\ + .numpy() + block_offsets = positions_np % block_size + np.add(block_numbers * block_size, + block_offsets, + out=self.slot_mapping_np[i, :total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -356,105 +367,115 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True) seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) - slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( - self.device, non_blocking=True).long() - - # Prepare for cascade attention if needed. - common_prefix_len = (scheduler_output.num_common_prefix_blocks * - self.block_size) - if common_prefix_len == 0: - # Common case. - use_cascade = False - else: - # NOTE(woosuk): Cascade attention uses two attention kernels: one - # for the common prefix and the other for the rest. For the first - # kernel, we concatenate all the query tokens (possibly from - # different requests) and treat them as if they are from the same - # request. Then, we use bi-directional attention to process the - # common prefix in the KV cache. Importantly, this means that the - # first kernel does not do any masking. - - # Consider the following example: - # Request 1's input query: [D, E, X] - # Request 1's kv cache: [A, B, C, D, E, X] - # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) - # Request 2's input query: [E, Y] - # Request 2's kv cache: [A, B, C, D, E, Y] - # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) - - # If we use [A, B, C, D, E] as the common prefix, then the - # first kernel will compute the bi-directional attention between - # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. - # However, this is wrong because D in Request 1 should not attend to - # E in the common prefix (i.e., we need masking). - # To avoid this, [A, B, C, D] should be the common prefix. - # That is, the common prefix should be capped by the minimum - # num_computed_tokens among the requests, and plus one to include - # the first token of the query. - - # In practice, we use [A, B, C] as the common prefix, instead of - # [A, B, C, D] (i.e., the common prefix is capped by the minimum - # num_computed_tokens, without plus one). - # This is because of an implementation detail: We want to always - # use two kernels for cascade attention. Let's imagine: - # Request 3's input query: [D] - # Request 3's kv cache: [A, B, C, D] - # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) - # If we use [A, B, C, D] as the common prefix for Request 1-3, - # then Request 3 will be processed only by the first kernel, - # and the second kernel will get an empty input. While this is not - # a fundamental problem, our current implementation does not support - # this case. - common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) - # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * + # layer_name -> AttentionMetadata + attn_metadata: Dict[str, FlashAttentionMetadata] = {} + # for i, layer_ids in enumerate(self.kv_cache_config.groups): + for i in range(1): + layer_ids = list(self.vllm_config.compilation_config. + static_forward_context.keys()) + slot_mapping = self.slot_mapping_cpu[ + i, :total_num_scheduled_tokens].to(self.device, + non_blocking=True).long() + + # Prepare for cascade attention if needed. + common_prefix_len = (scheduler_output.num_common_prefix_blocks[i] * self.block_size) - use_cascade = FlashAttentionBackend.use_cascade_attention( + if common_prefix_len == 0: + # Common case. + use_cascade = False + else: + # NOTE(woosuk): Cascade attention uses two attention kernels: + # one for the common prefix and the other for the rest. For the + # first kernel, we concatenate all the query tokens (possibly + # from different requests) and treat them as if they are from + # the same request. Then, we use bi-directional attention to + # process the common prefix in the KV cache. Importantly, this + # means that the first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not + # attend to E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to + # include the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is + # not a fundamental problem, our current implementation does not + # support this case. + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // self.block_size * + self.block_size) + use_cascade = FlashAttentionBackend.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + use_alibi=False, # FIXME + use_sliding_window=self.sliding_window is not None, + num_sms=self.num_sms, + ) + + if use_cascade: + # TODO: Optimize. + cu_prefix_query_lens = torch.tensor( + [0, total_num_scheduled_tokens], + dtype=torch.int32, + device=self.device) + cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], + dtype=torch.int32, + device=self.device) + cu_suffix_kv_lens = ( + self.seq_start_loc_np[:num_reqs + 1] - + self.arange_np[:num_reqs + 1] * common_prefix_len) + cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to( + self.device) + else: + cu_prefix_query_lens = None + cu_prefix_kv_lens = None + cu_suffix_kv_lens = None + + attn_metadata_of_group = FlashAttentionMetadata( + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_start_loc=seq_start_loc, + block_table=(self.input_batch.block_table.get_device_tensor()[ + i, :num_reqs]), + slot_mapping=slot_mapping, + use_cascade=use_cascade, common_prefix_len=common_prefix_len, - query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - use_alibi=False, # FIXME - use_sliding_window=self.sliding_window is not None, - num_sms=self.num_sms, + cu_prefix_query_lens=cu_prefix_query_lens, + cu_prefix_kv_lens=cu_prefix_kv_lens, + cu_suffix_kv_lens=cu_suffix_kv_lens, ) - if use_cascade: - # TODO: Optimize. - cu_prefix_query_lens = torch.tensor( - [0, total_num_scheduled_tokens], - dtype=torch.int32, - device=self.device) - cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], - dtype=torch.int32, - device=self.device) - cu_suffix_kv_lens = ( - self.seq_start_loc_np[:num_reqs + 1] - - self.arange_np[:num_reqs + 1] * common_prefix_len) - cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to( - self.device) - else: - cu_prefix_query_lens = None - cu_prefix_kv_lens = None - cu_suffix_kv_lens = None - - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_start_loc=seq_start_loc, - block_table=( - self.input_batch.block_table.get_device_tensor()[:num_reqs]), - slot_mapping=slot_mapping, - use_cascade=use_cascade, - common_prefix_len=common_prefix_len, - cu_prefix_query_lens=cu_prefix_query_lens, - cu_prefix_kv_lens=cu_prefix_kv_lens, - cu_suffix_kv_lens=cu_suffix_kv_lens, - ) + for layer_id in layer_ids: + attn_metadata[layer_id] = attn_metadata_of_group # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled @@ -582,7 +603,10 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens + + # TODO: update after PR #11960 + attn_metadata[next(iter( + attn_metadata.keys()))].num_input_tokens = num_input_tokens if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision From 990d0861585a9d25b509251d0203d0f6a721951e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 15 Jan 2025 04:54:24 -0800 Subject: [PATCH 02/48] fix tests Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 127 ++++++++++++------------ tests/v1/worker/test_gpu_input_batch.py | 3 +- vllm/v1/core/kv_cache_manager.py | 6 +- vllm/v1/utils.py | 3 + vllm/v1/worker/gpu_model_runner.py | 2 +- 5 files changed, 74 insertions(+), 67 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index fafd9d0ce4455..c1111045e92d2 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -50,11 +50,11 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(req0.kv_block_hashes) == 3 - assert not computed_blocks + assert len(req0.kv_block_hashes[0]) == 3 + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -75,13 +75,13 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 - assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert len(req1.kv_block_hashes[0]) == 3 + assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5, 6] - for block in computed_blocks: + assert [b.block_id for b in blocks[0]] == [5, 6] + for block in computed_blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 3 free blocks left. @@ -106,12 +106,12 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(req2.kv_block_hashes) == 3 - assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert len(req2.kv_block_hashes[0]) == 3 + assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [7, 8] + assert [b.block_id for b in blocks[0]] == [7, 8] # Although we only have 5 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -127,11 +127,11 @@ def test_prefill(): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 9)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] + assert [b.block_id for b in blocks[0]] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert manager.free_block_queue.num_free_blocks == 0 assert manager.free_block_queue.free_list_head is None assert manager.free_block_queue.free_list_tail is None @@ -155,17 +155,17 @@ def test_decode(): unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) new_blocks = manager.append_slots(req0, 4) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks[0]) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is None # Append slots without allocating a new block, but start using the @@ -176,7 +176,7 @@ def test_decode(): for _ in range(5 + 10): req0.append_output_token_ids(7) new_blocks = manager.append_slots(req0, 15) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks[0]) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None # Append slots with allocating a new block. @@ -187,7 +187,7 @@ def test_decode(): req0.append_output_token_ids(12) new_blocks = manager.append_slots(req0, 17) # Plus one preallocated block. - assert new_blocks is not None and len(new_blocks) == 2 + assert new_blocks is not None and len(new_blocks[0]) == 2 def test_evict(): @@ -203,19 +203,19 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) - assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated + assert len(blocks[0]) == 7 # 5 full + 1 partial + 1 preallocated # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) - assert len(blocks) == 3 # 3 full blocks + assert len(blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 assert manager.free_block_queue.num_free_blocks == 0 @@ -230,10 +230,10 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert [b.block_id for b in computed_blocks] == [0, 1] + assert [b.block_id for b in computed_blocks[0]] == [0, 1] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [6, 5] + assert [b.block_id for b in blocks[0]] == [6, 5] assert manager.free_block_queue.num_free_blocks == 6 @@ -256,10 +256,10 @@ def test_hash_block_correct_reuse(): num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, computed_blocks) - assert len(blocks) == 1 + assert len(blocks[0]) == 1 # Deallocate the block. manager.free(req) @@ -268,12 +268,12 @@ def test_hash_block_correct_reuse(): # block is cleared. req = make_request("1", list(range(num_tokens - 1))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) - assert len(blocks) == 1 + assert len(blocks[0]) == 1 - assert manager.block_pool[blocks[0].block_id].block_hash is None + assert manager.block_pool[blocks[0][0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -295,20 +295,20 @@ def test_computed_blocks_not_evicted(): num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 0 + assert len(blocks[0]) == 1 + assert blocks[0][0].block_id == 0 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 1 + assert len(blocks[0]) == 1 + assert blocks[0][0].block_id == 1 # Free the blocks. manager.free(req0) @@ -318,14 +318,14 @@ def test_computed_blocks_not_evicted(): # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks) == 1 - assert computed_blocks[0].block_id == 0 + assert len(computed_blocks[0]) == 1 + assert computed_blocks[0][0].block_id == 0 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 1 + assert len(blocks[0]) == 1 + assert blocks[0][0].block_id == 1 def test_basic_prefix_caching_disabled(): @@ -345,10 +345,10 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, computed_blocks) - assert len(blocks) == 3 + assert len(blocks[0]) == 3 # Free the blocks. manager.free(req1) @@ -356,15 +356,15 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, computed_blocks) - assert len(blocks) == 4 + assert len(blocks[0]) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, computed_blocks) assert not blocks @@ -388,20 +388,20 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): req = make_request("0", list(range(block_size * 30))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 # Just ask for 1 block. blocks = manager.allocate_slots(req, block_size, computed_blocks) req.num_computed_tokens = block_size - assert len(blocks) == 1 + num_preallocated_blocks + assert len(blocks[0]) == 1 + num_preallocated_blocks # Assume all computed. - manager.append_slots(req, block_size * (len(blocks) - 1)) - req.num_computed_tokens = block_size * len(blocks) + manager.append_slots(req, block_size * (len(blocks[0]) - 1)) + req.num_computed_tokens = block_size * len(blocks[0]) # Append 1 block. blocks = manager.append_slots(req, block_size) - assert len(blocks) == 1 + num_preallocated_blocks + assert len(blocks[0]) == 1 + num_preallocated_blocks def test_cache_blocks(): @@ -424,6 +424,7 @@ def test_cache_blocks(): # Block 2: [8, 9, 10, 11] # Block 3: [12, 13] req = make_request("0", list(range(14))) + req.set_kv_block_hashes([[]]) # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] @@ -433,6 +434,7 @@ def test_cache_blocks(): blk_start_idx=0, full_blocks=blocks, prev_block=None, + kv_cache_group_id=0, ) assert len(manager.cached_block_hash_to_block) == 2 @@ -445,6 +447,7 @@ def test_cache_blocks(): blk_start_idx=2, full_blocks=blocks, prev_block=None, + kv_cache_group_id=0, ) assert len(manager.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None @@ -489,26 +492,26 @@ def test_mm_prefix_caching(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - assert len(req0.kv_block_hashes) == 3 - assert req0.kv_block_hashes[0].extra_keys == ("aaa", ) - assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb") - assert req0.kv_block_hashes[2].extra_keys == ("bbb", ) + assert len(req0.kv_block_hashes[0]) == 3 + assert req0.kv_block_hashes[0][0].extra_keys == ("aaa", ) + assert req0.kv_block_hashes[0][1].extra_keys == ("aaa", "bbb") + assert req0.kv_block_hashes[0][2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.append_slots(req0, 5) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks[0]) == 0 # The just completed block should have hashes with extra keys. - assert len(req0.kv_block_hashes) == 4 - assert req0.kv_block_hashes[3].extra_keys == ("ccc", ) + assert len(req0.kv_block_hashes[0]) == 4 + assert req0.kv_block_hashes[0][3].extra_keys == ("ccc", ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -522,7 +525,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(computed_blocks) == 3 + assert len(computed_blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -547,7 +550,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) block_part0 = manager.req_to_blocks[req0.request_id] @@ -555,7 +558,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert computed_blocks == block_part0 + assert computed_blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) block_part1 = manager.req_to_blocks[req1.request_id] @@ -569,7 +572,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, computed_blocks) @@ -579,7 +582,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert manager.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert computed_blocks == block_part1 + assert computed_blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. assert manager.allocate_slots(req3, 48, computed_blocks) is None diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 694ce81ff6e22..6beb9500b4eae 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -166,7 +166,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): max_num_blocks_per_req=10, device=torch.device(device), pin_memory=is_pin_memory_available(), - vocab_size=1024) + vocab_size=1024, + num_kv_cache_groups=1) reqs: List[CachedRequestState] = [] req_id_reqs = {} req_id_output_token_ids = {} diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 85e7946edab5e..a46a1eac61c65 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -195,7 +195,7 @@ def allocate_slots( self, request: Request, num_tokens: int, - computed_blocks_of_groups: ReqKVCacheBlocks, + computed_blocks_all_groups: ReqKVCacheBlocks, ) -> Optional[ReqKVCacheBlocks]: """Allocate slots for a new request. @@ -213,7 +213,7 @@ def allocate_slots( raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") - computed_blocks = computed_blocks_of_groups[0] # only one group + computed_blocks = computed_blocks_all_groups[0] # only one group # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. @@ -253,7 +253,7 @@ def allocate_slots( self.req_to_blocks[request.request_id] = computed_blocks + new_blocks if not self.enable_caching: - return new_blocks + return [new_blocks] num_computed_tokens = len(computed_blocks) * self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b0a7affbebb7e..e68f0be069caa 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -77,6 +77,9 @@ def __contains__(self, item): def __len__(self): return len(self._x) + def __repr__(self): + return "ConstantList(" + repr(self._x) + ")" + class BackgroundProcHandle: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 60017e41ed649..6494f86bf4177 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -214,7 +214,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: continue self.input_batch.block_table.append_row(req_index, req_data.new_block_ids) - for group_id, new_block_ids in enumerate(new_block_ids): + for group_id, new_block_ids in enumerate(req_data.new_block_ids): req_state.block_ids[group_id].extend(new_block_ids) req_ids_to_add: List[str] = [] From e46fff56faa7c5d2f4bfee9fadee18af5b671406 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 15 Jan 2025 05:17:14 -0800 Subject: [PATCH 03/48] format Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 5 ++++- vllm/v1/core/kv_cache_utils.py | 7 ++++--- vllm/v1/core/scheduler.py | 10 +++++----- vllm/v1/request.py | 6 +++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index a46a1eac61c65..08a056cd88b3c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -15,6 +15,9 @@ class KVCacheManager: + """ + TODO: add notes about num_group=1 + """ def __init__( self, @@ -297,7 +300,7 @@ def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> int: + ) -> List[int]: """Calculate the number of common prefix blocks shared by all requests in the RUNNING state. diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a9cd1d2255c96..d2e225eeb5013 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, List, NamedTuple, Optional, Tuple + from vllm.logger import init_logger from vllm.v1.request import Request @@ -61,11 +62,11 @@ def reset_hash(self): """When the model contains different types of layers (e.g., full attention + -sliding window attention), the layers will be splited to multiple groups, where -layers in the same group has the same type and with the same KVCacheBlock. +sliding window attention), the layers will be split to multiple groups, where +layers in the same group has the same type and with the same KVCacheBlock. +See KVCacheConfig class for more details of "group". KVCacheBlocks: the blocks in one (group) of layer in one request ReqKVCacheBlocks: the blocks in all groups of layers in one request. -Refer to KVCacheConfig class for the meaning of "group" """ KVCacheBlocks = List[KVCacheBlock] ReqKVCacheBlocks = List[KVCacheBlocks] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 574d128d396a4..0f89cd6a147c8 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -308,7 +308,7 @@ def schedule(self) -> "SchedulerOutput": def _make_running_request_data( self, request: Request, - new_block_ids: List[int], + new_block_ids: List[List[int]], num_computed_tokens: int, ) -> "RunningRequestData": # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating @@ -539,7 +539,7 @@ class NewRequestData: mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams # List of block IDs for each group. - # Refer to KVCacheConfig class for the meaning of "group" + # See KVCacheConfig class for the meaning of "group". block_ids: List[List[int]] num_computed_tokens: int @@ -568,7 +568,7 @@ class ResumedRequestData: req_id: str # List of block IDs for each kv cache group. - # Refer to KVCacheConfig class for the meaning of "group" + # See KVCacheConfig class for the meaning of "group". block_ids: List[List[int]] num_computed_tokens: int @@ -591,7 +591,7 @@ class RunningRequestData: req_id: str # List of block IDs for each kv cache group. - # Refer to KVCacheConfig class for the meaning of "group" + # See KVCacheConfig class for the meaning of "group". new_block_ids: List[List[int]] num_computed_tokens: int @@ -620,7 +620,7 @@ class SchedulerOutput: total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] # Number of common prefix blocks per kv cache group - # Refer to KVCacheConfig class for the meaning of "group" + # See KVCacheConfig class for the meaning of "group" num_common_prefix_blocks: List[int] preempted_req_ids: Set[str] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 097b2ffe7c831..9d79c13f315fb 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -63,9 +63,9 @@ def __init__( # Cache the computed kv block hashes of the request to avoid # recomputing. # Different kv cache groups may have different block_size, so save their - # hash seperately. Each outer list represents a group, and each inner + # hash separately. Each outer list represents a group, and each inner # list contains the hashes of blocks with that group's block_size. - # Refer to KVCacheConfig class for the meaning of "group". + # See KVCacheConfig class for the meaning of "group". self._kv_block_hashes: List[List[BlockHashType]] = [] @classmethod @@ -131,7 +131,7 @@ def get_num_encoder_tokens(self, input_id: int) -> int: return num_tokens @property - def kv_block_hashes(self) -> List[ConstantList["BlockHashType"]]: + def kv_block_hashes(self) -> ConstantList[ConstantList["BlockHashType"]]: # Prevent directly appending to the kv_block_hashes. return ConstantList([ ConstantList(kv_block_hashes_one_group) From 36a649a0b600c324b0efe3252fab90d07275c657 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 15 Jan 2025 06:45:53 -0800 Subject: [PATCH 04/48] fix bug Signed-off-by: Chen Zhang --- vllm/v1/worker/block_table.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 31b9efb361166..07ef26d26c886 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -54,6 +54,7 @@ def append_row( self.num_blocks_per_row[i, row_idx] = num_blocks + num_new_blocks def add_row(self, row_idx: int, block_ids: List[List[int]]) -> None: + self.num_blocks_per_row[:, row_idx] = 0 self.append_row(row_idx, block_ids) def move_row(self, src: int, tgt: int) -> None: From 9c36e7d8f4a970037a9939c5e6ca2a4debce660c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 15 Jan 2025 07:40:25 -0800 Subject: [PATCH 05/48] add comments Signed-off-by: Chen Zhang --- vllm/forward_context.py | 1 + vllm/v1/core/kv_cache_manager.py | 31 +++++++++++++++++------------- vllm/v1/core/kv_cache_utils.py | 14 ++++++++------ vllm/v1/core/scheduler.py | 9 +++++---- vllm/v1/request.py | 9 +++++---- vllm/v1/worker/block_table.py | 5 ++++- vllm/v1/worker/gpu_input_batch.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 3 ++- 8 files changed, 45 insertions(+), 30 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c34694790b8d9..536509db5ab60 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -38,6 +38,7 @@ class ForwardContext: attn_metadata: Union["AttentionMetadata", Dict[str, "AttentionMetadata"]] """ The virtual_engine for v0 pipeline parallelism + set dynamically for each forward pass """ virtual_engine: int # set dynamically for each forward pass diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 08a056cd88b3c..5297f79b7d476 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -16,7 +16,9 @@ class KVCacheManager: """ - TODO: add notes about num_group=1 + The KVCacheManager for models with one KV cache type (e.g., Llama) and + thus one kv cache group (Refer to class `KVCacheConfig` for the meaning of + kv cache group). """ def __init__( @@ -70,8 +72,10 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request - # is finished. KVCacheManager only supports models with one layer type, - # so the blocks can be stored by KVCacheBlocks type. + # is finished. + # KVCacheManager only supports models with one kv cache group, so we + # save KVCachedBlocks of that group instead of ReqKVCacheBlocks for + # simplicity. self.req_to_blocks: Dict[str, KVCacheBlocks] = {} def get_computed_blocks(self, @@ -83,9 +87,8 @@ def get_computed_blocks(self, request: The request to get the computed blocks. Returns: - # TODO: update docstring A tuple containing: - - A list of blocks that are computed for the request. + - The blocks that are computed for the request - The number of computed tokens. """ if not self.enable_caching: @@ -130,9 +133,8 @@ def append_slots( num_tokens: The number of tokens to append. Returns: - A list of new blocks if new blocks are allocated, or None - if new blocks are required but cannot be allocated. - # TODO: update docstring + The new blocks if new blocks are allocated, or None if new blocks + are required but cannot be allocated. """ num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, self.block_size) @@ -206,17 +208,19 @@ def allocate_slots( request: The request to allocate slots. num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. - computed_blocks: A list of computed blocks. + computed_blocks_all_groups: The computed blocks. Should contain + only one KV cache group. Returns: - # TODO: update docstring - A list of new allocated blocks. + The new blocks if new blocks are allocated, or None if new blocks + are required but cannot be allocated. """ if num_tokens == 0: raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") - computed_blocks = computed_blocks_all_groups[0] # only one group + assert len(computed_blocks_all_groups) == 1 + computed_blocks = computed_blocks_all_groups[0] # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. @@ -334,7 +338,7 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - int: The number of common prefix blocks. + List[int]: The number of common prefix blocks per KV cache group. """ assert request.status == RequestStatus.RUNNING blocks = self.req_to_blocks[request.request_id] @@ -448,6 +452,7 @@ def _cache_full_blocks( to cache. full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. + kv_cache_group_id: The KV cache group ID that the blocks belong to """ num_cached_block_hashes = len( request.kv_block_hashes[kv_cache_group_id]) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index d2e225eeb5013..12066a1370ac2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -61,12 +61,14 @@ def reset_hash(self): self._block_hash = None -"""When the model contains different types of layers (e.g., full attention + -sliding window attention), the layers will be split to multiple groups, where -layers in the same group has the same type and with the same KVCacheBlock. -See KVCacheConfig class for more details of "group". -KVCacheBlocks: the blocks in one (group) of layer in one request -ReqKVCacheBlocks: the blocks in all groups of layers in one request. +"""When a model needs different types of kv_caches (e.g., full attention + +sliding window attention), the attention layers will be split to multiple +"KV cache groups", where layers in the same group has the same kv cache type and +can use the same KVCacheBlock. There will be only one group if all layers use +the same type of KV cache. +See KVCacheConfig class for more examples of "KV cache group". +KVCacheBlocks: the blocks of one group of layer in one request +ReqKVCacheBlocks: the blocks of all groups of layers in one request. """ KVCacheBlocks = List[KVCacheBlock] ReqKVCacheBlocks = List[KVCacheBlocks] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 0f89cd6a147c8..04d78158e029d 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -95,6 +95,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] + # Request id -> List of block IDs for each kv cache group. req_to_new_block_ids: Dict[str, List[List[int]]] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens @@ -538,8 +539,8 @@ class NewRequestData: mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams - # List of block IDs for each group. - # See KVCacheConfig class for the meaning of "group". + # List of block IDs for each KV cache group. + # See KVCacheConfig class for the meaning of "KV cache group". block_ids: List[List[int]] num_computed_tokens: int @@ -568,7 +569,7 @@ class ResumedRequestData: req_id: str # List of block IDs for each kv cache group. - # See KVCacheConfig class for the meaning of "group". + # See KVCacheConfig class for the meaning of "KV cache group". block_ids: List[List[int]] num_computed_tokens: int @@ -591,7 +592,7 @@ class RunningRequestData: req_id: str # List of block IDs for each kv cache group. - # See KVCacheConfig class for the meaning of "group". + # See KVCacheConfig class for the meaning of "KV cache group". new_block_ids: List[List[int]] num_computed_tokens: int diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9d79c13f315fb..0a2ade5e64df8 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -62,10 +62,11 @@ def __init__( # Cache the computed kv block hashes of the request to avoid # recomputing. - # Different kv cache groups may have different block_size, so save their - # hash separately. Each outer list represents a group, and each inner - # list contains the hashes of blocks with that group's block_size. - # See KVCacheConfig class for the meaning of "group". + # Different KV cache groups may have different block_size, so save their + # hash separately. Each element of the outer list represents a group, + # and each inner list contains the hashes of blocks with that group's + # block_size. + # See KVCacheConfig class for the meaning of "KV cache group". self._kv_block_hashes: List[List[BlockHashType]] = [] @classmethod diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 07ef26d26c886..e5e5eff65940e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -17,6 +17,7 @@ def __init__( max_num_blocks_per_req: int, pin_memory: bool, device: torch.device, + # NOTE: See KVCacheConfig class for the meaning of "KV cache group". num_kv_cache_groups: int, ): self.max_num_reqs = max_num_reqs @@ -67,7 +68,9 @@ def commit(self, num_reqs: int) -> None: # NOTE: an alternative is # self.block_table[:, :num_reqs].copy_( # self.block_table_cpu[:, :num_reqs], non_blocking=True) - # but it will be a blocking copy when num_kv_cache_groups>1. + # but it will be a blocking copy when num_kv_cache_groups > 1. + # Can be verified by the following code: + # https://gist.github.com/heheda12345/74c7f7a68e45c242a5c901b5fb77d000 for i in range(self.num_kv_cache_groups): self.block_table[i, :num_reqs].copy_( self.block_table_cpu[i, :num_reqs], non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 4a74bd76841b4..096f6c274e121 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -26,7 +26,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: List[List[int]] + block_ids: List[List[int]] # List of block ids for each kv cache group num_computed_tokens: int output_token_ids: List[int] @@ -45,6 +45,7 @@ def __init__( device: torch.device, pin_memory: bool, vocab_size: int, + # NOTE: See KVCacheConfig class for the meaning of "KV cache group". num_kv_cache_groups: int, ): self.max_num_reqs = max_num_reqs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6494f86bf4177..cf040214f770e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -93,7 +93,7 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model - self.kv_caches: List[List[torch.Tensor]] = [] + self.kv_caches: List[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} @@ -369,6 +369,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True) # layer_name -> AttentionMetadata attn_metadata: Dict[str, FlashAttentionMetadata] = {} + # TODO: update after PR #11960 # for i, layer_ids in enumerate(self.kv_cache_config.groups): for i in range(1): layer_ids = list(self.vllm_config.compilation_config. From da6b549cf8c78810c958ef642a82c29b41dd9510 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 15 Jan 2025 07:50:22 -0800 Subject: [PATCH 06/48] format Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 12066a1370ac2..28f3fe51d02ae 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -63,7 +63,7 @@ def reset_hash(self): """When a model needs different types of kv_caches (e.g., full attention + sliding window attention), the attention layers will be split to multiple -"KV cache groups", where layers in the same group has the same kv cache type and +"KV cache groups", where layers in the same group has the same kv cache type and can use the same KVCacheBlock. There will be only one group if all layers use the same type of KV cache. See KVCacheConfig class for more examples of "KV cache group". From 41bc571d08c0928cb7b970b90ce985e3450ea449 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 19 Jan 2025 19:16:32 -0800 Subject: [PATCH 07/48] update code Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 128 ++++++++++++++++++----------- 1 file changed, 80 insertions(+), 48 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 79012c49d882f..e4b178d283784 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -69,9 +69,7 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -108,16 +106,6 @@ def __init__( # Request states. self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - num_kv_cache_groups=1, # TODO: update after PR #11960 - ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE @@ -162,13 +150,17 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() + + self.kv_cache_config: KVCacheConfig = None # Set by initialize_kv_cache + + # The following 3 variables depends on KVCacheConfig, assign a + # placeholder value here and initialize them in `initialize_kv_cache``. + self.input_batch: InputBatch = None # Persistent batch. self.slot_mapping_cpu = torch.zeros( - 1, # TODO: update after PR #11960 - self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + (1, )) # Real shape: (num_kv_cache_groups, self.max_num_tokens) self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.max_num_blocks_per_req: int = 0 + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -331,12 +323,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Calculate the slot mapping. - # for i in range(len(self.kv_cache_config.groups)): - for i in range(1): # TODO: update after PR #11960 - # group_spec = self.kv_cache_config.kv_cache_spec[ - # self.kv_cache_config.groups[i][0]] - # block_size = group_spec.block_size - block_size = self.block_size + for i in range(len(self.kv_cache_config.groups)): + # the LayerSpec of all layers in the group is the same. Take the + # first one. + group_spec = self.kv_cache_config.kv_cache_spec[ + self.kv_cache_config.groups[i][0]] + block_size = group_spec.block_size # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. @@ -378,18 +370,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True) # layer_name -> AttentionMetadata attn_metadata: Dict[str, FlashAttentionMetadata] = {} - # TODO: update after PR #11960 - # for i, layer_ids in enumerate(self.kv_cache_config.groups): - for i in range(1): - layer_ids = list(self.vllm_config.compilation_config. - static_forward_context.keys()) + for i, layer_ids in enumerate(self.kv_cache_config.groups): + block_size = self.kv_cache_config.kv_cache_spec[ + layer_ids[0]].block_size slot_mapping = self.slot_mapping_cpu[ i, :total_num_scheduled_tokens].to(self.device, non_blocking=True).long() # Prepare for cascade attention if needed. common_prefix_len = (scheduler_output.num_common_prefix_blocks[i] * - self.block_size) + block_size) if common_prefix_len == 0: # Common case. use_cascade = False @@ -437,8 +427,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) + common_prefix_len = (common_prefix_len // block_size * + block_size) use_cascade = FlashAttentionBackend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, @@ -604,19 +594,8 @@ def execute_model( # Prepare the decoder inputs. attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = num_scheduled_tokens - - # TODO: update after PR #11960 - attn_metadata[next(iter( - attn_metadata.keys()))].num_input_tokens = num_input_tokens + num_input_tokens = self.maybe_pad_for_cudagraph( + num_scheduled_tokens, attn_metadata) if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision @@ -709,6 +688,28 @@ def execute_model( ) return model_runner_output + def maybe_pad_for_cudagraph( + self, num_scheduled_tokens: int, + attn_metadata: Dict[str, FlashAttentionMetadata]) -> int: + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = num_scheduled_tokens + + # update num_input_tokens in attn_metadata + for layer_names in self.kv_cache_config.groups: + layer_name = layer_names[0] + # All layers in the group share the same attn_metadata object. + # Only need to update the num_input_tokens once. + attn_metadata[layer_name].num_input_tokens = num_input_tokens + + return num_input_tokens + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 @@ -892,10 +893,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") + self.kv_cache_config = kv_cache_config kv_caches: Dict[str, torch.Tensor] = {} @@ -919,6 +917,40 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + self._initialize_kv_related_buffers(kv_cache_config) + + def _initialize_kv_related_buffers(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Initialize data structures (e.g., InputBatch, slot mappings) that depend + on the kv cache configuration. + + Args: + kv_cache_config (KVCacheConfig): Configuration for the KV cache + """ + num_kv_cache_groups = len(kv_cache_config.groups) + + min_block_size = min( + spec.block_size for spec in kv_cache_config.kv_cache_spec.values()) + self.max_num_blocks_per_req = cdiv(self.max_model_len, min_block_size) + + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.vllm_config.model_config.get_vocab_size(), + num_kv_cache_groups=num_kv_cache_groups, + ) + + self.slot_mapping_cpu = torch.zeros(num_kv_cache_groups, + self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + def get_kv_cache_spec(self) -> KVCacheSpec: """ Generates the KVCacheSpec by parsing the kv cache format from each From 34c9d7482259811c58857fe0c453b3f4e46f7652 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 19 Jan 2025 20:17:38 -0800 Subject: [PATCH 08/48] can run Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 70 ++++++++++++++++++++--------- vllm/v1/executor/abstract.py | 4 +- vllm/v1/kv_cache_interface.py | 22 +++++---- vllm/v1/worker/gpu_model_runner.py | 72 +++++++++++++++--------------- vllm/v1/worker/gpu_worker.py | 4 +- 5 files changed, 103 insertions(+), 69 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b73e6338c0890..3251d75801d80 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,12 +1,12 @@ """KV-Cache Utilities.""" from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, - KVCacheTensor) +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroup, + KVCacheSpec, KVCacheTensor) from vllm.v1.request import Request logger = init_logger(__name__) @@ -324,7 +324,7 @@ def hash_request_tokens(block_size: int, def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int): """ Checks whether `available_memory` is enough for the KV cache to hold at @@ -332,7 +332,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Raises: @@ -359,12 +359,12 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") -def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: +def is_kv_cache_type_uniform(kv_cache_spec: Dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same type of KV cache. Args: - kv_cache_spec: The KVCacheSpec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model Returns: True if all layers have the same type, False otherwise. @@ -374,8 +374,36 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: return len(layer_keys) == 1 +def _create_kv_cache_groups( + kv_cache_spec: Dict[str, KVCacheSpec], + grouped_layers: List[List[str]]) -> List[KVCacheGroup]: + """ + Create KVCacheGroup objects for each group of layers. + The layers in one group should share the same KVCacheSpec. + + Args: + kv_cache_spec (Dict[str, KVCacheSpec]): + A mapping from each layer name to its corresponding KVCacheSpec. + grouped_layers (List[List[str]]): + A list of layer groups, where each element is a list of layer names + that belongs to one group and should share the same KVCacheSpec. + + Returns: + A list of KVCacheGroup objects, one for each group of layers. + """ + kv_cache_groups = [] + for layer_names in grouped_layers: + group_spec = kv_cache_spec[layer_names[0]] + assert all( + kv_cache_spec[layer_name] == group_spec + for layer_name in layer_names[1:]), ( + "All layers in a group must share the same KVCacheSpec.") + kv_cache_groups.append(KVCacheGroup(layer_names, group_spec)) + return kv_cache_groups + + def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. @@ -383,7 +411,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Returns: @@ -408,19 +436,21 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, logger.info("# GPU blocks: %d", num_blocks) per_layer_size = page_size * num_blocks - - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - tensors={ - layer_name: KVCacheTensor(size=per_layer_size) - for layer_name in kv_cache_spec - }, - groups=[[layer_name for layer_name in kv_cache_spec]], - kv_cache_spec=kv_cache_spec) + layers_of_group = [[layer_name for layer_name in kv_cache_spec]] + + kv_cache_config = KVCacheConfig(num_blocks=num_blocks, + tensors={ + layer_name: + KVCacheTensor(size=per_layer_size) + for layer_name in kv_cache_spec + }, + groups=_create_kv_cache_groups( + kv_cache_spec, layers_of_group)) return kv_cache_config -def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, +def get_kv_cache_config(vllm_config: VllmConfig, + kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model @@ -428,7 +458,7 @@ def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Returns: diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 131be759842c7..ebedb404b8efb 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Dict, Type from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase @@ -62,7 +62,7 @@ def determine_available_memory(self) -> int: # in bytes # operators can be applied to all workers. return min(output) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: output = self.collective_rpc("get_kv_cache_spec") for x in output: assert x == output[0] diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6d5cc32ffc5b8..3f6dffacd0ca3 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,7 +10,7 @@ @dataclass -class KVCacheSpecBase: +class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. """ @@ -54,7 +54,7 @@ def bytes_for_tokens(self, num_tokens: int) -> int: @dataclass -class FullAttentionSpec(KVCacheSpecBase): +class FullAttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype @@ -72,9 +72,6 @@ def bytes_for_tokens(self, num_tokens: int) -> int: return cdiv(num_tokens, self.block_size) * self.page_size_bytes -KVCacheSpec = Dict[str, KVCacheSpecBase] - - @dataclass class KVCacheTensor: """ @@ -85,6 +82,17 @@ class KVCacheTensor: size: int # The size of KV cache Tensor in bytes +@dataclass +class KVCacheGroup: + """ + A dataclass for specifying the KV cache group of a model. + """ + # The names of layers in this group + layer_names: List[str] + # The KV cache spec of this group + kv_cache_spec: KVCacheSpec + + @dataclass class KVCacheConfig: """ @@ -106,6 +114,4 @@ class KVCacheConfig: 3. (not implemented yet) A model with 2 full attention layers and 4 sliding window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). """ - groups: List[List[str]] - """the KVCacheSpec of the model""" - kv_cache_spec: KVCacheSpec + groups: List[KVCacheGroup] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9176701a6dcbc..c343d7c4023ad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -178,11 +178,12 @@ def __init__( pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.kv_cache_config: KVCacheConfig = None # Set by initialize_kv_cache + self.kv_cache_config = cast(KVCacheConfig, + None) # Set by initialize_kv_cache # The following 3 variables depends on KVCacheConfig, assign a # placeholder value here and initialize them in `initialize_kv_cache``. - self.input_batch: InputBatch = None # Persistent batch. + self.input_batch = cast(InputBatch, None) # Persistent batch. self.slot_mapping_cpu = torch.zeros( (1, )) # Real shape: (num_kv_cache_groups, self.max_num_tokens) self.slot_mapping_np = self.slot_mapping_cpu.numpy() @@ -384,12 +385,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Calculate the slot mapping. - for i in range(len(self.kv_cache_config.groups)): - # the LayerSpec of all layers in the group is the same. Take the - # first one. - group_spec = self.kv_cache_config.kv_cache_spec[ - self.kv_cache_config.groups[i][0]] - block_size = group_spec.block_size + for i, kv_cache_group in enumerate(self.kv_cache_config.groups): + block_size = kv_cache_group.kv_cache_spec.block_size # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. @@ -439,12 +436,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True) # layer_name -> AttentionMetadata attn_metadata: Dict[str, FlashAttentionMetadata] = {} - for i, layer_ids in enumerate(self.kv_cache_config.groups): - block_size = self.kv_cache_config.kv_cache_spec[ - layer_ids[0]].block_size + for group_id, kv_cache_group in enumerate(self.kv_cache_config.groups): + block_size = kv_cache_group.kv_cache_spec.block_size slot_mapping = self.slot_mapping_cpu[ - i, :total_num_scheduled_tokens].to(self.device, - non_blocking=True).long() + group_id, :total_num_scheduled_tokens].to( + self.device, non_blocking=True).long() # Prepare for cascade attention if needed. common_prefix_len = (scheduler_output.num_common_prefix_blocks[i] * @@ -543,8 +539,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): cu_suffix_kv_lens=cu_suffix_kv_lens, ) - for layer_id in layer_ids: - attn_metadata[layer_id] = attn_metadata_of_group + for layer_name in kv_cache_group.layer_names: + attn_metadata[layer_name] = attn_metadata_of_group # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled @@ -829,8 +825,8 @@ def maybe_pad_for_cudagraph( num_input_tokens = num_scheduled_tokens # update num_input_tokens in attn_metadata - for layer_names in self.kv_cache_config.groups: - layer_name = layer_names[0] + for kv_cache_group in self.kv_cache_config.groups: + layer_name = kv_cache_group.layer_names[0] # All layers in the group share the same attn_metadata object. # Only need to update the num_input_tokens once. attn_metadata[layer_name].num_input_tokens = num_input_tokens @@ -1027,20 +1023,22 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_caches: Dict[str, torch.Tensor] = {} - for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % layer_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // layer_spec.page_size_bytes - if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( - num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, - layer_spec.head_size) - dtype = layer_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - else: - raise NotImplementedError + for kv_cache_group in kv_cache_config.groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, FullAttentionSpec): + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + kv_caches[layer_name] = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + else: + raise NotImplementedError bind_kv_cache( kv_caches, @@ -1052,16 +1050,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def _initialize_kv_related_buffers(self, kv_cache_config: KVCacheConfig) -> None: """ - Initialize data structures (e.g., InputBatch, slot mappings) that depend - on the kv cache configuration. + Initialize data structures (e.g., InputBatch, slot mappings) that + depend on the kv cache configuration. Args: kv_cache_config (KVCacheConfig): Configuration for the KV cache """ num_kv_cache_groups = len(kv_cache_config.groups) - min_block_size = min( - spec.block_size for spec in kv_cache_config.kv_cache_spec.values()) + min_block_size = min(group.kv_cache_spec.block_size + for group in kv_cache_config.groups) self.max_num_blocks_per_req = cdiv(self.max_model_len, min_block_size) self.input_batch = InputBatch( @@ -1081,7 +1079,7 @@ def _initialize_kv_related_buffers(self, pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. @@ -1092,7 +1090,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: KVCacheSpec = {} + kv_cache_spec: Dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4fb4197f1822f..4f299029514ee 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Optional import torch import torch.distributed @@ -162,7 +162,7 @@ def determine_available_memory(self) -> int: return int(available_kv_cache_memory) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: From cfcf2b4b45828453295d7d0e845755793ada12c1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 19 Jan 2025 23:24:19 -0800 Subject: [PATCH 09/48] update comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 2 +- vllm/v1/core/kv_cache_utils.py | 4 ++-- vllm/v1/worker/block_table.py | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 5297f79b7d476..f945a579bba2a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -452,7 +452,7 @@ def _cache_full_blocks( to cache. full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. - kv_cache_group_id: The KV cache group ID that the blocks belong to + kv_cache_group_id: The KV cache group that the blocks belong to """ num_cached_block_hashes = len( request.kv_block_hashes[kv_cache_group_id]) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3251d75801d80..3706146b70d8f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -436,7 +436,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, logger.info("# GPU blocks: %d", num_blocks) per_layer_size = page_size * num_blocks - layers_of_group = [[layer_name for layer_name in kv_cache_spec]] + grouped_layers = [[layer_name for layer_name in kv_cache_spec]] kv_cache_config = KVCacheConfig(num_blocks=num_blocks, tensors={ @@ -445,7 +445,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, for layer_name in kv_cache_spec }, groups=_create_kv_cache_groups( - kv_cache_spec, layers_of_group)) + kv_cache_spec, grouped_layers)) return kv_cache_config diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index e5e5eff65940e..4f4a72501bcf3 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -27,6 +27,9 @@ def __init__( self.device = device self.num_kv_cache_groups = num_kv_cache_groups + # NOTE: Pad the block table to the max possible number of blocks among + # all KV cache groups. This waste some memory if block_size of the + # groups differ. self.block_table = torch.zeros( (num_kv_cache_groups, max_num_reqs, max_num_blocks_per_req), device=self.device, From 4898973a7f93b06ee1077ba10ad8827bddfd6731 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 21 Jan 2025 06:26:58 -0800 Subject: [PATCH 10/48] init kv cache for group allocation Signed-off-by: Chen Zhang --- vllm/config.py | 2 +- vllm/v1/core/kv_cache_utils.py | 68 +++++++++++++++++++++- vllm/v1/kv_cache_interface.py | 44 +++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 90 +++++++++++++++++++++--------- 4 files changed, 172 insertions(+), 32 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4698a05020332..03b4b86ce1f63 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1040,7 +1040,7 @@ def _verify_prefix_caching(self) -> None: if not self.enable_prefix_caching: return - if self.sliding_window is not None: + if self.sliding_window is not None and not envs.VLLM_USE_V1: raise NotImplementedError( "Prefix caching is not supported with sliding window. " "Run with --disable-sliding-window to use prefix caching.") diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3706146b70d8f..395b74045e527 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,12 +1,15 @@ """KV-Cache Utilities.""" +from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass +import math from typing import Any, Dict, List, NamedTuple, Optional, Tuple from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroup, - KVCacheSpec, KVCacheTensor) + KVCacheNewTensor, KVCacheReuseTensor, + KVCacheSpec) from vllm.v1.request import Request logger = init_logger(__name__) @@ -374,6 +377,21 @@ def is_kv_cache_type_uniform(kv_cache_spec: Dict[str, KVCacheSpec]) -> bool: return len(layer_keys) == 1 +def is_kv_cache_page_size_uniform(kv_cache_spec: KVCacheSpec): + """ + Whether all layers in the given KVCacheSpec have the same page size. + + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + True if all layers have the same page size, False otherwise. + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + return len(page_sizes) == 1 + + def _create_kv_cache_groups( kv_cache_spec: Dict[str, KVCacheSpec], grouped_layers: List[List[str]]) -> List[KVCacheGroup]: @@ -441,7 +459,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_config = KVCacheConfig(num_blocks=num_blocks, tensors={ layer_name: - KVCacheTensor(size=per_layer_size) + KVCacheNewTensor(size=per_layer_size) for layer_name in kv_cache_spec }, groups=_create_kv_cache_groups( @@ -449,6 +467,47 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config +def _get_kv_cache_config_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: Dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + # Grouped allocation + # TODO(Chen): explain it, need test + + # Group all layers by type_id + same_type_layers: Dict[str, List[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec.type_id].append(layer_name) + + # Split each group into smaller groups, to make the number of layers in + # each group identical + # E.g., 2 full attention layers and 4 sliding window attention layers, + # split from (full * 2), (sw * 4) to (full * 2), (sw * 2), (sw * 2). + group_size_gcd = math.gcd( + *[len(layers) for layers in same_type_layers.values()]) + grouped_layers = [] + for layers in same_type_layers.values(): + for i in range(0, len(layers), group_size_gcd): + grouped_layers.append(layers[i:i + group_size_gcd]) + + # TODO: explain it + kv_cache_spec_first_group = { + layer_name: kv_cache_spec[layer_name] + for layer_name in grouped_layers[0] + } + kv_cache_config = _get_kv_cache_config_uniform_type( + vllm_config, kv_cache_spec_first_group, available_memory) + + for layers in grouped_layers[1:]: + for layer_name, layer_name_first_group in zip(layers, + grouped_layers[0]): + kv_cache_config.tensors[layer_name] = KVCacheReuseTensor( + reused_layer_name=layer_name_first_group) + + kv_cache_config.groups = _create_kv_cache_groups(kv_cache_spec, + grouped_layers) + return kv_cache_config + + def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: @@ -470,5 +529,10 @@ def get_kv_cache_config(vllm_config: VllmConfig, # Allocate the same amount of memory for each layer. return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # TODO: add comments + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) else: raise NotImplementedError diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 3f6dffacd0ca3..2776d2c74e3a7 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -73,15 +73,51 @@ def bytes_for_tokens(self, num_tokens: int) -> int: @dataclass -class KVCacheTensor: +class SlidingWindowSpec(KVCacheSpec): + num_kv_heads: int + head_size: int + dtype: torch.dtype + sliding_window: int + + @property + def type_id(self) -> str: + return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa + + @property + def page_size_bytes(self) -> int: + return 2 * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + + def bytes_for_tokens(self, num_tokens: int) -> int: + num_tokens = min(num_tokens, self.sliding_window) + return cdiv(num_tokens, self.block_size) * self.page_size_bytes + + +@dataclass +class KVCacheTensorBase: """ A dataclass for specifying how the workers should initialize the KV cache - for a layer. Only contains the size of KV cache for that layer for now. Will - be extended to support multiple layers sharing the same memory pool. + for a layer. + """ + pass + + +@dataclass +class KVCacheNewTensor(KVCacheTensorBase): + """ + Initialize the KV cache with a tensor of `size` bytes. """ size: int # The size of KV cache Tensor in bytes +@dataclass +class KVCacheReuseTensor(KVCacheTensorBase): + """ + Reuse the KV cache tensor of `layer_name` for the current layer. + """ + reused_layer_name: str + + @dataclass class KVCacheGroup: """ @@ -101,7 +137,7 @@ class KVCacheConfig: """The number of KV cache blocks""" num_blocks: int """layer_name -> how to initialize KV cache for that layer""" - tensors: Dict[str, KVCacheTensor] + tensors: Dict[str, KVCacheTensorBase] """ A list of kv-cache groups. Each group includes a set of layers with the same kv-cache spec, and the total page_size of layers inside a group diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c343d7c4023ad..d1290d944cc00 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -25,7 +25,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) + KVCacheNewTensor, KVCacheReuseTensor, + KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache @@ -1012,34 +1013,65 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: - """ - Initialize KV cache based on `kv_cache_config`. - Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer - """ - self.kv_cache_config = kv_cache_config - + def _initialize_kv_cache_buffer( + self, kv_cache_config: KVCacheConfig) -> Dict[str, torch.Tensor]: + # TODO: add docstring + kv_cache_raw_tensors: Dict[str, torch.Tensor] = {} + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheNewTensor): + # A new tensor with `tensor_config.size` bytes + kv_cache_raw_tensors[layer_name] = torch.zeros( + tensor_config.size, dtype=torch.int8, device=self.device) + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheReuseTensor): + # Reuse the tensor from `kv_cache_raw_tensors` + kv_cache_raw_tensors[layer_name] = kv_cache_raw_tensors[ + tensor_config.reused_layer_name] + assert len(kv_cache_raw_tensors) == len( + kv_cache_config.tensors), "Some layers are not initialized" + return kv_cache_raw_tensors + + def _setup_kv_cache_shapes( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + # TODO: add docstring kv_caches: Dict[str, torch.Tensor] = {} - for kv_cache_group in kv_cache_config.groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, FullAttentionSpec): + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel( + ) // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, + (FullAttentionSpec, SlidingWindowSpec)): kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape) else: raise NotImplementedError + return kv_caches + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + # TODO: two "buffer" is confusing + """ + self.kv_cache_config = kv_cache_config + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._initialize_kv_cache_buffer( + kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._setup_kv_cache_shapes(kv_cache_config, + kv_cache_raw_tensors) bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, @@ -1096,12 +1128,21 @@ def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: # cross-attention, MLA. assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=attn_module.dtype, - ) + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + sliding_window=attn_module.sliding_window, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. @@ -1111,5 +1152,4 @@ def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") - return kv_cache_spec From ef9dc9d1cca6edb10f0eb9abd7c4effdbf17e8c1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 21 Jan 2025 08:51:19 -0800 Subject: [PATCH 11/48] can run, result a little strange Signed-off-by: Chen Zhang --- .../specialized_manager.py | 152 ++++++ vllm/v1/core/hybrid_cache_manager/utils.py | 67 +++ vllm/v1/core/kv_cache_manager.py | 470 ++++++++++++------ vllm/v1/core/kv_cache_utils.py | 21 +- vllm/v1/core/scheduler.py | 7 +- vllm/v1/engine/core.py | 17 +- 6 files changed, 561 insertions(+), 173 deletions(-) create mode 100644 vllm/v1/core/hybrid_cache_manager/specialized_manager.py create mode 100644 vllm/v1/core/hybrid_cache_manager/utils.py diff --git a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py new file mode 100644 index 0000000000000..59324b98422e3 --- /dev/null +++ b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py @@ -0,0 +1,152 @@ +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, TypedDict +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.hybrid_cache_manager.utils import ComputedTokenRange, ComputedTokens +from vllm.v1.utils import ConstantList + + +@dataclass +class MemoryPoolOperations: + get_cached_block: Callable[[BlockHashType], Optional[KVCacheBlock]] + get_null_block: Callable[[], KVCacheBlock] + + +class SpecializedManager(ABC): + block_size: int + max_num_blocks_per_req: int + + def __init__( + self, + layer_spec: KVCacheSpec, + memory_pool_operations: MemoryPoolOperations, + ) -> None: + self.block_size = layer_spec.block_size + self.memory_pool_operations = memory_pool_operations + + @abstractmethod + def get_computed_blocks_and_tokens( + self, block_hashes: ConstantList[BlockHashType] + ) -> Tuple[List[KVCacheBlock], ComputedTokens]: + raise NotImplementedError + + @abstractmethod + def get_num_new_blocks(self, num_computed_tokens: int, + num_append_tokens: int, + num_allocated_blocks: int) -> int: + raise NotImplementedError + + @abstractmethod + def remove_dropped_blocks(self, block_table: List[KVCacheBlock], + num_computed_tokens: int): + # update block_table inplace + raise NotImplementedError + + +class FullAttentionManager(SpecializedManager): + + def get_computed_blocks_and_tokens( + self, block_hashes: ConstantList[BlockHashType] + ) -> Tuple[List[KVCacheBlock], ComputedTokens]: + computed_blocks: List[KVCacheBlock] = [] + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self.memory_pool_operations.get_cached_block( + block_hash): + computed_blocks.append(cached_block) + else: + break + if len(computed_blocks) == 0: + return [], [] + else: + return [ + ComputedTokenRange(0, + len(computed_blocks) * self.block_size) + ], computed_blocks + + def get_num_new_blocks(self, num_computed_tokens: int, + num_append_tokens: int, + num_allocated_blocks: int) -> int: + num_required_blocks = cdiv(num_computed_tokens + num_append_tokens, + self.block_size) + num_new_blocks = num_required_blocks - num_allocated_blocks + return num_new_blocks + + def remove_dropped_blocks(self, block_table: List[KVCacheBlock], + num_computed_tokens: int) -> List[KVCacheBlock]: + return [] + + +class SlidingWindowManager(FullAttentionManager): + + def __init__(self, layer_spec: SlidingWindowSpec, + memory_pool_operations: MemoryPoolOperations): + super().__init__(layer_spec, memory_pool_operations) + # +1 due to not aligned + self.num_block_sliding_window = cdiv(layer_spec.sliding_window, + self.block_size) + 1 + self._null_block = memory_pool_operations.get_null_block() + + def get_computed_blocks_and_tokens( + self, block_hashes: ConstantList[BlockHashType] + ) -> Tuple[List[KVCacheBlock], ComputedTokens]: + # TODO: check the hit every num_block_sliding_window blocks, to optimize + # the time complexity from O(num_block) to + # O(num_block / num_block_sliding_window) + O(num_computed_block), + # which is good for low cache hit rate senarios. + start = 0 + ranges = [] + computed_blocks: List[KVCacheBlock] = [] + + for i, block_hash in enumerate(block_hashes): + if cached_block := self.memory_pool_operations.get_cached_block( + block_hash): + computed_blocks.append(cached_block) + else: + if start == 0: + ranges.append( + ComputedTokenRange(start * self.block_size, + i * self.block_size)) + elif i - start >= self.num_block_sliding_window: + ranges.append((ComputedTokenRange( + (start + self.num_block_sliding_window) * + self.block_size, i * self.block_size))) + computed_blocks.append( + self.memory_pool_operations.get_null_block()) + start = i + 1 + return ranges, computed_blocks + + def remove_dropped_blocks(self, block_table: List[KVCacheBlock], + num_computed_tokens: int) -> List[KVCacheBlock]: + num_block_should_free = cdiv(num_computed_tokens, self.block_size) - \ + self.num_block_sliding_window + removed_blocks = deque() + for i in range(num_block_should_free - 1, -1, -1): + if block_table[i] == self._null_block: + break + removed_blocks.appendleft(block_table[i]) + block_table[i] = self._null_block + return removed_blocks + + +spec_manager_map = { + FullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager +} + + +def get_managers( + kv_cache_config: KVCacheConfig, + memory_pool_operations: MemoryPoolOperations +) -> List[SpecializedManager]: + managers: List[SpecializedManager] = [] + for g in kv_cache_config.groups: + manager_class = spec_manager_map[type(g.kv_cache_spec)] + manager = manager_class(g.kv_cache_spec, memory_pool_operations) + managers.append(manager) + return managers diff --git a/vllm/v1/core/hybrid_cache_manager/utils.py b/vllm/v1/core/hybrid_cache_manager/utils.py new file mode 100644 index 0000000000000..b22a3e783fb0a --- /dev/null +++ b/vllm/v1/core/hybrid_cache_manager/utils.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class ComputedTokenRange: + """ + [start, end) + """ + start: int + end: int + + +ComputedTokens = List[ComputedTokenRange] + + +def intersect_two_ranges( + a: List[ComputedTokenRange], + b: List[ComputedTokenRange]) -> List[ComputedTokenRange]: + """ + Intersect two sorted lists of ComputedTokenRange intervals. + + Args: + a: List of intervals + b: List of intervals + Returns: + List of intervals that are intersections of a and b + """ + i, j = 0, 0 + result = [] + + while i < len(a) and j < len(b): + overlap_start = max(a[i].start, b[j].start) + overlap_end = min(a[i].end, b[j].end) + + if overlap_start <= overlap_end: + result.append(ComputedTokenRange(overlap_start, overlap_end)) + + if a[i].end < b[j].end: + i += 1 + else: + j += 1 + + return result + + +def intersect_ranges( + ranges: List[List[ComputedTokenRange]]) -> List[ComputedTokenRange]: + """ + Intersect multiple lists of ComputedTokenRange intervals, each is sorted. + + Args: + ranges: A list of lists of intervals + Returns: + A list of intervals representing the intersection of all ranges + """ + if not ranges: + return [] + + current_intersection = ranges[0] + for i in range(1, len(ranges)): + current_intersection = intersect_two_ranges(current_intersection, + ranges[i]) + if not current_intersection: + break + + return current_intersection diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index f945a579bba2a..c0a4c071a013c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,14 +1,18 @@ from collections import defaultdict +import math from typing import Dict, Iterable, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv +from vllm.v1.core.hybrid_cache_manager.specialized_manager import MemoryPoolOperations, get_managers from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, KVCacheBlocks, ReqKVCacheBlocks, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) +from vllm.v1.core.hybrid_cache_manager.utils import ComputedTokens, intersect_ranges +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus logger = init_logger(__name__) @@ -23,18 +27,18 @@ class KVCacheManager: def __init__( self, - block_size: int, - num_gpu_blocks: int, + kv_cache_config: KVCacheConfig, max_model_len: int, - sliding_window: Optional[int] = None, enable_caching: bool = True, num_preallocate_tokens: int = 64, ) -> None: - self.block_size = block_size - self.num_gpu_blocks = num_gpu_blocks + self.kv_cache_config = kv_cache_config + self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, block_size) - self.sliding_window = sliding_window + self.max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.groups + ] self.enable_caching = enable_caching # NOTE(woosuk): To avoid frequent block allocation, we preallocate some # blocks for each request. For example, when a request reaches the end @@ -46,12 +50,25 @@ def __init__( # the request gets N empty blocks, it starts to use the blocks without # further allocation. When it uses up all the N empty blocks, it gets # N new empty blocks. + # TODO: update comment self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) + # TODO: min or max? + self.num_preallocate_blocks = cdiv( + num_preallocate_tokens, + min(g.kv_cache_spec.block_size for g in kv_cache_config.groups)) + + self._null_block: KVCacheBlock = KVCacheBlock(-1) + + # TODO(Chen): add comments + self.managers = get_managers( + kv_cache_config, + MemoryPoolOperations(get_cached_block=self._get_cached_block, + get_null_block=self.get_null_block), + ) # A Block pool of all kv-cache blocks. self.block_pool: List[KVCacheBlock] = [ - KVCacheBlock(idx) for idx in range(num_gpu_blocks) + KVCacheBlock(idx) for idx in range(self.num_gpu_blocks) ] # Free block queue that constructs and manipulates a doubly linked # list of free blocks (including eviction candidates when caching is @@ -73,10 +90,7 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - # KVCacheManager only supports models with one kv cache group, so we - # save KVCachedBlocks of that group instead of ReqKVCacheBlocks for - # simplicity. - self.req_to_blocks: Dict[str, KVCacheBlocks] = {} + self.req_to_blocks: Dict[str, ReqKVCacheBlocks] = {} def get_computed_blocks(self, request: Request) -> Tuple[ReqKVCacheBlocks, int]: @@ -93,31 +107,44 @@ def get_computed_blocks(self, """ if not self.enable_caching: # Prefix caching is disabled. - return [[]], 0 - - computed_blocks = [] + return [[] for _ in self.managers], 0 # The block hashes for the request may already be computed # if the request was preempted and resumed. if not request.kv_block_hashes: - request.set_kv_block_hashes( - [hash_request_tokens(self.block_size, request)]) - block_hashes = request.kv_block_hashes[0] - - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self._get_cached_block(block_hash): - computed_blocks.append(cached_block) - else: - break - - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size - return [computed_blocks], num_computed_tokens + request.set_kv_block_hashes([ + hash_request_tokens(manager.block_size, request, i) + for i, manager in enumerate(self.managers) + ]) + + computed_blocks: ReqKVCacheBlocks = [] # group_id->[blocks] + computed_tokens: List[ComputedTokens] = [] # group_id->ComputedTokens + block_hashes = request.kv_block_hashes + for i, manager in enumerate(self.managers): + computed_tokens_i, computed_blocks_i = ( + manager.get_computed_blocks_and_tokens(block_hashes[i])) + computed_blocks.append(computed_blocks_i) + computed_tokens.append(computed_tokens_i) + + if len(self.kv_cache_config.groups) == 1: + # If there is only one group, we return the computed blocks and + # tokens directly. + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks[0]) * self.block_size + else: + # find the common cached prefix of all groups. This path also works + # for the single group case, but it is less efficient. + num_computed_tokens = self._get_common_computed_tokens( + computed_tokens) + + for i, manager in enumerate(self.managers): + computed_blocks[i] = computed_blocks[:num_computed_tokens // + manager.block_size] + self._free_blocks_for_sliding_window(computed_blocks, + num_computed_tokens) + return computed_blocks, num_computed_tokens def append_slots( self, @@ -136,71 +163,90 @@ def append_slots( The new blocks if new blocks are allocated, or None if new blocks are required but cannot be allocated. """ - num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, - self.block_size) + # we can free blocks even if we cannot schedule it + self._free_blocks_for_sliding_window( + self.req_to_blocks[request.request_id], + request.num_computed_tokens) req_blocks = self.req_to_blocks[request.request_id] - num_new_blocks = num_required_blocks - len(req_blocks) - if num_new_blocks > self.free_block_queue.num_free_blocks: + num_new_blocks = [ + manager.get_num_new_blocks(request.num_computed_tokens, num_tokens, + len(req_blocks_of_group)) + for manager, req_blocks_of_group in zip(self.managers, req_blocks) + ] + total_new_blocks = sum(max(x, 0) for x in num_new_blocks) + + if total_new_blocks > self.free_block_queue.num_free_blocks: # Need to allocate new blocks due to insufficient pre-allocated # slots, but we cannot allocate new blocks due to the limit. return None - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_new_blocks = min( - num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 + # TODO(Chen): add comments + num_preallocate_blocks = min( + self.num_preallocate_blocks, + (self.free_block_queue.num_free_blocks - total_new_blocks) // + len(self.managers)) - new_blocks = self._get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) + new_blocks = [] - if not self.enable_caching: - return [new_blocks] - - num_computed_full_blocks = (request.num_computed_tokens // - self.block_size) - - # NOTE(rickyx): We are assuming the `num_tokens` are actual - # tokens rather than lookahead slots (e.g. for speculative decoding). - # TODO(rickyx): When supporting speculative decoding, we will need to - # differentiate between them so that we can know how many blocks are - # full after appending the actual tokens. - num_full_blocks_after_append = (request.num_computed_tokens + - num_tokens) // self.block_size - assert num_full_blocks_after_append <= len(req_blocks) - - new_full_blocks = req_blocks[ - num_computed_full_blocks:num_full_blocks_after_append] - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=num_computed_full_blocks, - full_blocks=new_full_blocks, - prev_block=req_blocks[num_computed_full_blocks - 1] - if num_computed_full_blocks >= 1 else None, - kv_cache_group_id=0, - ) + for i in range(len(self.kv_cache_config.groups) + ): # TODO: self.num_kv_cache_groups + if num_new_blocks[i] <= 0: + # No new block is needed. + new_blocks.append([]) + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_block_to_allocate = min( + num_new_blocks[i] + num_preallocate_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req[i] - len(req_blocks[i]), + ) + assert num_block_to_allocate > 0 + + new_blocks_of_group = self._get_new_blocks(num_new_blocks) + new_blocks.append(new_blocks_of_group) + req_blocks[i].extend(new_blocks) - return [new_blocks] + if not self.enable_caching: + return new_blocks + + for i, manager in enumerate(self.managers): + num_computed_full_blocks = (request.num_computed_tokens // + manager.block_size) + + # NOTE(rickyx): We are assuming the `num_tokens` are actual tokens + # rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need + # to differentiate between them so that we can know how many blocks + # are full after appending the actual tokens. + num_full_blocks_after_append = (request.num_computed_tokens + + num_tokens) // manager.block_size + assert num_full_blocks_after_append <= len(req_blocks) + + new_full_blocks = req_blocks[i][ + num_computed_full_blocks:num_full_blocks_after_append] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=num_computed_full_blocks, + full_blocks=new_full_blocks, + prev_block=req_blocks[i][num_computed_full_blocks - 1] + if num_computed_full_blocks >= 1 else None, + kv_cache_group_id=i, + ) + + return new_blocks def allocate_slots( self, request: Request, num_tokens: int, - computed_blocks_all_groups: ReqKVCacheBlocks, + computed_blocks: ReqKVCacheBlocks, ) -> Optional[ReqKVCacheBlocks]: """Allocate slots for a new request. @@ -208,8 +254,7 @@ def allocate_slots( request: The request to allocate slots. num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. - computed_blocks_all_groups: The computed blocks. Should contain - only one KV cache group. + computed_blocks: The computed blocks. Returns: The new blocks if new blocks are allocated, or None if new blocks @@ -219,16 +264,23 @@ def allocate_slots( raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") - assert len(computed_blocks_all_groups) == 1 - computed_blocks = computed_blocks_all_groups[0] # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in computed_blocks + num_evictable_computed_blocks = sum(1 for blk_group in computed_blocks + for blk in blk_group if blk.ref_cnt == 0) - num_required_blocks = cdiv(num_tokens, self.block_size) - if (num_required_blocks > self.free_block_queue.num_free_blocks - + num_new_blocks = [ + manager.get_num_new_blocks(request.num_computed_tokens, num_tokens, + len(computed_blocks_of_group)) + for manager, computed_blocks_of_group in zip( + self.managers, computed_blocks) + ] + + total_new_blocks = sum(max(x, 0) for x in num_new_blocks) + + if (total_new_blocks > self.free_block_queue.num_free_blocks - num_evictable_computed_blocks): # Cannot allocate new blocks. return None @@ -241,65 +293,137 @@ def allocate_slots( "Computed blocks should be empty when " "prefix caching is disabled") - # Determine the number of new blocks to allocate considering - # preallocated blocks. - num_new_blocks = min( - num_required_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - self.max_num_blocks_per_req - len(computed_blocks), - ) - assert num_new_blocks > 0 + # TODO(Chen): add comments + num_preallocate_blocks = min( + self.num_preallocate_blocks, + (self.free_block_queue.num_free_blocks - total_new_blocks) // + len(self.managers)) - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self._get_new_blocks(num_new_blocks) - self.req_to_blocks[request.request_id] = computed_blocks + new_blocks + new_blocks = [] + req_to_blocks = [] - if not self.enable_caching: - return [new_blocks] - - num_computed_tokens = len(computed_blocks) * self.block_size - num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - - new_full_blocks = self.req_to_blocks[ - request.request_id][len(computed_blocks):num_full_blocks] - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=len(computed_blocks), - # The new full blocks are the full blocks that are not computed. - full_blocks=new_full_blocks, - prev_block=computed_blocks[-1] if computed_blocks else None, - kv_cache_group_id=0, + for i in range(len(self.managers)): + # Determine the number of new blocks to allocate considering + # preallocated blocks. + num_block_to_allocate = min( + num_new_blocks[i] + num_preallocate_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req[i] - len(computed_blocks[i]), ) + assert num_block_to_allocate > 0 - return [new_blocks] + new_blocks_of_group = self._get_new_blocks(num_block_to_allocate) + new_blocks.append(new_blocks_of_group) + # Concatenate the computed block IDs and the new block IDs. + req_to_blocks.append(computed_blocks[i] + new_blocks_of_group) - def free(self, request: Request) -> None: - """Free the blocks allocated for the request. - When caching is enabled, we free the blocks in reverse order so that - the tail blocks are evicted first. + self.req_to_blocks[request.request_id] = req_to_blocks - Args: - request: The request to free the blocks. - """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) + if not self.enable_caching: + return new_blocks + + for i, manager in enumerate(self.managers): + num_computed_tokens = len(computed_blocks) * manager.block_size + num_full_blocks = (num_computed_tokens + + num_tokens) // manager.block_size + + new_full_blocks = req_to_blocks[i][len(computed_blocks + ):num_full_blocks] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=len(computed_blocks), + # The new full blocks are the full blocks that are not computed. + full_blocks=new_full_blocks, + prev_block=computed_blocks[-1] + if computed_blocks else None, + kv_cache_group_id=i, + ) + + return new_blocks + + def _get_ordered_blocks_one_kv_cache_group( + self, blocks: KVCacheBlocks) -> Iterable[KVCacheBlock]: ordered_blocks: Iterable[KVCacheBlock] = blocks if self.enable_caching: # Free blocks in reverse order so that the tail blocks are # freed first. ordered_blocks = reversed(blocks) + return ordered_blocks + + def _get_ordered_blocks_multiple_kv_cache_groups( + self, blocks: ReqKVCacheBlocks) -> Iterable[KVCacheBlock]: + # Fast path: if all blocks are empty, return. This will happen during + # append_slots + blocks = [b for b in blocks if len(b) > 0] + if len(blocks) == 0: + return [] + # Free blocks in reverse order so that the tail blocks are + # freed first. + if self.enable_caching: + # TODO(Chen): add comments + # merge blocks from different groups based on the block size + block_size_set = set(manager.block_size + for manager in self.managers) + if len(block_size_set) == 1: + # O(n) time complexity if block_size of all groups are the same + ordered_blocks = [] + for i in range(len(blocks[0]) - 1, -1, -1): + for blocks_of_group in blocks: + ordered_blocks.append(blocks_of_group[i]) + else: + # O(n * log(n)) time complexity + # TODO(Chen): optimize it to O(n*len(self.managers)) time complexity + # NOTE: untested + ordered_blocks_with_key = [] + + for i, blocks_of_group in enumerate(blocks): + block_size = self.managers[i].block_size + for i, block in enumerate(blocks_of_group): + ordered_blocks_with_key.append((block_size * i, block)) + + ordered_blocks_with_key.sort(reverse=True) + ordered_blocks = [ + block for _, block in ordered_blocks_with_key + ] + else: + # TODO: need to implement this path + raise NotImplementedError + return ordered_blocks + + def _free_blocks(self, blocks: ReqKVCacheBlocks) -> None: + if len(self.kv_cache_config.groups) == 1: + ordered_blocks = self._get_ordered_blocks_one_kv_cache_group( + blocks[0]) + else: + ordered_blocks = self._get_ordered_blocks_multiple_kv_cache_groups( + blocks) for block in ordered_blocks: block.decr_ref() if block.ref_cnt == 0: self.free_block_queue.append(block) + def free(self, request: Request) -> None: + """Free the blocks allocated for the request. + When caching is enabled, we free the blocks in reverse order so that + the tail blocks are evicted first. + + Args: + request: The request to free the blocks. + """ + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request.request_id, []) + if len(blocks) == 0: + # This request is freed before alloc. just return + return + else: + self._free_blocks(blocks) + def get_num_common_prefix_blocks( self, request: Request, @@ -342,13 +466,16 @@ def get_num_common_prefix_blocks( """ assert request.status == RequestStatus.RUNNING blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - return [num_common_blocks] + num_common_blocks_per_group = [] + for blocks_of_group in blocks: + num_common_blocks = 0 + for block in blocks_of_group: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + num_common_blocks_per_group.append(num_common_blocks) + return num_common_blocks_per_group def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: """Get new blocks from the free block pool. @@ -415,7 +542,7 @@ def _get_cached_block(self, return self.cached_block_hash_to_block[block_hash][first_block_id] return None - def _touch(self, blocks: List[KVCacheBlock]) -> None: + def _touch(self, blocks: ReqKVCacheBlocks) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -423,12 +550,13 @@ def _touch(self, blocks: List[KVCacheBlock]) -> None: Args: blocks: A list of blocks to touch. """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0: - self.free_block_queue.remove(block) - block.incr_ref() + for blocks_of_group in blocks: + for block in blocks_of_group: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and block != self._null_block: + self.free_block_queue.remove(block) + block.incr_ref() def _cache_full_blocks( self, @@ -465,6 +593,8 @@ def _cache_full_blocks( assert prev_block.block_hash is not None prev_block_hash_value = prev_block.block_hash.hash_value + block_size = self.kv_cache_config.groups[ + kv_cache_group_id].kv_cache_spec.block_size for i, blk in enumerate(full_blocks): blk_idx = blk_start_idx + i @@ -479,12 +609,12 @@ def _cache_full_blocks( else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. - start_token_idx = blk_idx * self.block_size - end_token_idx = (blk_idx + 1) * self.block_size + start_token_idx = blk_idx * block_size + end_token_idx = (blk_idx + 1) * block_size block_tokens = request.all_token_ids[ start_token_idx:end_token_idx] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got " + assert len(block_tokens) == block_size, ( + f"Expected {block_size} tokens, got " f"{len(block_tokens)} at {blk_idx}th block for request " f"{request.request_id}({request})") @@ -503,3 +633,39 @@ def _cache_full_blocks( blk.block_hash = block_hash self.cached_block_hash_to_block[block_hash][blk.block_id] = blk prev_block_hash_value = block_hash.hash_value + + def get_null_block(self) -> KVCacheBlock: + return self._null_block + + def _get_common_computed_tokens(self, + computed_tokens: KVCacheBlocks) -> int: + # TODO: add comments: the largest in the intersection, and alignment + intersection = intersect_ranges(computed_tokens) + + # Since incomplete blocks are not eligible for sharing, + # `num_computed_tokens` should be a multiple of `block_size` of + # all managers, so we take the least common multiple (LCM) of them + alignment = math.lcm( + *[manager.block_size for manager in self.managers]) + + num_computed_tokens = 0 + for range_ in intersection: + aligned_end = cdiv(range_.end, alignment) * alignment + if aligned_end > range_.start: + num_computed_tokens = aligned_end + break + + return num_computed_tokens + + def _free_blocks_for_sliding_window(self, req_blocks: ReqKVCacheBlocks, + num_computed_tokens: int) -> None: + # NOTE(Chen): do all free before allocation to make less eviction + # req_blocks = self.req_to_blocks[request.request_id] + removed_blocks = [] + for manager, req_blocks_of_group in zip(self.managers, req_blocks): + removed_blocks.append( + manager.remove_dropped_blocks(req_blocks_of_group, + num_computed_tokens)) + # TODO: better handling of free order (e.g., this order have problem + # when different layer has different sliding window size) + self._free_blocks(removed_blocks) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 395b74045e527..a7fee270673a9 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -33,7 +33,8 @@ class BlockHashType(NamedTuple): @dataclass class KVCacheBlock: """KV-cache block metadata.""" - # Block ID, ranging from 0 to num_gpu_blocks - 1. + # Block ID, ranging from 0 to num_gpu_blocks - 1, and a special null_block + # with block_id = -1. block_id: int # Reference count. ref_cnt: int = 0 @@ -282,14 +283,15 @@ def hash_block_tokens( tuple(curr_block_token_ids), extra_keys) -def hash_request_tokens(block_size: int, - request: Request) -> List[BlockHashType]: +def hash_request_tokens(block_size: int, request: Request, + group_id: int) -> List[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. request: The request object. + group_id: TODO Returns: The list of computed hash values. @@ -301,8 +303,7 @@ def hash_request_tokens(block_size: int, "The number of multi-modal positions and hashes must match.") # TODO: Extend this to support other features such as LoRA. - need_extra_keys = bool(mm_positions) - extra_keys = None + need_mm_keys = bool(mm_positions) curr_mm_idx = 0 ret = [] @@ -314,13 +315,17 @@ def hash_request_tokens(block_size: int, if len(block_token_ids) < block_size: break + extra_keys = [group_id] + # Add extra keys if the block is a multi-modal block. - if need_extra_keys: - extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + if need_mm_keys: + mm_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) + if mm_keys is not None: + extra_keys.extend(mm_keys) block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids, extra_keys) + block_token_ids, tuple(extra_keys)) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 82929011dc95c..f17e33bedf0fb 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -10,6 +10,7 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -29,6 +30,7 @@ def __init__( model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + kv_cache_config: KVCacheConfig, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -46,10 +48,8 @@ def __init__( assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, + kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) self.block_size = self.cache_config.block_size @@ -203,6 +203,7 @@ def schedule(self) -> "SchedulerOutput": # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens if num_new_tokens == 0: + raise NotImplementedError # The happens when prompt length is divisible by the block # size and all blocks are cached. Now we force to recompute # the last block. Note that we have to re-compute an entire diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 26ebc7edcf03e..c3bb149a1efb5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -23,6 +23,7 @@ EngineCoreRequestUnion) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder from vllm.version import __version__ as VLLM_VERSION @@ -49,10 +50,9 @@ def __init__( self.model_executor = executor_class(vllm_config) # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( - vllm_config) - vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks - vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + kv_cache_config = self._initialize_kv_caches(vllm_config) + vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + vllm_config.cache_config.num_cpu_blocks = 0 # Setup scheduler. self.scheduler = Scheduler( @@ -60,13 +60,12 @@ def __init__( model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, - ) + kv_cache_config=kv_cache_config) self.mm_input_mapper_server = MMInputMapperServer( vllm_config.model_config) - def _initialize_kv_caches(self, - vllm_config: VllmConfig) -> Tuple[int, int]: + def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig: start = time.time() # Get all kv cache needed by the model @@ -79,8 +78,6 @@ def _initialize_kv_caches(self, # Get the kv cache tensor size kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, availble_gpu_memory) - num_gpu_blocks = kv_cache_config.num_blocks - num_cpu_blocks = 0 # Initialize kv cache and warmup the execution self.model_executor.initialize(kv_cache_config) @@ -88,7 +85,7 @@ def _initialize_kv_caches(self, elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) - return num_gpu_blocks, num_cpu_blocks + return kv_cache_config def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" From 5b71ccd7e0e7507dc2043263ff884a6f585d781f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 21 Jan 2025 21:40:42 -0800 Subject: [PATCH 12/48] fix small bug Signed-off-by: Chen Zhang --- vllm/v1/core/hybrid_cache_manager/utils.py | 2 +- vllm/v1/core/kv_cache_manager.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/hybrid_cache_manager/utils.py b/vllm/v1/core/hybrid_cache_manager/utils.py index b22a3e783fb0a..cb6c895354021 100644 --- a/vllm/v1/core/hybrid_cache_manager/utils.py +++ b/vllm/v1/core/hybrid_cache_manager/utils.py @@ -5,7 +5,7 @@ @dataclass class ComputedTokenRange: """ - [start, end) + (start, end] """ start: int end: int diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c0a4c071a013c..cb1362cc2d8a4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -132,7 +132,10 @@ def get_computed_blocks(self, # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. - num_computed_tokens = len(computed_blocks[0]) * self.block_size + if len(computed_tokens[0]) == 0: + num_computed_tokens = 0 + else: + num_computed_tokens = computed_tokens[0][-1].end else: # find the common cached prefix of all groups. This path also works # for the single group case, but it is less efficient. From 6a0eb69ae3e69d7c3f5c73d3a1cc002b8f139604 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 29 Jan 2025 01:29:44 -0800 Subject: [PATCH 13/48] cleanup SpecializedManager Signed-off-by: Chen Zhang --- .../specialized_manager.py | 126 ++++++++++++------ vllm/v1/core/hybrid_cache_manager/utils.py | 18 +-- vllm/v1/core/kv_cache_manager.py | 23 ++-- 3 files changed, 106 insertions(+), 61 deletions(-) diff --git a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py index 59324b98422e3..e624c959e8638 100644 --- a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py +++ b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py @@ -1,73 +1,121 @@ from abc import ABC, abstractmethod from collections import deque from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, TypedDict +from typing import Callable, Deque, Dict, List, Optional, Tuple, TypedDict from vllm.utils import cdiv from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock -from vllm.v1.core.hybrid_cache_manager.utils import ComputedTokenRange, ComputedTokens +from vllm.v1.core.hybrid_cache_manager.utils import PrefixLength, PrefixLengthRange from vllm.v1.utils import ConstantList @dataclass -class MemoryPoolOperations: +class BlockPoolOperations: get_cached_block: Callable[[BlockHashType], Optional[KVCacheBlock]] get_null_block: Callable[[], KVCacheBlock] class SpecializedManager(ABC): + """ + An abstract base class for specialized managers that handle the kv + cache management logic of different attention layers. + """ block_size: int max_num_blocks_per_req: int def __init__( self, - layer_spec: KVCacheSpec, - memory_pool_operations: MemoryPoolOperations, + kv_cache_spec: KVCacheSpec, + block_pool_operations: BlockPoolOperations, ) -> None: - self.block_size = layer_spec.block_size - self.memory_pool_operations = memory_pool_operations + """ + Initializes the SpecializedManager. + + Args: + kv_cache_spec: The kv_cache_spec for this manager. + block_pool_operations: Operations to interact with the block pool. + + Returns: + None + """ + + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_pool_operations = block_pool_operations @abstractmethod - def get_computed_blocks_and_tokens( + def get_possible_cached_prefix( self, block_hashes: ConstantList[BlockHashType] - ) -> Tuple[List[KVCacheBlock], ComputedTokens]: + ) -> Tuple[PrefixLength, List[KVCacheBlock]]: + """ + Get the possible cached prefixes of a request based on its block hashes. + If no cached prefixes are found, returns a tuple with a prefix length + range of [0, 0] and an empty list of blocks. + + Args: + block_hashes: The block hashes of the request. + + Returns: + A tuple containing: + - A list of all possible cached prefix lengths. + - The computed blocks that are cached. + """ + raise NotImplementedError @abstractmethod def get_num_new_blocks(self, num_computed_tokens: int, num_append_tokens: int, num_allocated_blocks: int) -> int: + """ + Calculate the number of new blocks needed by this manager. + + Args: + num_computed_tokens: The number of tokens that have been computed. + num_append_tokens: The number of tokens that need to be appended. + num_allocated_blocks: The number of blocks that have already been + allocated. + + Returns: + int: The number of new blocks needed. + """ raise NotImplementedError @abstractmethod - def remove_dropped_blocks(self, block_table: List[KVCacheBlock], - num_computed_tokens: int): - # update block_table inplace + def remove_useless_blocks(self, block_table: List[KVCacheBlock], + num_computed_tokens: int) -> List[KVCacheBlock]: + """ + Update the `block_table` in place to remove blocks that are no longer + needed. Returns the removed blocks. + + Args: + block_table: The block table to be updated. + num_computed_tokens: The number of tokens that have been computed. + + Returns: + List[KVCacheBlock]: The removed blocks. + """ raise NotImplementedError class FullAttentionManager(SpecializedManager): - def get_computed_blocks_and_tokens( + def get_possible_cached_prefix( self, block_hashes: ConstantList[BlockHashType] - ) -> Tuple[List[KVCacheBlock], ComputedTokens]: + ) -> Tuple[List[PrefixLengthRange], List[KVCacheBlock]]: computed_blocks: List[KVCacheBlock] = [] for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self.memory_pool_operations.get_cached_block( + if cached_block := self.block_pool_operations.get_cached_block( block_hash): computed_blocks.append(cached_block) else: break - if len(computed_blocks) == 0: - return [], [] - else: - return [ - ComputedTokenRange(0, - len(computed_blocks) * self.block_size) - ], computed_blocks + return [PrefixLengthRange(0, + len(computed_blocks) * self.block_size) + ], computed_blocks def get_num_new_blocks(self, num_computed_tokens: int, num_append_tokens: int, @@ -77,24 +125,24 @@ def get_num_new_blocks(self, num_computed_tokens: int, num_new_blocks = num_required_blocks - num_allocated_blocks return num_new_blocks - def remove_dropped_blocks(self, block_table: List[KVCacheBlock], + def remove_useless_blocks(self, block_table: List[KVCacheBlock], num_computed_tokens: int) -> List[KVCacheBlock]: return [] class SlidingWindowManager(FullAttentionManager): - def __init__(self, layer_spec: SlidingWindowSpec, - memory_pool_operations: MemoryPoolOperations): - super().__init__(layer_spec, memory_pool_operations) + def __init__(self, kv_cache_spec: SlidingWindowSpec, + block_pool_operations: BlockPoolOperations): + super().__init__(kv_cache_spec, block_pool_operations) # +1 due to not aligned - self.num_block_sliding_window = cdiv(layer_spec.sliding_window, + self.num_block_sliding_window = cdiv(kv_cache_spec.sliding_window, self.block_size) + 1 - self._null_block = memory_pool_operations.get_null_block() + self._null_block = block_pool_operations.get_null_block() - def get_computed_blocks_and_tokens( + def get_possible_cached_prefix( self, block_hashes: ConstantList[BlockHashType] - ) -> Tuple[List[KVCacheBlock], ComputedTokens]: + ) -> Tuple[List[PrefixLengthRange], List[KVCacheBlock]]: # TODO: check the hit every num_block_sliding_window blocks, to optimize # the time complexity from O(num_block) to # O(num_block / num_block_sliding_window) + O(num_computed_block), @@ -104,28 +152,28 @@ def get_computed_blocks_and_tokens( computed_blocks: List[KVCacheBlock] = [] for i, block_hash in enumerate(block_hashes): - if cached_block := self.memory_pool_operations.get_cached_block( + if cached_block := self.block_pool_operations.get_cached_block( block_hash): computed_blocks.append(cached_block) else: if start == 0: ranges.append( - ComputedTokenRange(start * self.block_size, - i * self.block_size)) + PrefixLengthRange(start * self.block_size, + i * self.block_size)) elif i - start >= self.num_block_sliding_window: - ranges.append((ComputedTokenRange( + ranges.append((PrefixLengthRange( (start + self.num_block_sliding_window) * self.block_size, i * self.block_size))) computed_blocks.append( - self.memory_pool_operations.get_null_block()) + self.block_pool_operations.get_null_block()) start = i + 1 return ranges, computed_blocks - def remove_dropped_blocks(self, block_table: List[KVCacheBlock], + def remove_useless_blocks(self, block_table: List[KVCacheBlock], num_computed_tokens: int) -> List[KVCacheBlock]: num_block_should_free = cdiv(num_computed_tokens, self.block_size) - \ self.num_block_sliding_window - removed_blocks = deque() + removed_blocks: Deque[KVCacheBlock] = deque() for i in range(num_block_should_free - 1, -1, -1): if block_table[i] == self._null_block: break @@ -142,11 +190,11 @@ def remove_dropped_blocks(self, block_table: List[KVCacheBlock], def get_managers( kv_cache_config: KVCacheConfig, - memory_pool_operations: MemoryPoolOperations + block_pool_operations: BlockPoolOperations ) -> List[SpecializedManager]: managers: List[SpecializedManager] = [] for g in kv_cache_config.groups: manager_class = spec_manager_map[type(g.kv_cache_spec)] - manager = manager_class(g.kv_cache_spec, memory_pool_operations) + manager = manager_class(g.kv_cache_spec, block_pool_operations) managers.append(manager) return managers diff --git a/vllm/v1/core/hybrid_cache_manager/utils.py b/vllm/v1/core/hybrid_cache_manager/utils.py index cb6c895354021..5d89de60cea54 100644 --- a/vllm/v1/core/hybrid_cache_manager/utils.py +++ b/vllm/v1/core/hybrid_cache_manager/utils.py @@ -3,22 +3,22 @@ @dataclass -class ComputedTokenRange: +class PrefixLengthRange: """ - (start, end] + [start, end] """ start: int end: int -ComputedTokens = List[ComputedTokenRange] +PrefixLength = List[PrefixLengthRange] def intersect_two_ranges( - a: List[ComputedTokenRange], - b: List[ComputedTokenRange]) -> List[ComputedTokenRange]: + a: List[PrefixLengthRange], + b: List[PrefixLengthRange]) -> List[PrefixLengthRange]: """ - Intersect two sorted lists of ComputedTokenRange intervals. + Intersect two sorted lists of PrefixLengthRange intervals. Args: a: List of intervals @@ -34,7 +34,7 @@ def intersect_two_ranges( overlap_end = min(a[i].end, b[j].end) if overlap_start <= overlap_end: - result.append(ComputedTokenRange(overlap_start, overlap_end)) + result.append(PrefixLengthRange(overlap_start, overlap_end)) if a[i].end < b[j].end: i += 1 @@ -45,9 +45,9 @@ def intersect_two_ranges( def intersect_ranges( - ranges: List[List[ComputedTokenRange]]) -> List[ComputedTokenRange]: + ranges: List[List[PrefixLengthRange]]) -> List[PrefixLengthRange]: """ - Intersect multiple lists of ComputedTokenRange intervals, each is sorted. + Intersect multiple lists of PrefixLengthRange intervals, each is sorted. Args: ranges: A list of lists of intervals diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index cb1362cc2d8a4..bd1a187ba16b2 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,14 +4,14 @@ from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.hybrid_cache_manager.specialized_manager import MemoryPoolOperations, get_managers +from vllm.v1.core.hybrid_cache_manager.specialized_manager import BlockPoolOperations, get_managers from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, KVCacheBlocks, ReqKVCacheBlocks, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) -from vllm.v1.core.hybrid_cache_manager.utils import ComputedTokens, intersect_ranges +from vllm.v1.core.hybrid_cache_manager.utils import PrefixLength, intersect_ranges from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus @@ -62,8 +62,8 @@ def __init__( # TODO(Chen): add comments self.managers = get_managers( kv_cache_config, - MemoryPoolOperations(get_cached_block=self._get_cached_block, - get_null_block=self.get_null_block), + BlockPoolOperations(get_cached_block=self._get_cached_block, + get_null_block=self.get_null_block), ) # A Block pool of all kv-cache blocks. @@ -118,7 +118,7 @@ def get_computed_blocks(self, ]) computed_blocks: ReqKVCacheBlocks = [] # group_id->[blocks] - computed_tokens: List[ComputedTokens] = [] # group_id->ComputedTokens + computed_tokens: List[PrefixLength] = [] # group_id->PrefixLength block_hashes = request.kv_block_hashes for i, manager in enumerate(self.managers): computed_tokens_i, computed_blocks_i = ( @@ -132,10 +132,7 @@ def get_computed_blocks(self, # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. - if len(computed_tokens[0]) == 0: - num_computed_tokens = 0 - else: - num_computed_tokens = computed_tokens[0][-1].end + num_computed_tokens = computed_tokens[0][-1].end else: # find the common cached prefix of all groups. This path also works # for the single group case, but it is less efficient. @@ -640,8 +637,8 @@ def _cache_full_blocks( def get_null_block(self) -> KVCacheBlock: return self._null_block - def _get_common_computed_tokens(self, - computed_tokens: KVCacheBlocks) -> int: + def _get_common_computed_tokens( + self, computed_tokens: List[PrefixLength]) -> int: # TODO: add comments: the largest in the intersection, and alignment intersection = intersect_ranges(computed_tokens) @@ -654,7 +651,7 @@ def _get_common_computed_tokens(self, num_computed_tokens = 0 for range_ in intersection: aligned_end = cdiv(range_.end, alignment) * alignment - if aligned_end > range_.start: + if aligned_end >= range_.start: num_computed_tokens = aligned_end break @@ -667,7 +664,7 @@ def _free_blocks_for_sliding_window(self, req_blocks: ReqKVCacheBlocks, removed_blocks = [] for manager, req_blocks_of_group in zip(self.managers, req_blocks): removed_blocks.append( - manager.remove_dropped_blocks(req_blocks_of_group, + manager.remove_useless_blocks(req_blocks_of_group, num_computed_tokens)) # TODO: better handling of free order (e.g., this order have problem # when different layer has different sliding window size) From 14ad04e9d1c7fce1a81a4f397daa0a3e137cb008 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 29 Jan 2025 08:06:08 -0800 Subject: [PATCH 14/48] add test and fix bug for sliding window manager Signed-off-by: Chen Zhang --- tests/v1/core/test_specialized_manager.py | 112 ++++++++++++++++++ .../specialized_manager.py | 63 +++++++--- vllm/v1/core/kv_cache_manager.py | 9 +- 3 files changed, 162 insertions(+), 22 deletions(-) create mode 100644 tests/v1/core/test_specialized_manager.py diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py new file mode 100644 index 0000000000000..959972521461a --- /dev/null +++ b/tests/v1/core/test_specialized_manager.py @@ -0,0 +1,112 @@ +from collections import deque +import torch +from vllm.v1.core.hybrid_cache_manager.specialized_manager import BlockPoolOperations, SlidingWindowManager +from vllm.v1.core.hybrid_cache_manager.utils import PrefixLengthRange +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.kv_cache_interface import SlidingWindowSpec + + +def test_sliding_window_possible_cached_prefix(): + sliding_window_spec = SlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + ) + + block_pool_result = deque() + null_block = KVCacheBlock(-1, 0) + + def get_cached_block(_block_hash): + if isinstance(_block_hash, + BlockHashType) and _block_hash.hash_value == -1: + # the dummy block hash + return None + is_cached = block_pool_result.popleft() + if is_cached: + return 1 + else: + return None + + def get_null_block(): + return null_block + + manager = SlidingWindowManager( + sliding_window_spec, + BlockPoolOperations(get_cached_block, get_null_block)) + + block_pool_result.clear() + block_pool_result.extend([ + True, True, False, True, False, False, True, True, False, True, True, + True + ]) + ranges, computed_blocks = manager.get_possible_cached_prefix( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert ranges == [ + PrefixLengthRange(0, 4), + PrefixLengthRange(16, 16), + PrefixLengthRange(22, 24) + ] + assert computed_blocks == [ + 1, 1, null_block, 1, null_block, null_block, 1, 1, null_block, 1, 1, 1 + ] + + +def test_sliding_window_remove_useless_blocks(): + sliding_window_spec = SlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + ) + + def get_cached_block(_block_hash): + # should not be called + raise NotImplementedError + + def get_null_block(): + return KVCacheBlock(-1, 0) + + manager = SlidingWindowManager( + sliding_window_spec, + BlockPoolOperations(get_cached_block, get_null_block)) + + def id_to_block_table(ids): + return [ + KVCacheBlock(id_, 0) if id_ != -1 else get_null_block() + for id_ in ids + ] + + def assert_block_id(block_table, ids): + for block, id_ in zip(block_table, ids): + if id_ == -1: + assert block == get_null_block() + else: + assert block.block_id == id_ + + block_table = id_to_block_table([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + removed = manager.remove_useless_blocks(block_table, 0) + assert_block_id(removed, []) + assert_block_id(block_table, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 5) + assert_block_id(removed, []) + assert_block_id(block_table, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 6) + assert_block_id(removed, [0]) + assert_block_id(block_table, [-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 7) + assert_block_id(removed, []) + assert_block_id(block_table, [-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 8) + assert_block_id(removed, [1]) + assert_block_id(block_table, [-1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 12) + assert_block_id(removed, [3, 2]) + assert_block_id(block_table, [-1, -1, -1, -1, 4, 5, 6, 7, 8, 9, 10]) diff --git a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py index e624c959e8638..64717ee14476d 100644 --- a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py +++ b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections import deque from dataclasses import dataclass +from itertools import chain from typing import Callable, Deque, Dict, List, Optional, Tuple, TypedDict from vllm.utils import cdiv from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec @@ -86,7 +87,11 @@ def remove_useless_blocks(self, block_table: List[KVCacheBlock], num_computed_tokens: int) -> List[KVCacheBlock]: """ Update the `block_table` in place to remove blocks that are no longer - needed. Returns the removed blocks. + needed. Replace the removed blocks with null_block and returns the + removed blocks. + The removed blocks should be in the order of the + priority to be evicted, where the first block should have the highest + priority. Args: block_table: The block table to be updated. @@ -135,9 +140,13 @@ class SlidingWindowManager(FullAttentionManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool_operations: BlockPoolOperations): super().__init__(kv_cache_spec, block_pool_operations) - # +1 due to not aligned - self.num_block_sliding_window = cdiv(kv_cache_spec.sliding_window, - self.block_size) + 1 + self.sliding_window = kv_cache_spec.sliding_window + # # +1 here because the sliding window may not start from the beginning + # # of the first block. For example, if the block size is 2, and sliding + # # window size is 4, [XX, XA, BC, D] where ABCD are the 4 tokens inside + # # the sliding window, we need to hold the last 3 blocks. + # self.num_block_sliding_window = cdiv(kv_cache_spec.sliding_window, + # self.block_size) + 1 self._null_block = block_pool_operations.get_null_block() def get_possible_cached_prefix( @@ -147,42 +156,60 @@ def get_possible_cached_prefix( # the time complexity from O(num_block) to # O(num_block / num_block_sliding_window) + O(num_computed_block), # which is good for low cache hit rate senarios. + # TODO: add test for this function start = 0 ranges = [] computed_blocks: List[KVCacheBlock] = [] - for i, block_hash in enumerate(block_hashes): + dummy_block_hash = BlockHashType(-1, ()) + # Add a dummy block hash to support the case that the last block is + # cached. + for i, block_hash in enumerate(chain(block_hashes, + [dummy_block_hash])): if cached_block := self.block_pool_operations.get_cached_block( block_hash): computed_blocks.append(cached_block) else: if start == 0: + # All tokens between [0, i * block_size] are cached. + # All of them are possible cached prefix. + ranges.append(PrefixLengthRange(0, i * self.block_size)) + elif (i - start) * self.block_size >= self.sliding_window: + # All tokens between [start * block_size, + # i * block_size)] are cached. These tokens except the + # first `self.sliding_window - 1` ones are possible cached + # prefix. + first_cached_token = start * self.block_size + # should be first_cached_token + self.sliding_window - 1 + 1 + # +1 is for converting the token index to the prefix length. + first_possible_length = first_cached_token + \ + self.sliding_window ranges.append( - PrefixLengthRange(start * self.block_size, + PrefixLengthRange(first_possible_length, i * self.block_size)) - elif i - start >= self.num_block_sliding_window: - ranges.append((PrefixLengthRange( - (start + self.num_block_sliding_window) * - self.block_size, i * self.block_size))) - computed_blocks.append( - self.block_pool_operations.get_null_block()) + computed_blocks.append(self._null_block) start = i + 1 + computed_blocks = computed_blocks[:-1] # remove the dummy block return ranges, computed_blocks def remove_useless_blocks(self, block_table: List[KVCacheBlock], num_computed_tokens: int) -> List[KVCacheBlock]: - num_block_should_free = cdiv(num_computed_tokens, self.block_size) - \ - self.num_block_sliding_window - removed_blocks: Deque[KVCacheBlock] = deque() - for i in range(num_block_should_free - 1, -1, -1): + # Remove the blocks that are no longer be in the sliding window. + last_useful_token = num_computed_tokens - self.sliding_window + last_useful_block = last_useful_token // self.block_size + + removed_blocks: List[KVCacheBlock] = [] + for i in range(last_useful_block - 1, -1, -1): if block_table[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also be null blocks. break - removed_blocks.appendleft(block_table[i]) + removed_blocks.append(block_table[i]) block_table[i] = self._null_block return removed_blocks -spec_manager_map = { +spec_manager_map: Dict[KVCacheSpec, SpecializedManager] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager } diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bd1a187ba16b2..8aca485bac655 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -122,7 +122,7 @@ def get_computed_blocks(self, block_hashes = request.kv_block_hashes for i, manager in enumerate(self.managers): computed_tokens_i, computed_blocks_i = ( - manager.get_computed_blocks_and_tokens(block_hashes[i])) + manager.get_possible_cached_prefix(block_hashes[i])) computed_blocks.append(computed_blocks_i) computed_tokens.append(computed_tokens_i) @@ -139,9 +139,9 @@ def get_computed_blocks(self, num_computed_tokens = self._get_common_computed_tokens( computed_tokens) - for i, manager in enumerate(self.managers): - computed_blocks[i] = computed_blocks[:num_computed_tokens // - manager.block_size] + for i, manager in enumerate(self.managers): + computed_blocks[i] = computed_blocks[:num_computed_tokens // + manager.block_size] self._free_blocks_for_sliding_window(computed_blocks, num_computed_tokens) return computed_blocks, num_computed_tokens @@ -348,6 +348,7 @@ def allocate_slots( def _get_ordered_blocks_one_kv_cache_group( self, blocks: KVCacheBlocks) -> Iterable[KVCacheBlock]: + # TODO (Chen): rethink where to do the reverse operation ordered_blocks: Iterable[KVCacheBlock] = blocks if self.enable_caching: # Free blocks in reverse order so that the tail blocks are From eb34a44f1368989d8338573f9869118260500728 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 29 Jan 2025 08:07:19 -0800 Subject: [PATCH 15/48] remove useless code Signed-off-by: Chen Zhang --- vllm/v1/core/hybrid_cache_manager/specialized_manager.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py index 64717ee14476d..ef58e17c89726 100644 --- a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py +++ b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py @@ -141,12 +141,6 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool_operations: BlockPoolOperations): super().__init__(kv_cache_spec, block_pool_operations) self.sliding_window = kv_cache_spec.sliding_window - # # +1 here because the sliding window may not start from the beginning - # # of the first block. For example, if the block size is 2, and sliding - # # window size is 4, [XX, XA, BC, D] where ABCD are the 4 tokens inside - # # the sliding window, we need to hold the last 3 blocks. - # self.num_block_sliding_window = cdiv(kv_cache_spec.sliding_window, - # self.block_size) + 1 self._null_block = block_pool_operations.get_null_block() def get_possible_cached_prefix( From f53e824612fe1d48f98670fbd8397c0578c2de14 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 30 Jan 2025 23:11:46 -0800 Subject: [PATCH 16/48] fix several bugs Signed-off-by: Chen Zhang --- .../v1/e2e/test_correctness_sliding_window.py | 30 +++++++++++++ .../specialized_manager.py | 3 +- vllm/v1/core/kv_cache_manager.py | 30 ++++++------- vllm/v1/core/kv_cache_utils.py | 42 +++++++++++++------ 4 files changed, 76 insertions(+), 29 deletions(-) create mode 100644 tests/v1/e2e/test_correctness_sliding_window.py diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py new file mode 100644 index 0000000000000..bbeb79e1c6a11 --- /dev/null +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -0,0 +1,30 @@ +import random +from typing import List +from vllm import LLM, SamplingParams +from ...core.block.e2e.test_correctness_sliding_window import (prep_prompts, + check_answers) +import pytest + + +@pytest.mark.parametrize("model", ["bigcode/starcoder2-3b"]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_retrival(monkeypatch, model, batch_size, seed): + """ + The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then + asks for value of one of them (which is outside the sliding window). + If we tell it upfront which we are going to be looking for, then + it answers correctly (mostly). + """ + # TODO: implement check_window + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM(model=model, enable_prefix_caching=True) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + + prompts, answer, indices = prep_prompts(batch_size) + + responses = llm.generate(prompts, sampling_params) + check_answers(indices, answer, + [response.outputs[0].text for response in responses]) diff --git a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py index ef58e17c89726..c61ff83ac0efe 100644 --- a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py +++ b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py @@ -150,12 +150,11 @@ def get_possible_cached_prefix( # the time complexity from O(num_block) to # O(num_block / num_block_sliding_window) + O(num_computed_block), # which is good for low cache hit rate senarios. - # TODO: add test for this function start = 0 ranges = [] computed_blocks: List[KVCacheBlock] = [] - dummy_block_hash = BlockHashType(-1, ()) + dummy_block_hash = BlockHashType(-1, (), -1) # Add a dummy block hash to support the case that the last block is # cached. for i, block_hash in enumerate(chain(block_hashes, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8aca485bac655..1f481c4b4b7a9 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -138,10 +138,9 @@ def get_computed_blocks(self, # for the single group case, but it is less efficient. num_computed_tokens = self._get_common_computed_tokens( computed_tokens) - for i, manager in enumerate(self.managers): - computed_blocks[i] = computed_blocks[:num_computed_tokens // - manager.block_size] + computed_blocks[i] = computed_blocks[i][:num_computed_tokens // + manager.block_size] self._free_blocks_for_sliding_window(computed_blocks, num_computed_tokens) return computed_blocks, num_computed_tokens @@ -208,9 +207,10 @@ def append_slots( ) assert num_block_to_allocate > 0 - new_blocks_of_group = self._get_new_blocks(num_new_blocks) + new_blocks_of_group = self._get_new_blocks( + num_block_to_allocate) new_blocks.append(new_blocks_of_group) - req_blocks[i].extend(new_blocks) + req_blocks[i].extend(new_blocks_of_group) if not self.enable_caching: return new_blocks @@ -226,7 +226,7 @@ def append_slots( # are full after appending the actual tokens. num_full_blocks_after_append = (request.num_computed_tokens + num_tokens) // manager.block_size - assert num_full_blocks_after_append <= len(req_blocks) + assert num_full_blocks_after_append <= len(req_blocks[i]) new_full_blocks = req_blocks[i][ num_computed_full_blocks:num_full_blocks_after_append] @@ -289,7 +289,7 @@ def allocate_slots( if self.enable_caching: self._touch(computed_blocks) else: - assert not computed_blocks, ( + assert all(len(blks) == 0 for blks in computed_blocks), ( "Computed blocks should be empty when " "prefix caching is disabled") @@ -325,22 +325,21 @@ def allocate_slots( if not self.enable_caching: return new_blocks - for i, manager in enumerate(self.managers): - num_computed_tokens = len(computed_blocks) * manager.block_size + num_computed_tokens = len(computed_blocks[i]) * manager.block_size num_full_blocks = (num_computed_tokens + num_tokens) // manager.block_size - new_full_blocks = req_to_blocks[i][len(computed_blocks + new_full_blocks = req_to_blocks[i][len(computed_blocks[i] ):num_full_blocks] if new_full_blocks: self._cache_full_blocks( request=request, - blk_start_idx=len(computed_blocks), + blk_start_idx=len(computed_blocks[i]), # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, - prev_block=computed_blocks[-1] - if computed_blocks else None, + prev_block=computed_blocks[i][-1] + if computed_blocks[i] else None, kv_cache_group_id=i, ) @@ -405,6 +404,8 @@ def _free_blocks(self, blocks: ReqKVCacheBlocks) -> None: ordered_blocks = self._get_ordered_blocks_multiple_kv_cache_groups( blocks) for block in ordered_blocks: + if block == self._null_block: + continue block.decr_ref() if block.ref_cnt == 0: self.free_block_queue.append(block) @@ -627,7 +628,8 @@ def _cache_full_blocks( # Compute the hash of the current block. block_hash = hash_block_tokens(prev_block_hash_value, - block_tokens, extra_keys) + block_tokens, kv_cache_group_id, + extra_keys) request.append_kv_block_hashes(kv_cache_group_id, block_hash) # Update and added the full block to the cache. diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a7fee270673a9..312e99d9808c1 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -26,6 +26,8 @@ class BlockHashType(NamedTuple): hash_value: int # Token IDs in the block. token_ids: Tuple[int, ...] + # The KV cache group that the block belongs to. + kv_cache_group_id: int # Extra keys for the block. extra_keys: Optional[Any] = None @@ -67,6 +69,19 @@ def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None + def __repr__(self): + # print block_id instead of KVCacheBlock object to avoid printing the + # KVCacheBlock object recursively. + prev_block_id = self.prev_free_block.block_id \ + if self.prev_free_block else None + next_block_id = self.next_free_block.block_id \ + if self.next_free_block else None + return (f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}), " + f"_block_hash={self._block_hash}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})") + """When a model needs different types of kv_caches (e.g., full attention + sliding window attention), the attention layers will be split to multiple @@ -259,6 +274,7 @@ def generate_block_hash_extra_keys( def hash_block_tokens( parent_block_hash: Optional[int], curr_block_token_ids: Sequence[int], + kv_cache_group_id: int, extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for @@ -279,19 +295,20 @@ def hash_block_tokens( The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. """ - return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), - tuple(curr_block_token_ids), extra_keys) + return BlockHashType( + hash((parent_block_hash, kv_cache_group_id, *curr_block_token_ids)), + tuple(curr_block_token_ids), kv_cache_group_id, extra_keys) def hash_request_tokens(block_size: int, request: Request, - group_id: int) -> List[BlockHashType]: + kv_cache_group_id: int) -> List[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. request: The request object. - group_id: TODO + kv_cache_group_id: The KV cache group that the blocks belong to Returns: The list of computed hash values. @@ -303,7 +320,8 @@ def hash_request_tokens(block_size: int, request: Request, "The number of multi-modal positions and hashes must match.") # TODO: Extend this to support other features such as LoRA. - need_mm_keys = bool(mm_positions) + need_extra_keys = bool(mm_positions) + extra_keys = None curr_mm_idx = 0 ret = [] @@ -315,19 +333,18 @@ def hash_request_tokens(block_size: int, request: Request, if len(block_token_ids) < block_size: break - extra_keys = [group_id] - # Add extra keys if the block is a multi-modal block. - if need_mm_keys: - mm_keys, curr_mm_idx = generate_block_hash_extra_keys( + if need_extra_keys: + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) - if mm_keys is not None: - extra_keys.extend(mm_keys) block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids, tuple(extra_keys)) + block_token_ids, kv_cache_group_id, + extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value + if start < 5 * block_size: + print("block_hash", start // block_size, block_hash) return ret @@ -518,7 +535,6 @@ def get_kv_cache_config(vllm_config: VllmConfig, available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model - TODO: support hybrid models with more than one type of KV cache. Args: vllm_config: The global VllmConfig From 0ecf3fa11c16eebbd7efbf454822020ef544abad Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 30 Jan 2025 23:39:47 -0800 Subject: [PATCH 17/48] update sliding window test Signed-off-by: Chen Zhang --- .../e2e/test_correctness_sliding_window.py | 13 +++-- .../v1/e2e/test_correctness_sliding_window.py | 47 ++++++++++++++++--- vllm/v1/core/kv_cache_utils.py | 2 - 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 415d0bd8237df..e287cf796c818 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -1,5 +1,5 @@ import random -from typing import List +from typing import List, Tuple import pytest @@ -118,7 +118,7 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, check_answers(indices, answer, test_texts) -def prep_prompts(batch_size: int): +def prep_prompts(batch_size: int, assign_range: Tuple[int, int] = (800, 1100)): """ Generate prompts which a bunch of assignments, then asking for the value of one of them. @@ -134,7 +134,7 @@ def prep_prompts(batch_size: int): indices.append(idx) prompt = "```python\n# We set a number of variables, " + \ f"x{idx} will be important later\n" - ln = random.randint(800, 1100) + ln = random.randint(*assign_range) for k in range(30, ln): v = random.randint(10, 99) if k == idx: @@ -146,7 +146,10 @@ def prep_prompts(batch_size: int): return prompts, answer, indices -def check_answers(indices: List[int], answer: List[int], outputs: List[str]): +def check_answers(indices: List[int], + answer: List[int], + outputs: List[str], + accept_rate=0.7): answer2 = [int(text[0:2].strip()) for text in outputs] print(list(zip(indices, zip(answer, answer2)))) numok = 0 @@ -155,7 +158,7 @@ def check_answers(indices: List[int], answer: List[int], outputs: List[str]): numok += 1 frac_ok = numok / len(answer) print(f"Num OK: {numok}/{len(answer)} {frac_ok}") - assert frac_ok > 0.7 + assert frac_ok >= accept_rate def check_window(prompts: List[str]): diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index bbeb79e1c6a11..3c43772492c29 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -1,12 +1,26 @@ -import random -from typing import List +from dataclasses import dataclass +from typing import List, Tuple from vllm import LLM, SamplingParams from ...core.block.e2e.test_correctness_sliding_window import (prep_prompts, check_answers) import pytest -@pytest.mark.parametrize("model", ["bigcode/starcoder2-3b"]) +@dataclass +class TestConfig: + sliding_window: int + assign_range: Tuple[int, int] + + +model_config = { + "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), + "google/gemma-2-2b-it": TestConfig(4096, (400, 800)), +} + + +# @pytest.mark.parametrize("model", +# ["bigcode/starcoder2-3b", "google/gemma-2-2b-it"]) +@pytest.mark.parametrize("model", ["google/gemma-2-2b-it"]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) def test_sliding_window_retrival(monkeypatch, model, batch_size, seed): @@ -20,11 +34,30 @@ def test_sliding_window_retrival(monkeypatch, model, batch_size, seed): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model, enable_prefix_caching=True) + test_config = model_config[model] + + llm = LLM(model=model) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - prompts, answer, indices = prep_prompts(batch_size) + prompts, answer, indices = prep_prompts( + batch_size, assign_range=test_config.assign_range) + + # both starcoder2-3b and gemma-2-2b-it have 4096 sliding window + check_window(prompts, llm, test_config.sliding_window) responses = llm.generate(prompts, sampling_params) - check_answers(indices, answer, - [response.outputs[0].text for response in responses]) + check_answers(indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0) + + +def check_window(prompts: List[str], llm: LLM, sliding_window: int): + tokenizer = llm.get_tokenizer() + max_model_len = llm.llm_engine.model_config.max_model_len + assert any( + len(tokenizer.encode(prompt)) > sliding_window + for prompt in prompts), "Prompt is too short for test" + assert all( + len(tokenizer.encode(prompt)) <= max_model_len + for prompt in prompts), "Prompt is too long for test" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 312e99d9808c1..8c5422a41de07 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -343,8 +343,6 @@ def hash_request_tokens(block_size: int, request: Request, extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value - if start < 5 * block_size: - print("block_hash", start // block_size, block_hash) return ret From 3998e924b5530b409ddd715c0cf57a3523667197 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 06:08:17 -0800 Subject: [PATCH 18/48] small fix Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 349ed3140131b..6f51191fff6a5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -531,7 +531,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): max_seq_len=max_seq_len, seq_start_loc=seq_start_loc, block_table=(self.input_batch.block_table.get_device_tensor()[ - i, :num_reqs]), + group_id, :num_reqs]), slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, From 446e99dc62536cfcae23d7d1603b110483c4bf62 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 06:13:25 -0800 Subject: [PATCH 19/48] small fix, can run gemma2 Signed-off-by: Chen Zhang --- tests/v1/e2e/test_correctness_sliding_window.py | 5 ++--- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 3c43772492c29..28d1fb192deae 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -18,9 +18,8 @@ class TestConfig: } -# @pytest.mark.parametrize("model", -# ["bigcode/starcoder2-3b", "google/gemma-2-2b-it"]) -@pytest.mark.parametrize("model", ["google/gemma-2-2b-it"]) +@pytest.mark.parametrize("model", + ["bigcode/starcoder2-3b", "google/gemma-2-2b-it"]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) def test_sliding_window_retrival(monkeypatch, model, batch_size, seed): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d1290d944cc00..7dc9e5472d71e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -531,7 +531,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): max_seq_len=max_seq_len, seq_start_loc=seq_start_loc, block_table=(self.input_batch.block_table.get_device_tensor()[ - i, :num_reqs]), + group_id, :num_reqs]), slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, From d97c1b0dadf0e3650cd44a524a55863c2fee3595 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 06:28:10 -0800 Subject: [PATCH 20/48] add test for range_intersect Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index f4081766e39a2..a71eae5a05f70 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2,6 +2,7 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams +from vllm.v1.core.hybrid_cache_manager.utils import PrefixLengthRange, intersect_ranges from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, @@ -243,3 +244,25 @@ def test_hash_request_tokens_no_mm_inputs(): assert block_hashes[0].extra_keys is None assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1].extra_keys is None + + +def test_prefix_length_range_intersection(): + range0 = [ + PrefixLengthRange(1, 5), + PrefixLengthRange(10, 14), + PrefixLengthRange(16, 18) + ] + range1 = [ + PrefixLengthRange(2, 6), + PrefixLengthRange(8, 12), + PrefixLengthRange(15, 17) + ] + range2 = [PrefixLengthRange(3, 11), PrefixLengthRange(13, 19)] + ranges = [range0, range1, range2] + + intersection = intersect_ranges(ranges) + assert intersection == [ + PrefixLengthRange(3, 5), + PrefixLengthRange(10, 11), + PrefixLengthRange(16, 17) + ] From 5ebfeac81448a1367b7db2f5d930d57816e04ef9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 07:14:49 -0800 Subject: [PATCH 21/48] clean up get_computed_blocks, append_slots, allocate_slots Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 79 +++++++++++++++++++------------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 1f481c4b4b7a9..277be95f50f23 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -50,21 +50,25 @@ def __init__( # the request gets N empty blocks, it starts to use the blocks without # further allocation. When it uses up all the N empty blocks, it gets # N new empty blocks. - # TODO: update comment + # NOTE(Chen): For simplicity, we keep the number of preallocated blocks + # the same for all kv cache groups, which will result in different + # preallocated tokens for different groups if their block sizes are + # different. self.num_preallocate_tokens = num_preallocate_tokens - # TODO: min or max? self.num_preallocate_blocks = cdiv( num_preallocate_tokens, - min(g.kv_cache_spec.block_size for g in kv_cache_config.groups)) + max(g.kv_cache_spec.block_size for g in kv_cache_config.groups)) self._null_block: KVCacheBlock = KVCacheBlock(-1) - # TODO(Chen): add comments + # Specialized managers for each kv cache group, which handle the + # different kv cache management logic of different attention layers. self.managers = get_managers( kv_cache_config, BlockPoolOperations(get_cached_block=self._get_cached_block, get_null_block=self.get_null_block), ) + self.num_kv_cache_groups = len(self.kv_cache_config.groups) # A Block pool of all kv-cache blocks. self.block_pool: List[KVCacheBlock] = [ @@ -117,32 +121,37 @@ def get_computed_blocks(self, for i, manager in enumerate(self.managers) ]) - computed_blocks: ReqKVCacheBlocks = [] # group_id->[blocks] - computed_tokens: List[PrefixLength] = [] # group_id->PrefixLength + computed_blocks: ReqKVCacheBlocks = [] # computed blocks of each group + prefix_length: List[PrefixLength] = [ + ] # possible cached prefix length of each group block_hashes = request.kv_block_hashes for i, manager in enumerate(self.managers): - computed_tokens_i, computed_blocks_i = ( + prefix_length_i, computed_blocks_i = ( manager.get_possible_cached_prefix(block_hashes[i])) computed_blocks.append(computed_blocks_i) - computed_tokens.append(computed_tokens_i) + prefix_length.append(prefix_length_i) if len(self.kv_cache_config.groups) == 1: # If there is only one group, we return the computed blocks and # tokens directly. - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = computed_tokens[0][-1].end + num_computed_tokens = prefix_length[0][-1].end else: - # find the common cached prefix of all groups. This path also works + # Find the common cached prefix of all groups. This path also works # for the single group case, but it is less efficient. num_computed_tokens = self._get_common_computed_tokens( - computed_tokens) + prefix_length) + + # Truncate the computed blocks to the number of computed tokens. + # E.g., group 0 has 3 computed blocks, and group 1 has 4 computed + # blocks with the same block size, we truncate both groups to 3 blocks. for i, manager in enumerate(self.managers): computed_blocks[i] = computed_blocks[i][:num_computed_tokens // manager.block_size] - self._free_blocks_for_sliding_window(computed_blocks, - num_computed_tokens) + + # Free the blocks that are not needed. E.g., sliding window layer + # with window size 2 and block size 1, we can change the computed + # blocks from [1, 2, 3] to [-1, 2, 3] (-1 refers to null block) + self._free_useless_blocks(computed_blocks, num_computed_tokens) return computed_blocks, num_computed_tokens def append_slots( @@ -162,10 +171,10 @@ def append_slots( The new blocks if new blocks are allocated, or None if new blocks are required but cannot be allocated. """ - # we can free blocks even if we cannot schedule it - self._free_blocks_for_sliding_window( - self.req_to_blocks[request.request_id], - request.num_computed_tokens) + # We can free blocks that are no longer needed even if we cannot + # schedule this request due to the limit of free blocks. + self._free_useless_blocks(self.req_to_blocks[request.request_id], + request.num_computed_tokens) req_blocks = self.req_to_blocks[request.request_id] num_new_blocks = [ @@ -180,16 +189,16 @@ def append_slots( # slots, but we cannot allocate new blocks due to the limit. return None - # TODO(Chen): add comments + # Truncate the number of pre-allocated blocks to ensure that we can + # have at least `num_new_blocks` free blocks for each group. num_preallocate_blocks = min( self.num_preallocate_blocks, (self.free_block_queue.num_free_blocks - total_new_blocks) // len(self.managers)) - new_blocks = [] + new_blocks: ReqKVCacheBlocks = [] - for i in range(len(self.kv_cache_config.groups) - ): # TODO: self.num_kv_cache_groups + for i in range(self.num_kv_cache_groups): if num_new_blocks[i] <= 0: # No new block is needed. new_blocks.append([]) @@ -205,7 +214,10 @@ def append_slots( # num_prompt_tokens + max_tokens > max_model_len. self.max_num_blocks_per_req[i] - len(req_blocks[i]), ) - assert num_block_to_allocate > 0 + + assert num_block_to_allocate >= 0 + assert num_block_to_allocate <= \ + self.free_block_queue.num_free_blocks new_blocks_of_group = self._get_new_blocks( num_block_to_allocate) @@ -293,7 +305,8 @@ def allocate_slots( "Computed blocks should be empty when " "prefix caching is disabled") - # TODO(Chen): add comments + # Truncate the number of pre-allocated blocks to ensure that we can + # have at least `num_new_blocks` free blocks for each group. num_preallocate_blocks = min( self.num_preallocate_blocks, (self.free_block_queue.num_free_blocks - total_new_blocks) // @@ -314,7 +327,9 @@ def allocate_slots( # num_prompt_tokens + max_tokens > max_model_len. self.max_num_blocks_per_req[i] - len(computed_blocks[i]), ) - assert num_block_to_allocate > 0 + assert num_block_to_allocate >= 0 + assert num_block_to_allocate <= \ + self.free_block_queue.num_free_blocks new_blocks_of_group = self._get_new_blocks(num_block_to_allocate) new_blocks.append(new_blocks_of_group) @@ -640,10 +655,10 @@ def _cache_full_blocks( def get_null_block(self) -> KVCacheBlock: return self._null_block - def _get_common_computed_tokens( - self, computed_tokens: List[PrefixLength]) -> int: + def _get_common_computed_tokens(self, + prefix_length: List[PrefixLength]) -> int: # TODO: add comments: the largest in the intersection, and alignment - intersection = intersect_ranges(computed_tokens) + intersection = intersect_ranges(prefix_length) # Since incomplete blocks are not eligible for sharing, # `num_computed_tokens` should be a multiple of `block_size` of @@ -660,8 +675,8 @@ def _get_common_computed_tokens( return num_computed_tokens - def _free_blocks_for_sliding_window(self, req_blocks: ReqKVCacheBlocks, - num_computed_tokens: int) -> None: + def _free_useless_blocks(self, req_blocks: ReqKVCacheBlocks, + num_computed_tokens: int) -> None: # NOTE(Chen): do all free before allocation to make less eviction # req_blocks = self.req_to_blocks[request.request_id] removed_blocks = [] From 4e0dc489619fa80a5c5392589dba674ba4471ee0 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 08:06:26 -0800 Subject: [PATCH 22/48] finish the clean up of kv cache manager Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 113 +++++++++++++++---------------- 1 file changed, 56 insertions(+), 57 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 277be95f50f23..c60d5dcb29a06 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -148,9 +148,7 @@ def get_computed_blocks(self, computed_blocks[i] = computed_blocks[i][:num_computed_tokens // manager.block_size] - # Free the blocks that are not needed. E.g., sliding window layer - # with window size 2 and block size 1, we can change the computed - # blocks from [1, 2, 3] to [-1, 2, 3] (-1 refers to null block) + # Free the blocks that are not needed. self._free_useless_blocks(computed_blocks, num_computed_tokens) return computed_blocks, num_computed_tokens @@ -173,6 +171,8 @@ def append_slots( """ # We can free blocks that are no longer needed even if we cannot # schedule this request due to the limit of free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. self._free_useless_blocks(self.req_to_blocks[request.request_id], request.num_computed_tokens) req_blocks = self.req_to_blocks[request.request_id] @@ -360,64 +360,44 @@ def allocate_slots( return new_blocks - def _get_ordered_blocks_one_kv_cache_group( - self, blocks: KVCacheBlocks) -> Iterable[KVCacheBlock]: - # TODO (Chen): rethink where to do the reverse operation - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) - return ordered_blocks + def _merge_blocks_by_eviction_order( + self, blocks: ReqKVCacheBlocks) -> List[KVCacheBlock]: + """ + Merge the blocks of different groups to one list. The returned blocks + are sorted by eviction order, with the first block having the highest + eviction priority. + + Args: + blocks: the blocks of each kv cache group, ordered by eviction + priority. + + Returns: + A list of KVCacheBlocks sorted by eviction order. + """ - def _get_ordered_blocks_multiple_kv_cache_groups( - self, blocks: ReqKVCacheBlocks) -> Iterable[KVCacheBlock]: - # Fast path: if all blocks are empty, return. This will happen during - # append_slots - blocks = [b for b in blocks if len(b) > 0] - if len(blocks) == 0: - return [] - # Free blocks in reverse order so that the tail blocks are - # freed first. if self.enable_caching: - # TODO(Chen): add comments - # merge blocks from different groups based on the block size - block_size_set = set(manager.block_size - for manager in self.managers) - if len(block_size_set) == 1: - # O(n) time complexity if block_size of all groups are the same - ordered_blocks = [] - for i in range(len(blocks[0]) - 1, -1, -1): - for blocks_of_group in blocks: + # NOTE (Chen): A simple strategy that interleaves the blocks of + # different KV cache groups. We can investigate more advanced + # strategies in the future. + ordered_blocks = [] + max_len = max(len(blocks_of_group) for blocks_of_group in blocks) + for i in range(max_len): + for blocks_of_group in blocks: + if i < len(blocks_of_group): ordered_blocks.append(blocks_of_group[i]) - else: - # O(n * log(n)) time complexity - # TODO(Chen): optimize it to O(n*len(self.managers)) time complexity - # NOTE: untested - ordered_blocks_with_key = [] - - for i, blocks_of_group in enumerate(blocks): - block_size = self.managers[i].block_size - for i, block in enumerate(blocks_of_group): - ordered_blocks_with_key.append((block_size * i, block)) - - ordered_blocks_with_key.sort(reverse=True) - ordered_blocks = [ - block for _, block in ordered_blocks_with_key - ] else: - # TODO: need to implement this path - raise NotImplementedError + ordered_blocks = [] + for blocks_of_group in blocks: + ordered_blocks.extend(blocks_of_group) return ordered_blocks def _free_blocks(self, blocks: ReqKVCacheBlocks) -> None: if len(self.kv_cache_config.groups) == 1: - ordered_blocks = self._get_ordered_blocks_one_kv_cache_group( - blocks[0]) + # Fast path for single kv cache group models. + ordered_blocks = blocks[0] else: - ordered_blocks = self._get_ordered_blocks_multiple_kv_cache_groups( - blocks) + ordered_blocks = self._merge_blocks_by_eviction_order(blocks) for block in ordered_blocks: if block == self._null_block: continue @@ -439,7 +419,9 @@ def free(self, request: Request) -> None: # This request is freed before alloc. just return return else: - self._free_blocks(blocks) + # Reverse the blocks so that the tail blocks can have higher + # eviction priority. + self._free_blocks([list(reversed(blks)) for blks in blocks]) def get_num_common_prefix_blocks( self, @@ -657,7 +639,17 @@ def get_null_block(self) -> KVCacheBlock: def _get_common_computed_tokens(self, prefix_length: List[PrefixLength]) -> int: - # TODO: add comments: the largest in the intersection, and alignment + """ + Find a prefix that is cached by all KV cache groups. Returns the number + of tokens of that prefix. + + Args: + prefix_length (List[PrefixLength]): The valid cached prefix lengths + of each KV cache group. + + Returns: + The number of tokens of the common prefix. + """ intersection = intersect_ranges(prefix_length) # Since incomplete blocks are not eligible for sharing, @@ -666,6 +658,7 @@ def _get_common_computed_tokens(self, alignment = math.lcm( *[manager.block_size for manager in self.managers]) + # Get the longest common prefix that is aligned with the block size. num_computed_tokens = 0 for range_ in intersection: aligned_end = cdiv(range_.end, alignment) * alignment @@ -677,13 +670,19 @@ def _get_common_computed_tokens(self, def _free_useless_blocks(self, req_blocks: ReqKVCacheBlocks, num_computed_tokens: int) -> None: - # NOTE(Chen): do all free before allocation to make less eviction - # req_blocks = self.req_to_blocks[request.request_id] + """ + Frees memory blocks that are not needed. E.g., sliding window + layer with window size 2 and block size 1, we have req_blocks as + [[1, 2, 3]], this function will free block 1 and change the req_blocks + to [[-1, 2, 3]] (-1 refers to null block) + + Args: + req_blocks: The KV cache blocks of one request. + num_computed_tokens: The number of computed tokens. + """ removed_blocks = [] for manager, req_blocks_of_group in zip(self.managers, req_blocks): removed_blocks.append( manager.remove_useless_blocks(req_blocks_of_group, num_computed_tokens)) - # TODO: better handling of free order (e.g., this order have problem - # when different layer has different sliding window size) self._free_blocks(removed_blocks) From cd4f8e21c8d699ae53446e55ead1a53e2acd37f2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 08:49:45 -0800 Subject: [PATCH 23/48] clean up the code Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 3 +- tests/v1/core/test_specialized_manager.py | 8 +++- .../v1/e2e/test_correctness_sliding_window.py | 9 +++-- .../specialized_manager.py | 14 ++++--- vllm/v1/core/kv_cache_manager.py | 24 ++++++----- vllm/v1/core/kv_cache_utils.py | 40 ++++++++++++++----- vllm/v1/engine/core.py | 3 +- 7 files changed, 67 insertions(+), 34 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index a71eae5a05f70..c31d478527beb 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2,7 +2,8 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams -from vllm.v1.core.hybrid_cache_manager.utils import PrefixLengthRange, intersect_ranges +from vllm.v1.core.hybrid_cache_manager.utils import (PrefixLengthRange, + intersect_ranges) from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 959972521461a..4aad3fa6b638f 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -1,6 +1,10 @@ from collections import deque +from typing import Deque + import torch -from vllm.v1.core.hybrid_cache_manager.specialized_manager import BlockPoolOperations, SlidingWindowManager + +from vllm.v1.core.hybrid_cache_manager.specialized_manager import ( + BlockPoolOperations, SlidingWindowManager) from vllm.v1.core.hybrid_cache_manager.utils import PrefixLengthRange from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.kv_cache_interface import SlidingWindowSpec @@ -15,7 +19,7 @@ def test_sliding_window_possible_cached_prefix(): sliding_window=4, ) - block_pool_result = deque() + block_pool_result: Deque[bool] = deque() null_block = KVCacheBlock(-1, 0) def get_cached_block(_block_hash): diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 28d1fb192deae..2bc6d3d1712fc 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -1,10 +1,13 @@ from dataclasses import dataclass from typing import List, Tuple -from vllm import LLM, SamplingParams -from ...core.block.e2e.test_correctness_sliding_window import (prep_prompts, - check_answers) + import pytest +from vllm import LLM, SamplingParams + +from ...core.block.e2e.test_correctness_sliding_window import (check_answers, + prep_prompts) + @dataclass class TestConfig: diff --git a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py index c61ff83ac0efe..db8f20e2d0c93 100644 --- a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py +++ b/vllm/v1/core/hybrid_cache_manager/specialized_manager.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod -from collections import deque from dataclasses import dataclass from itertools import chain -from typing import Callable, Deque, Dict, List, Optional, Tuple, TypedDict +from typing import Callable, Dict, List, Optional, Tuple, Type + from vllm.utils import cdiv -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec +from vllm.v1.core.hybrid_cache_manager.utils import (PrefixLength, + PrefixLengthRange) from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock -from vllm.v1.core.hybrid_cache_manager.utils import PrefixLength, PrefixLengthRange +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec, SlidingWindowSpec) from vllm.v1.utils import ConstantList @@ -149,7 +151,7 @@ def get_possible_cached_prefix( # TODO: check the hit every num_block_sliding_window blocks, to optimize # the time complexity from O(num_block) to # O(num_block / num_block_sliding_window) + O(num_computed_block), - # which is good for low cache hit rate senarios. + # which is good for low cache hit rate scenarios. start = 0 ranges = [] computed_blocks: List[KVCacheBlock] = [] @@ -202,7 +204,7 @@ def remove_useless_blocks(self, block_table: List[KVCacheBlock], return removed_blocks -spec_manager_map: Dict[KVCacheSpec, SpecializedManager] = { +spec_manager_map: Dict[Type[KVCacheSpec], Type[SpecializedManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager } diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c60d5dcb29a06..5fcf6fd4888f4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,17 +1,18 @@ -from collections import defaultdict import math -from typing import Dict, Iterable, List, Optional, Tuple +from collections import defaultdict +from typing import Dict, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.hybrid_cache_manager.specialized_manager import BlockPoolOperations, get_managers +from vllm.v1.core.hybrid_cache_manager.specialized_manager import ( + BlockPoolOperations, get_managers) +from vllm.v1.core.hybrid_cache_manager.utils import (PrefixLength, + intersect_ranges) from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, KVCacheBlocks, - ReqKVCacheBlocks, + KVCacheBlock, ReqKVCacheBlocks, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) -from vllm.v1.core.hybrid_cache_manager.utils import PrefixLength, intersect_ranges from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus @@ -207,7 +208,7 @@ def append_slots( # preallocated blocks. num_block_to_allocate = min( num_new_blocks[i] + num_preallocate_blocks, - # Should not exceed the maximum number of blocks per request. + # Should not exceed the maximum number of blocks per request # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. # TODO(woosuk): Check and reject requests if @@ -351,7 +352,8 @@ def allocate_slots( self._cache_full_blocks( request=request, blk_start_idx=len(computed_blocks[i]), - # The new full blocks are the full blocks that are not computed. + # The new full blocks are the full blocks that are not + # computed. full_blocks=new_full_blocks, prev_block=computed_blocks[i][-1] if computed_blocks[i] else None, @@ -640,15 +642,15 @@ def get_null_block(self) -> KVCacheBlock: def _get_common_computed_tokens(self, prefix_length: List[PrefixLength]) -> int: """ - Find a prefix that is cached by all KV cache groups. Returns the number - of tokens of that prefix. + Find the longest prefix that is cached by all KV cache groups. Returns + the number of tokens in that prefix. Args: prefix_length (List[PrefixLength]): The valid cached prefix lengths of each KV cache group. Returns: - The number of tokens of the common prefix. + The number of tokens in the common prefix. """ intersection = intersect_ranges(prefix_length) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 8c5422a41de07..88bc7aeec31b5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,8 +1,8 @@ """KV-Cache Utilities.""" +import math from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass -import math from typing import Any, Dict, List, NamedTuple, Optional, Tuple from vllm.config import VllmConfig @@ -397,7 +397,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: Dict[str, KVCacheSpec]) -> bool: return len(layer_keys) == 1 -def is_kv_cache_page_size_uniform(kv_cache_spec: KVCacheSpec): +def is_kv_cache_page_size_uniform( + kv_cache_spec: Dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same page size. @@ -490,18 +491,28 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, def _get_kv_cache_config_uniform_page_size( vllm_config: VllmConfig, kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: - # Grouped allocation - # TODO(Chen): explain it, need test + """ + Generates the KV cache configuration for a model with one page size. - # Group all layers by type_id + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The KVCacheSpec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + # Group all layers by type_id. + # E.g., 2 full attention layers and 4 sliding window attention layers, + # -> (full.0, full.1), (sw.0, sw.1, sw.2, sw.3). same_type_layers: Dict[str, List[str]] = defaultdict(list) for layer_name, layer_spec in kv_cache_spec.items(): same_type_layers[layer_spec.type_id].append(layer_name) # Split each group into smaller groups, to make the number of layers in - # each group identical - # E.g., 2 full attention layers and 4 sliding window attention layers, - # split from (full * 2), (sw * 4) to (full * 2), (sw * 2), (sw * 2). + # each group identical. + # E.g., (full.0, full.1), (sw.0, sw.1, sw.2, sw.3) is split to 3 groups: + # (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3). group_size_gcd = math.gcd( *[len(layers) for layers in same_type_layers.values()]) grouped_layers = [] @@ -509,7 +520,10 @@ def _get_kv_cache_config_uniform_page_size( for i in range(0, len(layers), group_size_gcd): grouped_layers.append(layers[i:i + group_size_gcd]) - # TODO: explain it + # Divide the available memory equally among all layers in the first group. + # The memory layout in the example will be: + # full.0: Tensor with size=available_memory//2 + # full.1: Tensor with size=available_memory//2 kv_cache_spec_first_group = { layer_name: kv_cache_spec[layer_name] for layer_name in grouped_layers[0] @@ -517,6 +531,12 @@ def _get_kv_cache_config_uniform_page_size( kv_cache_config = _get_kv_cache_config_uniform_type( vllm_config, kv_cache_spec_first_group, available_memory) + # Reuse the KV cache tensors of the first group for the other groups. + # The memory layout in the example will be: + # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 + # full.1, sw.1, sw.3: share another Tensor with size=available_memory//2 + # Layers of different groups have different block table, so they will + # use different parts of the shared Tensor. for layers in grouped_layers[1:]: for layer_name, layer_name_first_group in zip(layers, grouped_layers[0]): @@ -549,7 +569,7 @@ def get_kv_cache_config(vllm_config: VllmConfig, return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) elif is_kv_cache_page_size_uniform(kv_cache_spec): - # TODO: add comments + # KV cache of all layers have the same page size. return _get_kv_cache_config_uniform_page_size(vllm_config, kv_cache_spec, available_memory) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c3bb149a1efb5..db24695f2d373 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -4,7 +4,7 @@ import threading import time from multiprocessing.connection import Connection -from typing import List, Tuple, Type +from typing import List, Type import psutil import zmq @@ -78,6 +78,7 @@ def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig: # Get the kv cache tensor size kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, availble_gpu_memory) + print("kv_cache_config", kv_cache_config) # Initialize kv cache and warmup the execution self.model_executor.initialize(kv_cache_config) From 2d7bbca4cd59f7813961f41ececdbd077147886d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 20:04:25 -0800 Subject: [PATCH 24/48] fix some tests Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 15 ++- tests/v1/core/test_prefix_caching.py | 134 +++++++++++++++------------ vllm/v1/core/kv_cache_manager.py | 4 +- vllm/v1/core/scheduler.py | 24 +++-- 4 files changed, 104 insertions(+), 73 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index c31d478527beb..9d8d6d0097a8b 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -49,7 +49,9 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3)) + block_hash = BlockHashType(hash_value=123, + kv_cache_group_id=0, + token_ids=(1, 2, 3)) block.block_hash = block_hash assert block.block_hash == block_hash @@ -190,11 +192,11 @@ def test_hash_block_tokens(): curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") - block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids, + block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids, 0, extra_keys) assert isinstance(block_hash, BlockHashType) assert block_hash.hash_value == hash( - (parent_block_hash, *curr_block_token_ids)) + (parent_block_hash, 0, *curr_block_token_ids)) assert block_hash.token_ids == curr_block_token_ids assert block_hash.extra_keys == extra_keys @@ -214,7 +216,7 @@ def test_hash_request_tokens(): ) block_size = 3 - block_hashes = hash_request_tokens(block_size, request) + block_hashes = hash_request_tokens(block_size, request, 0) assert len(block_hashes) == 2 assert isinstance(block_hashes[0], BlockHashType) @@ -238,7 +240,7 @@ def test_hash_request_tokens_no_mm_inputs(): ) block_size = 3 - block_hashes = hash_request_tokens(block_size, request) + block_hashes = hash_request_tokens(block_size, request, 0) assert len(block_hashes) == 2 assert block_hashes[0].token_ids == (0, 1, 2) @@ -267,3 +269,6 @@ def test_prefix_length_range_intersection(): PrefixLengthRange(10, 11), PrefixLengthRange(16, 17) ] + + +# TODO: add tests for hash of kv_cache_group_id diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c1111045e92d2..6a3316557d688 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,11 +1,13 @@ """Compare the with and without prefix caching.""" import pytest +import torch from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroup def make_request(request_id, @@ -31,12 +33,21 @@ def make_request(request_id, ) +def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + tensors={}, + groups=[ + KVCacheGroup(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32)) + ], + ) + + def test_prefill(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -53,14 +64,15 @@ def test_prefill(): assert len(req0.kv_block_hashes[0]) == 3 assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) + blocks = manager.allocate_slots(req0, 55, computed_blocks, + num_computed_tokens) assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] # Check full block metadata parent_block_hash = None for block_id in (0, 1, 2): block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) - block_hash = hash_block_tokens(parent_block_hash, block_tokens) + block_hash = hash_block_tokens(parent_block_hash, block_tokens, 0) assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value @@ -79,7 +91,8 @@ def test_prefill(): assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) + blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks, + num_computed_tokens) assert [b.block_id for b in blocks[0]] == [5, 6] for block in computed_blocks[0]: assert block.ref_cnt == 2 @@ -110,7 +123,8 @@ def test_prefill(): assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) + blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks, + num_computed_tokens) assert [b.block_id for b in blocks[0]] == [7, 8] # Although we only have 5 free blocks, we have 8 blocks in @@ -129,7 +143,8 @@ def test_prefill(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks, + num_computed_tokens) # This block ID order also checks the eviction order. assert [b.block_id for b in blocks[0]] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert manager.free_block_queue.num_free_blocks == 0 @@ -139,10 +154,8 @@ def test_prefill(): def test_decode(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -157,7 +170,8 @@ def test_decode(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) + blocks = manager.allocate_slots(req0, 55, computed_blocks, + num_computed_tokens) assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] # Append slots without allocating a new block. @@ -166,7 +180,7 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.append_slots(req0, 4) assert new_blocks is not None and len(new_blocks[0]) == 0 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is None + assert manager.req_to_blocks[req0.request_id][0][-2].block_hash is None # Append slots without allocating a new block, but start using the # preallocated block. @@ -177,7 +191,7 @@ def test_decode(): req0.append_output_token_ids(7) new_blocks = manager.append_slots(req0, 15) assert new_blocks is not None and len(new_blocks[0]) == 0 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None + assert manager.req_to_blocks[req0.request_id][0][-2].block_hash is not None # Append slots with allocating a new block. req0.num_computed_tokens = 74 @@ -192,10 +206,8 @@ def test_decode(): def test_evict(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -205,7 +217,8 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) + blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 7 # 5 full + 1 partial + 1 preallocated # 3 blocks. @@ -214,7 +227,8 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) + blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 @@ -232,7 +246,8 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert [b.block_id for b in computed_blocks[0]] == [0, 1] assert num_computed_tokens == 2 * 16 - blocks = manager.allocate_slots(req2, 3, computed_blocks) + blocks = manager.allocate_slots(req2, 3, computed_blocks, + num_computed_tokens) assert [b.block_id for b in blocks[0]] == [6, 5] assert manager.free_block_queue.num_free_blocks == 6 @@ -244,10 +259,8 @@ def test_hash_block_correct_reuse(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=1, + make_kv_cache_config(block_size, 1), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -258,7 +271,8 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req, num_tokens, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 1 # Deallocate the block. @@ -270,7 +284,8 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) + blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 1 assert manager.block_pool[blocks[0][0].block_id].block_hash is None @@ -283,10 +298,8 @@ def test_computed_blocks_not_evicted(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=2, + make_kv_cache_config(block_size, 2), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -297,7 +310,8 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req0, num_tokens, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 1 assert blocks[0][0].block_id == 0 @@ -306,7 +320,8 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req1, num_tokens, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 1 assert blocks[0][0].block_id == 1 @@ -323,7 +338,7 @@ def test_computed_blocks_not_evicted(): assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, - computed_blocks) + computed_blocks, num_computed_tokens) assert len(blocks[0]) == 1 assert blocks[0][0].block_id == 1 @@ -334,10 +349,8 @@ def test_basic_prefix_caching_disabled(): """ block_size = 4 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=4, + make_kv_cache_config(block_size, 4), max_model_len=8192, - sliding_window=None, enable_caching=False, num_preallocate_tokens=0, ) @@ -347,7 +360,8 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 10, computed_blocks) + blocks = manager.allocate_slots(req1, 10, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 3 # Free the blocks. @@ -358,7 +372,8 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 16, computed_blocks) + blocks = manager.allocate_slots(req2, 16, computed_blocks, + num_computed_tokens) assert len(blocks[0]) == 4 # New requests should not have any blocks. @@ -366,7 +381,8 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 4, computed_blocks) + blocks = manager.allocate_slots(req3, 4, computed_blocks, + num_computed_tokens) assert not blocks @@ -377,10 +393,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): This tests that the preallocated blocks are correctly added. """ manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=10, + make_kv_cache_config(block_size, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=num_preallocate_tokens, ) @@ -391,7 +405,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): assert not computed_blocks[0] assert num_computed_tokens == 0 # Just ask for 1 block. - blocks = manager.allocate_slots(req, block_size, computed_blocks) + blocks = manager.allocate_slots(req, block_size, computed_blocks, + num_computed_tokens) req.num_computed_tokens = block_size assert len(blocks[0]) == 1 + num_preallocated_blocks @@ -411,10 +426,8 @@ def test_cache_blocks(): """ block_size = 4 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=5, + make_kv_cache_config(block_size, 5), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -458,10 +471,8 @@ def test_mm_prefix_caching(): This tests that the multi-modal prefix caching is correct. """ manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -499,7 +510,8 @@ def test_mm_prefix_caching(): assert req0.kv_block_hashes[0][1].extra_keys == ("aaa", "bbb") assert req0.kv_block_hashes[0][2].extra_keys == ("bbb", ) - blocks = manager.allocate_slots(req0, 59, computed_blocks) + blocks = manager.allocate_slots(req0, 59, computed_blocks, + num_computed_tokens) assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] req0.num_computed_tokens = 59 @@ -538,10 +550,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=10, + make_kv_cache_config(block_size, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -552,21 +562,21 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req0, 48, computed_blocks) + manager.allocate_slots(req0, 48, computed_blocks, num_computed_tokens) block_part0 = manager.req_to_blocks[req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert computed_blocks[0] == block_part0 + assert computed_blocks == block_part0 assert num_computed_tokens == 3 * 16 - manager.allocate_slots(req1, 48, computed_blocks) + manager.allocate_slots(req1, 48, computed_blocks, num_computed_tokens) block_part1 = manager.req_to_blocks[req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) - assert {block.ref_cnt for block in block_part1[:3]} == {1} - assert {block.ref_cnt for block in block_part1[3:]} == {0} + assert {block.ref_cnt for block in block_part1[0][:3]} == {1} + assert {block.ref_cnt for block in block_part1[0][3:]} == {0} # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | @@ -574,7 +584,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req2, block_size * 2, computed_blocks) + manager.allocate_slots(req2, block_size * 2, computed_blocks, + num_computed_tokens) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -582,11 +593,12 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert manager.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert computed_blocks[0] == block_part1 + assert computed_blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. - assert manager.allocate_slots(req3, 48, computed_blocks) is None + assert manager.allocate_slots(req3, 48, computed_blocks, + num_computed_tokens) is None # Block 0-2 are used by Req 1. - assert {block.ref_cnt for block in block_part1[:3]} == {1} + assert {block.ref_cnt for block in block_part1[0][:3]} == {1} # Block 3-5 are free. - assert {block.ref_cnt for block in block_part1[3:]} == {0} + assert {block.ref_cnt for block in block_part1[0][3:]} == {0} diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 5fcf6fd4888f4..9c2599f6ebf82 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -260,6 +260,7 @@ def allocate_slots( request: Request, num_tokens: int, computed_blocks: ReqKVCacheBlocks, + num_computed_tokens: int, ) -> Optional[ReqKVCacheBlocks]: """Allocate slots for a new request. @@ -268,6 +269,7 @@ def allocate_slots( num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. computed_blocks: The computed blocks. + num_computed_tokens: The number of computed tokens. Returns: The new blocks if new blocks are allocated, or None if new blocks @@ -285,7 +287,7 @@ def allocate_slots( if blk.ref_cnt == 0) num_new_blocks = [ - manager.get_num_new_blocks(request.num_computed_tokens, num_tokens, + manager.get_num_new_blocks(num_computed_tokens, num_tokens, len(computed_blocks_of_group)) for manager, computed_blocks_of_group in zip( self.managers, computed_blocks) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f17e33bedf0fb..cc51d503eeda3 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -10,7 +10,7 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -203,7 +203,6 @@ def schedule(self) -> "SchedulerOutput": # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens if num_new_tokens == 0: - raise NotImplementedError # The happens when prompt length is divisible by the block # size and all blocks are cached. Now we force to recompute # the last block. Note that we have to re-compute an entire @@ -211,9 +210,21 @@ def schedule(self) -> "SchedulerOutput": # is always a multiple of the block size. This limitation # can potentially be removed in the future to slightly # improve the performance. - num_computed_tokens -= self.block_size - num_new_tokens = self.block_size - computed_blocks.pop() + kv_groups = self.kv_cache_manager.kv_cache_config.groups + if len(kv_groups) > 1 or \ + not isinstance(kv_groups[0].kv_cache_spec, + FullAttentionSpec): + # It is difficult to handle the last block problem + # for hybrid models. Ignore all computed tokens as + # a temporary solution. + num_computed_tokens = 0 + num_new_tokens = request.num_tokens + computed_blocks = [[] for _ in kv_groups] + else: + block_size = kv_groups[0].kv_cache_spec.block_size + num_computed_tokens -= block_size + num_new_tokens = block_size + computed_blocks[0].pop() num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -227,7 +238,8 @@ def schedule(self) -> "SchedulerOutput": break new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens, computed_blocks, + num_computed_tokens) if new_blocks is None: # The request cannot be scheduled. break From 68fe2db3c28deb837481135ef92793d54f0fd85c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 20:11:56 -0800 Subject: [PATCH 25/48] remove print kvcacheconfig Signed-off-by: Chen Zhang --- vllm/v1/engine/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index db24695f2d373..66139f458e537 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -78,7 +78,6 @@ def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig: # Get the kv cache tensor size kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, availble_gpu_memory) - print("kv_cache_config", kv_cache_config) # Initialize kv cache and warmup the execution self.model_executor.initialize(kv_cache_config) From 30e983744e46418c077df019250e0e9f4258dea5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 31 Jan 2025 20:23:35 -0800 Subject: [PATCH 26/48] move files Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 5 +- tests/v1/core/test_specialized_manager.py | 8 +-- vllm/v1/core/hybrid_cache_manager/utils.py | 67 ------------------- vllm/v1/core/kv_cache_manager.py | 9 ++- vllm/v1/core/kv_cache_utils.py | 65 ++++++++++++++++++ .../specialized_manager.py | 5 +- 6 files changed, 77 insertions(+), 82 deletions(-) delete mode 100644 vllm/v1/core/hybrid_cache_manager/utils.py rename vllm/v1/core/{hybrid_cache_manager => }/specialized_manager.py (97%) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 9d8d6d0097a8b..1c64ff1958c3f 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2,13 +2,12 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams -from vllm.v1.core.hybrid_cache_manager.utils import (PrefixLengthRange, - intersect_ranges) from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens) + hash_request_tokens, + PrefixLengthRange, intersect_ranges) from vllm.v1.request import Request diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 4aad3fa6b638f..969b18afbe977 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,10 +3,10 @@ import torch -from vllm.v1.core.hybrid_cache_manager.specialized_manager import ( - BlockPoolOperations, SlidingWindowManager) -from vllm.v1.core.hybrid_cache_manager.utils import PrefixLengthRange -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.specialized_manager import (BlockPoolOperations, + SlidingWindowManager) +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + PrefixLengthRange) from vllm.v1.kv_cache_interface import SlidingWindowSpec diff --git a/vllm/v1/core/hybrid_cache_manager/utils.py b/vllm/v1/core/hybrid_cache_manager/utils.py deleted file mode 100644 index 5d89de60cea54..0000000000000 --- a/vllm/v1/core/hybrid_cache_manager/utils.py +++ /dev/null @@ -1,67 +0,0 @@ -from dataclasses import dataclass -from typing import List - - -@dataclass -class PrefixLengthRange: - """ - [start, end] - """ - start: int - end: int - - -PrefixLength = List[PrefixLengthRange] - - -def intersect_two_ranges( - a: List[PrefixLengthRange], - b: List[PrefixLengthRange]) -> List[PrefixLengthRange]: - """ - Intersect two sorted lists of PrefixLengthRange intervals. - - Args: - a: List of intervals - b: List of intervals - Returns: - List of intervals that are intersections of a and b - """ - i, j = 0, 0 - result = [] - - while i < len(a) and j < len(b): - overlap_start = max(a[i].start, b[j].start) - overlap_end = min(a[i].end, b[j].end) - - if overlap_start <= overlap_end: - result.append(PrefixLengthRange(overlap_start, overlap_end)) - - if a[i].end < b[j].end: - i += 1 - else: - j += 1 - - return result - - -def intersect_ranges( - ranges: List[List[PrefixLengthRange]]) -> List[PrefixLengthRange]: - """ - Intersect multiple lists of PrefixLengthRange intervals, each is sorted. - - Args: - ranges: A list of lists of intervals - Returns: - A list of intervals representing the intersection of all ranges - """ - if not ranges: - return [] - - current_intersection = ranges[0] - for i in range(1, len(ranges)): - current_intersection = intersect_two_ranges(current_intersection, - ranges[i]) - if not current_intersection: - break - - return current_intersection diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 9c2599f6ebf82..4672e0eded793 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,15 +4,14 @@ from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.hybrid_cache_manager.specialized_manager import ( - BlockPoolOperations, get_managers) -from vllm.v1.core.hybrid_cache_manager.utils import (PrefixLength, - intersect_ranges) +from vllm.v1.core.specialized_manager import (BlockPoolOperations, + get_managers) from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, ReqKVCacheBlocks, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens) + hash_request_tokens, PrefixLength, + intersect_ranges) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 88bc7aeec31b5..b8cc725f1ec78 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -575,3 +575,68 @@ def get_kv_cache_config(vllm_config: VllmConfig, available_memory) else: raise NotImplementedError + + +@dataclass +class PrefixLengthRange: + """ + [start, end] + """ + start: int + end: int + + +PrefixLength = List[PrefixLengthRange] + + +def intersect_two_ranges( + a: List[PrefixLengthRange], + b: List[PrefixLengthRange]) -> List[PrefixLengthRange]: + """ + Intersect two sorted lists of PrefixLengthRange intervals. + + Args: + a: List of intervals + b: List of intervals + Returns: + List of intervals that are intersections of a and b + """ + i, j = 0, 0 + result = [] + + while i < len(a) and j < len(b): + overlap_start = max(a[i].start, b[j].start) + overlap_end = min(a[i].end, b[j].end) + + if overlap_start <= overlap_end: + result.append(PrefixLengthRange(overlap_start, overlap_end)) + + if a[i].end < b[j].end: + i += 1 + else: + j += 1 + + return result + + +def intersect_ranges( + ranges: List[List[PrefixLengthRange]]) -> List[PrefixLengthRange]: + """ + Intersect multiple lists of PrefixLengthRange intervals, each is sorted. + + Args: + ranges: A list of lists of intervals + Returns: + A list of intervals representing the intersection of all ranges + """ + if not ranges: + return [] + + current_intersection = ranges[0] + for i in range(1, len(ranges)): + current_intersection = intersect_two_ranges(current_intersection, + ranges[i]) + if not current_intersection: + break + + return current_intersection diff --git a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py b/vllm/v1/core/specialized_manager.py similarity index 97% rename from vllm/v1/core/hybrid_cache_manager/specialized_manager.py rename to vllm/v1/core/specialized_manager.py index db8f20e2d0c93..c0f626218fa15 100644 --- a/vllm/v1/core/hybrid_cache_manager/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -4,9 +4,8 @@ from typing import Callable, Dict, List, Optional, Tuple, Type from vllm.utils import cdiv -from vllm.v1.core.hybrid_cache_manager.utils import (PrefixLength, - PrefixLengthRange) -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + PrefixLength, PrefixLengthRange) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.utils import ConstantList From e6016e5f1c0b03c8882895d4b16925a7dc07c8ec Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 1 Feb 2025 19:14:03 -0800 Subject: [PATCH 27/48] add docstrings Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 27 ++++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b8cc725f1ec78..335db9cd379c2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -580,7 +580,7 @@ def get_kv_cache_config(vllm_config: VllmConfig, @dataclass class PrefixLengthRange: """ - [start, end] + A closed interval [start, end] representing a range of valid prefix lengths. """ start: int end: int diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7dc9e5472d71e..09a76976e585e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1015,7 +1015,17 @@ def capture_model(self) -> None: def _initialize_kv_cache_buffer( self, kv_cache_config: KVCacheConfig) -> Dict[str, torch.Tensor]: - # TODO: add docstring + """ + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. + + Args: + kv_cache_config: The KV cache config + + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ kv_cache_raw_tensors: Dict[str, torch.Tensor] = {} for layer_name, tensor_config in kv_cache_config.tensors.items(): if isinstance(tensor_config, KVCacheNewTensor): @@ -1024,7 +1034,7 @@ def _initialize_kv_cache_buffer( tensor_config.size, dtype=torch.int8, device=self.device) for layer_name, tensor_config in kv_cache_config.tensors.items(): if isinstance(tensor_config, KVCacheReuseTensor): - # Reuse the tensor from `kv_cache_raw_tensors` + # Reuse a tensor from `kv_cache_raw_tensors` kv_cache_raw_tensors[layer_name] = kv_cache_raw_tensors[ tensor_config.reused_layer_name] assert len(kv_cache_raw_tensors) == len( @@ -1036,7 +1046,18 @@ def _setup_kv_cache_shapes( kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: - # TODO: add docstring + """ + Reshape the KV cache tensors to the desired shape. + + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ kv_caches: Dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.groups: kv_cache_spec = kv_cache_group.kv_cache_spec From b369fa23ec58006cde8b810c6e173d3357095f02 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 1 Feb 2025 19:50:31 -0800 Subject: [PATCH 28/48] fix pre-commit Signed-off-by: Chen Zhang --- vllm/v1/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 3c2a62b26e984..f47bce6018f46 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -272,7 +272,7 @@ def schedule(self) -> "SchedulerOutput": # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = 0 + num_common_prefix_blocks = [0] if self.running: any_request = self.running[0] num_common_prefix_blocks = ( From ea65e60c3743e5cefa3d016256221e6641b73907 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 4 Feb 2025 05:29:16 -0800 Subject: [PATCH 29/48] add request.py Signed-off-by: Chen Zhang --- vllm/v1/request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 1f22f79e05eca..ab647748f223e 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -125,7 +125,7 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens - def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: + def set_kv_block_hashes(self, value: List[List["BlockHashType"]]) -> None: self._kv_block_hashes = value # NOTE: self.kv_block_hashes._x is not self._kv_block_hashes, but # self.kv_block_hashes[0]._x is self._kv_block_hashes[0]. This is From 42c391dee48f509cc2a0f5fe3c0267d1e80d8c11 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 4 Feb 2025 06:27:28 -0800 Subject: [PATCH 30/48] remove small comment Signed-off-by: Chen Zhang --- vllm/v1/worker/block_table.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 6b82b3c6c1128..d12c8a96b9a07 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -30,8 +30,7 @@ def __init__( self.num_kv_cache_groups = num_kv_cache_groups # NOTE: Pad the block table to the max possible number of blocks among - # all KV cache groups. This waste some memory if block_size of the - # groups differ. + # all KV cache groups. self.block_table = torch.zeros( (num_kv_cache_groups, max_num_reqs, max_num_blocks_per_req), device=self.device, From f6d2bfd168b1eb72d7946baae753c72fd8942dd0 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 5 Feb 2025 07:54:17 -0800 Subject: [PATCH 31/48] avoid loop in block table Signed-off-by: Chen Zhang --- vllm/v1/worker/block_table.py | 94 +++++++++++++++++++----------- vllm/v1/worker/gpu_input_batch.py | 9 ++- vllm/v1/worker/gpu_model_runner.py | 14 ++--- 3 files changed, 71 insertions(+), 46 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index d12c8a96b9a07..efa74f00d3fe2 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -19,71 +19,55 @@ def __init__( max_num_blocks_per_req: int, pin_memory: bool, device: torch.device, - # NOTE: See KVCacheConfig class for the meaning of "KV cache group". - num_kv_cache_groups: int, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.pin_memory = pin_memory self.device = device - self.num_kv_cache_groups = num_kv_cache_groups - # NOTE: Pad the block table to the max possible number of blocks among - # all KV cache groups. self.block_table = torch.zeros( - (num_kv_cache_groups, max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, max_num_blocks_per_req), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (num_kv_cache_groups, max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, max_num_blocks_per_req), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) self.block_table_np = self.block_table_cpu.numpy() - self.num_blocks_per_row = np.zeros((num_kv_cache_groups, max_num_reqs), - dtype=np.int32) + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) def append_row( self, + block_ids: List[int], row_idx: int, - block_ids: List[List[int]], ) -> None: - if max(len(b) for b in block_ids) > 0: + if not block_ids: return - for i, (num_blocks, block_ids_of_group) in enumerate( - zip(self.num_blocks_per_row[:, row_idx], block_ids)): - num_new_blocks = len(block_ids_of_group) - self.block_table_np[i, row_idx, num_blocks:num_blocks + - num_new_blocks] = block_ids_of_group - self.num_blocks_per_row[i, row_idx] = num_blocks + num_new_blocks + num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.num_blocks_per_row[row_idx] = start + num_blocks - def add_row(self, row_idx: int, block_ids: List[List[int]]) -> None: - self.num_blocks_per_row[:, row_idx] = 0 - self.append_row(row_idx, block_ids) + def add_row(self, block_ids: List[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) def move_row(self, src: int, tgt: int) -> None: - num_blocks = self.num_blocks_per_row[:, src] - self.block_table_np[:, tgt, :max(num_blocks)] = \ - self.block_table_np[:, src, :max(num_blocks)] - self.num_blocks_per_row[:, tgt] = num_blocks + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks def commit(self, num_reqs: int) -> None: - # NOTE: an alternative is - # self.block_table[:, :num_reqs].copy_( - # self.block_table_cpu[:, :num_reqs], non_blocking=True) - # but it will be a blocking copy when num_kv_cache_groups > 1. - # Can be verified by the following code: - # https://gist.github.com/heheda12345/74c7f7a68e45c242a5c901b5fb77d000 - for i in range(self.num_kv_cache_groups): - self.block_table[i, :num_reqs].copy_( - self.block_table_cpu[i, :num_reqs], non_blocking=True) + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) def clear(self) -> None: self.block_table.fill_(0) - self.block_table_cpu.fill_(0) def get_device_tensor(self) -> torch.Tensor: """Ruturns the device tensor of the block table.""" @@ -96,3 +80,45 @@ def get_cpu_tensor(self) -> torch.Tensor: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np + + +class GroupedBlockTable: + + def __init__(self, max_num_reqs: int, max_model_len: int, + max_num_blocks_per_req: int, pin_memory: bool, + device: torch.device, num_kv_cache_groups: int): + self.block_tables = [ + BlockTable( + max_num_reqs, + max_model_len, + max_num_blocks_per_req, + pin_memory, + device, + ) for _ in range(num_kv_cache_groups) + ] + for f_name in ('move_row', 'commit', 'clear'): + setattr(self, f_name, self._make_grouped_func(f_name)) + + for f_name in ('append_row', 'add_row'): + # NOTE: requires to pass block_ids as the first argument + setattr(self, f_name, + self._make_grouped_func_with_block_ids(f_name)) + + def _make_grouped_func(self, f_name): + + def grouped_func(*args, **kwargs): + for block_table in self.block_tables: + getattr(block_table, f_name)(*args, **kwargs) + + return grouped_func + + def _make_grouped_func_with_block_ids(self, f_name): + + def grouped_func(block_ids: List[List[int]], *args, **kwargs): + for i, block_table in enumerate(self.block_tables): + getattr(block_table, f_name)(block_ids[i], *args, **kwargs) + + return grouped_func + + def __getitem__(self, idx): + return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 968cab15e6372..d3288fb32d665 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,7 +11,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.block_table import BlockTable, GroupedBlockTable if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -79,14 +79,13 @@ def __init__( self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Block table. - self.block_table = BlockTable( + self.block_table = GroupedBlockTable( max_num_reqs=max_num_reqs, max_model_len=max_model_len, max_num_blocks_per_req=max_num_blocks_per_req, pin_memory=pin_memory, device=device, - num_kv_cache_groups=num_kv_cache_groups, - ) + num_kv_cache_groups=num_kv_cache_groups) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), @@ -197,7 +196,7 @@ def add_request( self.num_tokens[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(req_index, request.block_ids) + self.block_table.add_row(request.block_ids, req_index) sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f11a1a2b04b52..9c7ba24391e94 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -337,8 +337,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_data.num_computed_tokens) # Update the block table. - self.input_batch.block_table.append_row(req_index, - req_data.new_block_ids) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) for group_id, new_block_ids in enumerate(req_data.new_block_ids): req_state.block_ids[group_id].extend(new_block_ids) @@ -439,9 +439,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu[i].flatten()[block_table_indices]\ - .numpy() + block_table_cpu = self.input_batch.block_table[i].get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() block_offsets = positions_np % block_size np.add(block_numbers * block_size, block_offsets, @@ -567,8 +567,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - block_table=(self.input_batch.block_table.get_device_tensor()[ - group_id, :num_reqs]), + block_table=(self.input_batch.block_table[group_id]. + get_device_tensor()[:num_reqs]), slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, From b614b422738007866e7759d57edae87edd00d6ab Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 6 Feb 2025 02:54:15 -0800 Subject: [PATCH 32/48] clean up attn_metadata Signed-off-by: Chen Zhang --- vllm/v1/kv_cache_interface.py | 1 + vllm/v1/worker/block_table.py | 38 ++++++---- vllm/v1/worker/gpu_input_batch.py | 20 +++-- vllm/v1/worker/gpu_model_runner.py | 115 +++++++++++------------------ 4 files changed, 76 insertions(+), 98 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 500a9ec985b1e..14082e5d32f1d 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -57,6 +57,7 @@ def bytes_for_tokens(self, num_tokens: int) -> int: @dataclass class FullAttentionSpec(KVCacheSpec): + num_heads: int num_kv_heads: int head_size: int dtype: torch.dtype diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index efa74f00d3fe2..3b8173e3435e1 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List +from typing import Any, Dict, List import numpy as np import torch +from triton import cdiv from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec logger = init_logger(__name__) @@ -16,23 +19,26 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, + max_num_tokens: int, pin_memory: bool, device: torch.device, + kv_cache_spec: KVCacheSpec, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_tokens = max_num_tokens + self.max_num_blocks_per_req = cdiv(max_model_len, + kv_cache_spec.block_size) self.pin_memory = pin_memory self.device = device self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, self.max_num_blocks_per_req), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, self.max_num_blocks_per_req), device="cpu", dtype=torch.int32, pin_memory=pin_memory, @@ -40,6 +46,12 @@ def __init__( self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + def append_row( self, block_ids: List[int], @@ -68,6 +80,7 @@ def commit(self, num_reqs: int) -> None: def clear(self) -> None: self.block_table.fill_(0) + self.block_table_cpu.fill_(0) def get_device_tensor(self) -> torch.Tensor: """Ruturns the device tensor of the block table.""" @@ -85,16 +98,11 @@ def get_numpy_array(self) -> np.ndarray: class GroupedBlockTable: def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, pin_memory: bool, - device: torch.device, num_kv_cache_groups: int): + max_num_tokens: int, pin_memory: bool, device: torch.device, + kv_cache_config: KVCacheConfig): self.block_tables = [ - BlockTable( - max_num_reqs, - max_model_len, - max_num_blocks_per_req, - pin_memory, - device, - ) for _ in range(num_kv_cache_groups) + BlockTable(max_num_reqs, max_model_len, max_num_tokens, pin_memory, + device, g.kv_cache_spec) for g in kv_cache_config.groups ] for f_name in ('move_row', 'commit', 'clear'): setattr(self, f_name, self._make_grouped_func(f_name)) @@ -120,5 +128,5 @@ def grouped_func(block_ids: List[List[int]], *args, **kwargs): return grouped_func - def __getitem__(self, idx): + def __getitem__(self, idx) -> BlockTable: return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d3288fb32d665..ba7f0f4aa7fd6 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,6 +10,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, GroupedBlockTable @@ -46,16 +47,14 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, + max_num_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, - # NOTE: See KVCacheConfig class for the meaning of "KV cache group". - num_kv_cache_groups: int, + kv_cache_config: KVCacheConfig, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size @@ -79,13 +78,12 @@ def __init__( self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Block table. - self.block_table = GroupedBlockTable( - max_num_reqs=max_num_reqs, - max_model_len=max_model_len, - max_num_blocks_per_req=max_num_blocks_per_req, - pin_memory=pin_memory, - device=device, - num_kv_cache_groups=num_kv_cache_groups) + self.block_table = GroupedBlockTable(max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_tokens=max_num_tokens, + pin_memory=pin_memory, + device=device, + kv_cache_config=kv_cache_config) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c7ba24391e94..ba7855234bbe1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -77,13 +77,6 @@ def __init__( self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() # Multi-modal data support @@ -182,13 +175,9 @@ def __init__( self.kv_cache_config = cast(KVCacheConfig, None) # Set by initialize_kv_cache - # The following 3 variables depends on KVCacheConfig, assign a - # placeholder value here and initialize them in `initialize_kv_cache``. + # InputBatch depends on KVCacheConfig, assign a fake value here and + # initialize in `initialize_kv_cache``. self.input_batch = cast(InputBatch, None) # Persistent batch. - self.slot_mapping_cpu = torch.zeros( - (1, )) # Real shape: (num_kv_cache_groups, self.max_num_tokens) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() - self.max_num_blocks_per_req: int = 0 self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, @@ -425,28 +414,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - for i, kv_cache_group in enumerate(self.kv_cache_config.groups): - block_size = kv_cache_group.kv_cache_spec.block_size - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table[i].get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add(block_numbers * block_size, - block_offsets, - out=self.slot_mapping_np[i, :total_num_scheduled_tokens]) - # Prepare the attention metadata. self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -473,17 +440,41 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True) seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, non_blocking=True) - # layer_name -> AttentionMetadata + attn_metadata: Dict[str, FlashAttentionMetadata] = {} + for group_id, kv_cache_group in enumerate(self.kv_cache_config.groups): block_size = kv_cache_group.kv_cache_spec.block_size - slot_mapping = self.slot_mapping_cpu[ - group_id, :total_num_scheduled_tokens].to( - self.device, non_blocking=True).long() + block_table = self.input_batch.block_table[group_id] + + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + slot_mapping = block_table.slot_mapping_cpu \ + [:total_num_scheduled_tokens] \ + .to(self.device, non_blocking=True).long() # Prepare for cascade attention if needed. - common_prefix_len = (scheduler_output.num_common_prefix_blocks[i] * - block_size) + common_prefix_len = ( + scheduler_output.num_common_prefix_blocks[group_id] * + block_size) if common_prefix_len == 0: # Common case. use_cascade = False @@ -533,11 +524,13 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // block_size * block_size) + kv_cache_spec = kv_cache_group.kv_cache_spec + assert isinstance(kv_cache_spec, FullAttentionSpec) use_cascade = FlashAttentionBackend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, + num_query_heads=kv_cache_spec.num_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, use_alibi=False, # FIXME use_sliding_window=self.sliding_window is not None, num_sms=self.num_sms, @@ -567,8 +560,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - block_table=(self.input_batch.block_table[group_id]. - get_device_tensor()[:num_reqs]), + block_table=block_table.get_device_tensor()[:num_reqs], slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, @@ -936,9 +928,11 @@ def profile_run(self) -> None: # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. + num_attn_layers = self.model_config.get_num_layers_by_block_type( + self.parallel_config, LayerBlockType.attention) dummy_kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(self.num_attn_layers) + for _ in range(num_attn_layers) ] # Profile with multimodal encoder & encoder cache. @@ -1104,40 +1098,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) - self._initialize_kv_related_buffers(kv_cache_config) - - def _initialize_kv_related_buffers(self, - kv_cache_config: KVCacheConfig) -> None: - """ - Initialize data structures (e.g., InputBatch, slot mappings) that - depend on the kv cache configuration. - - Args: - kv_cache_config (KVCacheConfig): Configuration for the KV cache - """ - num_kv_cache_groups = len(kv_cache_config.groups) - - min_block_size = min(group.kv_cache_spec.block_size - for group in kv_cache_config.groups) - self.max_num_blocks_per_req = cdiv(self.max_model_len, min_block_size) - self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.vllm_config.model_config.get_vocab_size(), - num_kv_cache_groups=num_kv_cache_groups, + kv_cache_config=kv_cache_config, ) - self.slot_mapping_cpu = torch.zeros(num_kv_cache_groups, - self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() - def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -1157,6 +1127,7 @@ def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, + num_heads=attn_module.num_heads, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, From ca91b306da8eb9e285a66b97768227edab439351 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 6 Feb 2025 05:01:53 -0800 Subject: [PATCH 33/48] BlockIDList Signed-off-by: Chen Zhang --- vllm/v1/core/scheduler.py | 35 ++++++++++++++---------------- vllm/v1/kv_cache_interface.py | 31 +++++++++++++++++++++++++- vllm/v1/worker/block_table.py | 7 +++--- vllm/v1/worker/gpu_input_batch.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 6 +---- 5 files changed, 53 insertions(+), 30 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 580c4d42478f6..7fbf052c79310 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -12,6 +12,7 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.kv_cache_interface import BlockIDList from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -107,8 +108,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] - # Request id -> List of block IDs for each kv cache group. - req_to_new_block_ids: Dict[str, List[List[int]]] = {} + req_to_new_block_ids: Dict[str, BlockIDList] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -165,9 +165,10 @@ def schedule(self) -> "SchedulerOutput": # Schedule the request. scheduled_running_reqs.append(request) - req_to_new_block_ids[request.request_id] = [[ - b.block_id for b in new_blocks_of_group - ] for new_blocks_of_group in new_blocks] + req_to_new_block_ids[ + request.request_id] = BlockIDList.from_kv_cache_blocks( + new_blocks) + num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -235,11 +236,11 @@ def schedule(self) -> "SchedulerOutput": raise RuntimeError( f"Invalid request status: {request.status}") - req_to_new_block_ids[request.request_id] = [[ - b.block_id - for b in computed_blocks_of_group + new_blocks_of_group - ] for computed_blocks_of_group, new_blocks_of_group in zip( - computed_blocks, new_blocks)] + req_to_new_block_ids[ + request.request_id] = BlockIDList.from_kv_cache_blocks( + computed_blocks) + BlockIDList.from_kv_cache_blocks( + new_blocks) + num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -318,7 +319,7 @@ def schedule(self) -> "SchedulerOutput": def _make_cached_request_data( self, request: Request, - new_block_ids: List[List[int]], + new_block_ids: BlockIDList, num_computed_tokens: int, resumed_from_preemption: bool, ) -> "CachedRequestData": @@ -569,16 +570,14 @@ class NewRequestData: mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams - # List of block IDs for each KV cache group. - # See KVCacheConfig class for the meaning of "KV cache group". - block_ids: List[List[int]] + block_ids: BlockIDList num_computed_tokens: int @classmethod def from_request( cls, request: Request, - block_ids: List[List[int]], + block_ids: BlockIDList, num_computed_tokens: int, ) -> "NewRequestData": return cls( @@ -602,9 +601,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool - # List of block IDs for each kv cache group. - # See KVCacheConfig class for the meaning of "KV cache group". - new_block_ids: List[List[int]] + new_block_ids: BlockIDList num_computed_tokens: int @classmethod @@ -612,7 +609,7 @@ def from_request( cls, request: Request, resumed_from_preemption: bool, - new_block_ids: List[List[int]], + new_block_ids: BlockIDList, num_computed_tokens: int, ) -> "CachedRequestData": return cls( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 14082e5d32f1d..debdb1c75b5c1 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List +from typing import TYPE_CHECKING, Dict, List import torch from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import ReqKVCacheBlocks logger = init_logger(__name__) @@ -118,3 +120,30 @@ class KVCacheConfig: window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). """ groups: List[KVCacheGroup] + + +@dataclass +class BlockIDList: + # A list of block IDs for each group of KV cache blocks + _block_ids: List[List[int]] + + def __init__(self, block_ids: List[List[int]]): + self._block_ids = block_ids + + @classmethod + def from_kv_cache_blocks(cls, kv_cache_blocks: "ReqKVCacheBlocks"): + return cls( + block_ids=[[blk.block_id for blk in kv_cache_blocks_one_group] + for kv_cache_blocks_one_group in kv_cache_blocks]) + + def extend(self, new_block_ids: "BlockIDList") -> None: + for i, block_ids in enumerate(new_block_ids._block_ids): + self._block_ids[i].extend(block_ids) + + def __add__(self, other: "BlockIDList") -> "BlockIDList": + return BlockIDList(block_ids=[ + a + b for a, b in zip(self._block_ids, other._block_ids) + ]) + + def get_group(self, group_idx: int) -> List[int]: + return self._block_ids[group_idx] diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 3b8173e3435e1..62ca99cbc0321 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -8,7 +8,7 @@ from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec +from vllm.v1.kv_cache_interface import BlockIDList, FullAttentionSpec, KVCacheConfig, KVCacheSpec logger = init_logger(__name__) @@ -122,9 +122,10 @@ def grouped_func(*args, **kwargs): def _make_grouped_func_with_block_ids(self, f_name): - def grouped_func(block_ids: List[List[int]], *args, **kwargs): + def grouped_func(block_ids: BlockIDList, *args, **kwargs): for i, block_table in enumerate(self.block_tables): - getattr(block_table, f_name)(block_ids[i], *args, **kwargs) + getattr(block_table, f_name)(block_ids.get_group(i), *args, + **kwargs) return grouped_func diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ba7f0f4aa7fd6..1fcbd3fdd9e0a 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,7 +10,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import BlockIDList, KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, GroupedBlockTable @@ -29,7 +29,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: List[List[int]] # List of block ids for each kv cache group + block_ids: BlockIDList num_computed_tokens: int output_token_ids: List[int] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba7855234bbe1..7bb3f72827272 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -305,9 +305,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_state.num_computed_tokens = req_data.num_computed_tokens if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - for group_id, new_block_ids in enumerate( - req_data.new_block_ids): - req_state.block_ids[group_id].extend(new_block_ids) + req_state.block_ids.extend(req_data.new_block_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -328,8 +326,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the block table. self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) - for group_id, new_block_ids in enumerate(req_data.new_block_ids): - req_state.block_ids[group_id].extend(new_block_ids) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. From 0475e9f5c32b01403edc12b4f375c647f46e707a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 6 Feb 2025 05:17:05 -0800 Subject: [PATCH 34/48] forward metadata Signed-off-by: Chen Zhang --- vllm/forward_context.py | 16 +++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 19 +++++++++++++++---- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 42af91b9d7500..48bf9694d806a 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -45,6 +45,14 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass +@dataclass +class ForwardMetadata: + """ + Forward metadata for each forward pass + """ + num_input_tokens: int + + _forward_context: Optional[ForwardContext] = None @@ -59,7 +67,8 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0): + virtual_engine: int = 0, + forward_metadata: Optional[ForwardMetadata] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -79,13 +88,14 @@ def set_forward_context(attn_metadata: Any, finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: - if hasattr(attn_metadata, "num_prefill_tokens"): + if not envs.VLLM_USE_V1: # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens else: # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + assert forward_metadata is not None + batchsize = forward_metadata.num_input_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7bb3f72827272..05473785f0f2e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,7 +13,7 @@ from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture -from vllm.forward_context import set_forward_context +from vllm.forward_context import ForwardMetadata, set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -751,8 +751,17 @@ def execute_model( # Prepare the decoder inputs. attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self.maybe_pad_for_cudagraph( - num_scheduled_tokens, attn_metadata) + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = num_scheduled_tokens + + forward_metadata = ForwardMetadata(num_input_tokens=num_input_tokens) if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision @@ -778,7 +787,9 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, + self.vllm_config, + forward_metadata=forward_metadata): positions = self.mrope_positions[:, :num_input_tokens] \ if self.model_config.uses_mrope \ else self.positions[:num_input_tokens] From bcfc994aa383750e9345fea6c8ad17536c8a533e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 6 Feb 2025 05:38:17 -0800 Subject: [PATCH 35/48] cleanup Signed-off-by: Chen Zhang --- vllm/v1/core/scheduler.py | 22 ++++++++--------- vllm/v1/kv_cache_interface.py | 9 +++---- vllm/v1/worker/block_table.py | 38 ++++++++++++++++++++---------- vllm/v1/worker/gpu_input_batch.py | 20 +++++++++------- vllm/v1/worker/gpu_model_runner.py | 25 -------------------- 5 files changed, 52 insertions(+), 62 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 7fbf052c79310..d5ad67f659b95 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -12,7 +12,7 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import BlockIDList +from vllm.v1.kv_cache_interface import GroupedBlockIDs from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -108,7 +108,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] - req_to_new_block_ids: Dict[str, BlockIDList] = {} + req_to_new_block_ids: Dict[str, GroupedBlockIDs] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -166,7 +166,7 @@ def schedule(self) -> "SchedulerOutput": # Schedule the request. scheduled_running_reqs.append(request) req_to_new_block_ids[ - request.request_id] = BlockIDList.from_kv_cache_blocks( + request.request_id] = GroupedBlockIDs.from_kv_cache_blocks( new_blocks) num_scheduled_tokens[request.request_id] = num_new_tokens @@ -237,9 +237,9 @@ def schedule(self) -> "SchedulerOutput": f"Invalid request status: {request.status}") req_to_new_block_ids[ - request.request_id] = BlockIDList.from_kv_cache_blocks( - computed_blocks) + BlockIDList.from_kv_cache_blocks( - new_blocks) + request.request_id] = GroupedBlockIDs.from_kv_cache_blocks( + computed_blocks + ) + GroupedBlockIDs.from_kv_cache_blocks(new_blocks) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens @@ -319,7 +319,7 @@ def schedule(self) -> "SchedulerOutput": def _make_cached_request_data( self, request: Request, - new_block_ids: BlockIDList, + new_block_ids: GroupedBlockIDs, num_computed_tokens: int, resumed_from_preemption: bool, ) -> "CachedRequestData": @@ -570,14 +570,14 @@ class NewRequestData: mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams - block_ids: BlockIDList + block_ids: GroupedBlockIDs num_computed_tokens: int @classmethod def from_request( cls, request: Request, - block_ids: BlockIDList, + block_ids: GroupedBlockIDs, num_computed_tokens: int, ) -> "NewRequestData": return cls( @@ -601,7 +601,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool - new_block_ids: BlockIDList + new_block_ids: GroupedBlockIDs num_computed_tokens: int @classmethod @@ -609,7 +609,7 @@ def from_request( cls, request: Request, resumed_from_preemption: bool, - new_block_ids: BlockIDList, + new_block_ids: GroupedBlockIDs, num_computed_tokens: int, ) -> "CachedRequestData": return cls( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index debdb1c75b5c1..0825acbfb9869 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -7,6 +7,7 @@ from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size + if TYPE_CHECKING: from vllm.v1.core.kv_cache_utils import ReqKVCacheBlocks @@ -123,7 +124,7 @@ class KVCacheConfig: @dataclass -class BlockIDList: +class GroupedBlockIDs: # A list of block IDs for each group of KV cache blocks _block_ids: List[List[int]] @@ -136,12 +137,12 @@ def from_kv_cache_blocks(cls, kv_cache_blocks: "ReqKVCacheBlocks"): block_ids=[[blk.block_id for blk in kv_cache_blocks_one_group] for kv_cache_blocks_one_group in kv_cache_blocks]) - def extend(self, new_block_ids: "BlockIDList") -> None: + def extend(self, new_block_ids: "GroupedBlockIDs") -> None: for i, block_ids in enumerate(new_block_ids._block_ids): self._block_ids[i].extend(block_ids) - def __add__(self, other: "BlockIDList") -> "BlockIDList": - return BlockIDList(block_ids=[ + def __add__(self, other: "GroupedBlockIDs") -> "GroupedBlockIDs": + return GroupedBlockIDs(block_ids=[ a + b for a, b in zip(self._block_ids, other._block_ids) ]) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 62ca99cbc0321..5223e8e7ba8b0 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,14 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List +from typing import Callable, Concatenate, List, ParamSpec import numpy as np import torch from triton import cdiv from vllm.logger import init_logger -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.kv_cache_interface import BlockIDList, FullAttentionSpec, KVCacheConfig, KVCacheSpec +from vllm.v1.kv_cache_interface import (GroupedBlockIDs, KVCacheConfig, + KVCacheSpec) logger = init_logger(__name__) @@ -95,39 +95,51 @@ def get_numpy_array(self) -> np.ndarray: return self.block_table_np +P = ParamSpec("P") + + class GroupedBlockTable: + move_row: Callable[P, None] + commit: Callable[P, None] + clear: Callable[P, None] + + append_row: Callable[Concatenate["GroupedBlockIDs", P], None] + add_row: Callable[Concatenate["GroupedBlockIDs", P], None] def __init__(self, max_num_reqs: int, max_model_len: int, max_num_tokens: int, pin_memory: bool, device: torch.device, - kv_cache_config: KVCacheConfig): + kv_cache_config: KVCacheConfig) -> None: self.block_tables = [ BlockTable(max_num_reqs, max_model_len, max_num_tokens, pin_memory, device, g.kv_cache_spec) for g in kv_cache_config.groups ] - for f_name in ('move_row', 'commit', 'clear'): + # For methods that just pass the arguments to each BlockTable. + for f_name in ("move_row", "commit", "clear"): setattr(self, f_name, self._make_grouped_func(f_name)) - - for f_name in ('append_row', 'add_row'): - # NOTE: requires to pass block_ids as the first argument + # For methods that require a block_ids as the first argument. + for f_name in ("append_row", "add_row"): setattr(self, f_name, self._make_grouped_func_with_block_ids(f_name)) - def _make_grouped_func(self, f_name): + def _make_grouped_func(self, f_name: str) -> Callable[P, None]: - def grouped_func(*args, **kwargs): + def grouped_func(*args: P.args, **kwargs: P.kwargs) -> None: for block_table in self.block_tables: getattr(block_table, f_name)(*args, **kwargs) return grouped_func - def _make_grouped_func_with_block_ids(self, f_name): + def _make_grouped_func_with_block_ids( + self, + f_name: str) -> Callable[Concatenate["GroupedBlockIDs", P], None]: - def grouped_func(block_ids: BlockIDList, *args, **kwargs): + def grouped_func(block_ids: "GroupedBlockIDs", *args: P.args, + **kwargs: P.kwargs) -> None: for i, block_table in enumerate(self.block_tables): getattr(block_table, f_name)(block_ids.get_group(i), *args, **kwargs) return grouped_func - def __getitem__(self, idx) -> BlockTable: + def __getitem__(self, idx: int) -> "BlockTable": return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1fcbd3fdd9e0a..c4b7cdc99f1fa 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,9 +10,9 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.v1.kv_cache_interface import BlockIDList, KVCacheConfig +from vllm.v1.kv_cache_interface import GroupedBlockIDs, KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable, GroupedBlockTable +from vllm.v1.worker.block_table import GroupedBlockTable if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -29,7 +29,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: BlockIDList + block_ids: GroupedBlockIDs num_computed_tokens: int output_token_ids: List[int] @@ -78,12 +78,14 @@ def __init__( self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Block table. - self.block_table = GroupedBlockTable(max_num_reqs=max_num_reqs, - max_model_len=max_model_len, - max_num_tokens=max_num_tokens, - pin_memory=pin_memory, - device=device, - kv_cache_config=kv_cache_config) + self.block_table = GroupedBlockTable( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_tokens=max_num_tokens, + pin_memory=pin_memory, + device=device, + kv_cache_config=kv_cache_config, + ) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 05473785f0f2e..57d2acc6c8d2a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -61,7 +61,6 @@ def __init__( model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config - parallel_config = self.parallel_config self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype @@ -322,8 +321,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) - - # Update the block table. self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) @@ -868,28 +865,6 @@ def execute_model( ) return model_runner_output - def maybe_pad_for_cudagraph( - self, num_scheduled_tokens: int, - attn_metadata: Dict[str, FlashAttentionMetadata]) -> int: - if (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = num_scheduled_tokens - - # update num_input_tokens in attn_metadata - for kv_cache_group in self.kv_cache_config.groups: - layer_name = kv_cache_group.layer_names[0] - # All layers in the group share the same attn_metadata object. - # Only need to update the num_input_tokens once. - attn_metadata[layer_name].num_input_tokens = num_input_tokens - - return num_input_tokens - def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 From bcab7afef101eb162215b172b1242dea537d87b5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 6 Feb 2025 05:55:29 -0800 Subject: [PATCH 36/48] fix pre-commit Signed-off-by: Chen Zhang --- vllm/v1/worker/block_table.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 5223e8e7ba8b0..7692972f0316f 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Concatenate, List, ParamSpec +from typing import Callable, List import numpy as np import torch from triton import cdiv +from typing_extensions import Concatenate, ParamSpec from vllm.logger import init_logger from vllm.v1.kv_cache_interface import (GroupedBlockIDs, KVCacheConfig, From 0a9701ed7dd01943d3e594b73eb5c5299021833d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 7 Feb 2025 05:21:01 -0800 Subject: [PATCH 37/48] fix Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 5 ++--- tests/v1/core/test_prefix_caching.py | 3 ++- vllm/v1/core/kv_cache_manager.py | 9 ++++----- vllm/v1/core/scheduler.py | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 28a400c2303b4..97595a9422ec8 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,11 +5,10 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + KVCacheBlock, PrefixLengthRange, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens, - PrefixLengthRange, intersect_ranges) + hash_request_tokens, intersect_ranges) from vllm.v1.request import Request diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b0ec747a97890..f76cfa13b53c2 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,7 +8,8 @@ from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroup +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroup) def make_request(request_id, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 79dfc4dcafa55..ba784600c5063 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,14 +6,13 @@ from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.specialized_manager import (BlockPoolOperations, - get_managers) from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, ReqKVCacheBlocks, + KVCacheBlock, PrefixLength, + ReqKVCacheBlocks, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens, PrefixLength, - intersect_ranges) + hash_request_tokens, intersect_ranges) +from vllm.v1.core.specialized_manager import BlockPoolOperations, get_managers from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index c6ccc0bac564d..cef1fa0bcb2b8 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -12,7 +12,7 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import (GroupedBlockIDs, FullAttentionSpec, +from vllm.v1.kv_cache_interface import (FullAttentionSpec, GroupedBlockIDs, KVCacheConfig) from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput From a7173a2af82773797d7da8638954591621c3a6f9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 7 Feb 2025 05:44:54 -0800 Subject: [PATCH 38/48] fix Signed-off-by: Chen Zhang --- examples/offline_inference/basic.py | 4 ++-- vllm/v1/core/kv_cache_manager.py | 15 ++++++--------- vllm/v1/worker/gpu_model_runner.py | 3 ++- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index a6e96c0bb4339..9875b80a971cd 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -13,7 +13,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="google/gemma-2-2b-it") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -21,4 +21,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ba784600c5063..ec11462dff23c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -95,8 +95,8 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: DefaultDict[str, - ReqKVCacheBlocks] = defaultdict(list) + self.req_to_blocks: DefaultDict[str, ReqKVCacheBlocks] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) @property def usage(self) -> float: @@ -155,8 +155,6 @@ def get_computed_blocks(self, computed_blocks[i] = computed_blocks[i][:num_computed_tokens // manager.block_size] - # Free the blocks that are not needed. (?????) - # self._free_useless_blocks(computed_blocks, num_computed_tokens) return computed_blocks, num_computed_tokens def allocate_slots( @@ -193,22 +191,21 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") + req_blocks = self.req_to_blocks[request.request_id] # We can free blocks that are no longer needed even if we cannot # schedule this request due to the limit of free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - self._free_useless_blocks(self.req_to_blocks[request.request_id], - request.num_computed_tokens) + self._free_useless_blocks(req_blocks, request.num_computed_tokens) new_computed_blocks = new_computed_blocks if new_computed_blocks is not None else [ - [] for _ in self.num_kv_cache_groups + [] for _ in range(self.num_kv_cache_groups) ] # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + num_new_computed_tokens) - req_blocks = self.req_to_blocks[request.request_id] num_new_blocks = [ manager.get_num_new_blocks( @@ -291,7 +288,7 @@ def allocate_slots( # full after appending the actual tokens. num_full_blocks = (num_computed_tokens + num_tokens) // manager.block_size - num_computed_full_blocks = num_computed_tokens // self.block_size + num_computed_full_blocks = num_computed_tokens // manager.block_size new_full_blocks = req_blocks[i][ num_computed_full_blocks:num_full_blocks] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3d8538f2a4917..7936f8b0d1a1e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -519,7 +519,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): common_prefix_len = (common_prefix_len // block_size * block_size) kv_cache_spec = kv_cache_group.kv_cache_spec - assert isinstance(kv_cache_spec, FullAttentionSpec) + assert isinstance(kv_cache_spec, + (FullAttentionSpec, SlidingWindowSpec)) use_cascade = FlashAttentionBackend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, From 5e2d3bdf4e831bb653dfa4517bea17aa3eb57f7e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 7 Feb 2025 19:14:10 -0800 Subject: [PATCH 39/48] cherry-pick: [V1] Move KV block hashes from Request to KVCacheManager (#12922) Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 21 ++++++++-------- vllm/v1/core/kv_cache_manager.py | 37 +++++++++++++++++++--------- vllm/v1/core/scheduler.py | 1 + vllm/v1/request.py | 19 -------------- 4 files changed, 38 insertions(+), 40 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index f76cfa13b53c2..568df91f2d858 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -63,7 +63,7 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(req0.kv_block_hashes[0]) == 3 + assert len(manager.req_to_block_hashes[0][req0.request_id]) == 3 assert not computed_blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks, @@ -89,7 +89,7 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes[0]) == 3 + assert len(manager.req_to_block_hashes[0][req1.request_id]) == 3 assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -121,7 +121,7 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(req2.kv_block_hashes[0]) == 3 + assert len(manager.req_to_block_hashes[0][req2.request_id]) == 3 assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -509,10 +509,11 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks[0] assert num_computed_tokens == 0 - assert len(req0.kv_block_hashes[0]) == 3 - assert req0.kv_block_hashes[0][0].extra_keys == ("aaa", ) - assert req0.kv_block_hashes[0][1].extra_keys == ("aaa", "bbb") - assert req0.kv_block_hashes[0][2].extra_keys == ("bbb", ) + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes[0]) == 3 + assert block_hashes[0][0].extra_keys == ("aaa", ) + assert block_hashes[0][1].extra_keys == ("aaa", "bbb") + assert block_hashes[0][2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks, num_computed_tokens) @@ -526,8 +527,8 @@ def test_mm_prefix_caching(): assert new_blocks is not None and len(new_blocks[0]) == 0 # The just completed block should have hashes with extra keys. - assert len(req0.kv_block_hashes[0]) == 4 - assert req0.kv_block_hashes[0][3].extra_keys == ("ccc", ) + assert len(block_hashes[0]) == 4 + assert block_hashes[0][3].extra_keys == ("ccc", ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -632,7 +633,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes[0]) == 3 + assert len(manager.req_to_block_hashes[0][req1.request_id]) == 3 assert len(computed_blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) assert [b.block_id for b in blocks[0]] == [4] diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ec11462dff23c..e77b609535f0b 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -98,6 +98,13 @@ def __init__( self.req_to_blocks: DefaultDict[str, ReqKVCacheBlocks] = defaultdict( lambda: [[] for _ in range(self.num_kv_cache_groups)]) + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: DefaultDict[ + str, List[List[BlockHashType]]] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) + @property def usage(self) -> float: return 1.0 - (self.free_block_queue.num_free_blocks / @@ -121,17 +128,19 @@ def get_computed_blocks(self, return [[] for _ in self.managers], 0 # The block hashes for the request may already be computed - # if the request was preempted and resumed. - if not request.kv_block_hashes: - request.set_kv_block_hashes([ + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = [ hash_request_tokens(manager.block_size, request, i) for i, manager in enumerate(self.managers) - ]) + ] + self.req_to_block_hashes[request.request_id] = block_hashes computed_blocks: ReqKVCacheBlocks = [] # computed blocks of each group prefix_length: List[PrefixLength] = [ ] # possible cached prefix length of each group - block_hashes = request.kv_block_hashes + for i, manager in enumerate(self.managers): prefix_length_i, computed_blocks_i = ( manager.get_possible_cached_prefix(block_hashes[i])) @@ -154,7 +163,6 @@ def get_computed_blocks(self, for i, manager in enumerate(self.managers): computed_blocks[i] = computed_blocks[i][:num_computed_tokens // manager.block_size] - return computed_blocks, num_computed_tokens def allocate_slots( @@ -560,8 +568,8 @@ def _cache_full_blocks( prev_block: The previous block in the chain. kv_cache_group_id: The KV cache group that the blocks belong to """ - num_cached_block_hashes = len( - request.kv_block_hashes[kv_cache_group_id]) + block_hashes = self.req_to_block_hashes[request.request_id] + num_cached_block_hashes = len(block_hashes[kv_cache_group_id]) # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None @@ -596,8 +604,7 @@ def _cache_full_blocks( # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. - block_hash = request.kv_block_hashes[kv_cache_group_id][ - blk_idx] + block_hash = block_hashes[kv_cache_group_id][blk_idx] else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. @@ -620,13 +627,21 @@ def _cache_full_blocks( block_hash = hash_block_tokens(prev_block_hash_value, block_tokens, kv_cache_group_id, extra_keys) - request.append_kv_block_hashes(kv_cache_group_id, block_hash) + block_hashes.append(kv_cache_group_id, block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash self.cached_block_hash_to_block[block_hash][blk.block_id] = blk prev_block_hash_value = block_hash.hash_value + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) + def get_null_block(self) -> KVCacheBlock: return self._null_block diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index cef1fa0bcb2b8..7bf6c8d933d64 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -551,6 +551,7 @@ def finish_requests( def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index ab647748f223e..8c76d72f5d812 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.core.kv_cache_utils import BlockHashType class Request: @@ -63,12 +62,6 @@ def __init__( if self.mm_hashes: assert len(self.mm_inputs) == len(self.mm_hashes) - # Cache the computed kv block hashes of the request to avoid - # recomputing. - self._kv_block_hashes: List[List[BlockHashType]] = [] - self.kv_block_hashes = ConstantList( - [ConstantList(x) for x in self._kv_block_hashes]) - # Read-only views # Prevent directly appending to the these lists since # they should also be updated simultaneously. @@ -125,18 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens - def set_kv_block_hashes(self, value: List[List["BlockHashType"]]) -> None: - self._kv_block_hashes = value - # NOTE: self.kv_block_hashes._x is not self._kv_block_hashes, but - # self.kv_block_hashes[0]._x is self._kv_block_hashes[0]. This is - # correct because we never need to update the outer list. - self.kv_block_hashes = ConstantList( - [ConstantList(x) for x in self._kv_block_hashes]) - - def append_kv_block_hashes(self, group_id: int, - block_hash: "BlockHashType") -> None: - self._kv_block_hashes[group_id].append(block_hash) - class RequestStatus(enum.IntEnum): """Status of a request.""" From 09782a20f5c01aee06040149d977167fb22bd3d4 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 7 Feb 2025 21:44:50 -0800 Subject: [PATCH 40/48] small fix Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index e77b609535f0b..40a73271fce38 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -627,7 +627,7 @@ def _cache_full_blocks( block_hash = hash_block_tokens(prev_block_hash_value, block_tokens, kv_cache_group_id, extra_keys) - block_hashes.append(kv_cache_group_id, block_hash) + block_hashes[kv_cache_group_id].append(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash From 1bdbc73359420a0634e8d3c2d42867c62d2d4e1e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 00:51:59 -0800 Subject: [PATCH 41/48] add back kvcache manager Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 419 ++++++++--------------------- vllm/v1/core/scheduler.py | 19 +- vllm/v1/kv_cache_interface.py | 20 +- vllm/v1/worker/block_table.py | 19 +- vllm/v1/worker/gpu_input_batch.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 15 +- 6 files changed, 178 insertions(+), 318 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 40a73271fce38..68dad14011620 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,18 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -import math from collections import defaultdict -from typing import DefaultDict, Dict, List, Optional, Tuple +from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, PrefixLength, - ReqKVCacheBlocks, + KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens, intersect_ranges) -from vllm.v1.core.specialized_manager import BlockPoolOperations, get_managers + hash_request_tokens) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus @@ -20,11 +17,6 @@ class KVCacheManager: - """ - The KVCacheManager for models with one KV cache type (e.g., Llama) and - thus one kv cache group (Refer to class `KVCacheConfig` for the meaning of - kv cache group). - """ def __init__( self, @@ -33,13 +25,11 @@ def __init__( enable_caching: bool = True, num_preallocate_tokens: int = 64, ) -> None: - self.kv_cache_config = kv_cache_config + self.block_size = kv_cache_config.groups[0].kv_cache_spec.block_size self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = [ - cdiv(max_model_len, g.kv_cache_spec.block_size) - for g in kv_cache_config.groups - ] + self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) + # self.sliding_window = sliding_window self.enable_caching = enable_caching # NOTE(woosuk): To avoid frequent block allocation, we preallocate some # blocks for each request. For example, when a request reaches the end @@ -51,25 +41,9 @@ def __init__( # the request gets N empty blocks, it starts to use the blocks without # further allocation. When it uses up all the N empty blocks, it gets # N new empty blocks. - # NOTE(Chen): For simplicity, we keep the number of preallocated blocks - # the same for all kv cache groups, which will result in different - # preallocated tokens for different groups if their block sizes are - # different. self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv( - num_preallocate_tokens, - max(g.kv_cache_spec.block_size for g in kv_cache_config.groups)) - - self._null_block: KVCacheBlock = KVCacheBlock(-1) - - # Specialized managers for each kv cache group, which handle the - # different kv cache management logic of different attention layers. - self.managers = get_managers( - kv_cache_config, - BlockPoolOperations(get_cached_block=self._get_cached_block, - get_null_block=self.get_null_block), - ) - self.num_kv_cache_groups = len(self.kv_cache_config.groups) + self.num_preallocate_blocks = cdiv(num_preallocate_tokens, + self.block_size) # A Block pool of all kv-cache blocks. self.block_pool: List[KVCacheBlock] = [ @@ -95,23 +69,22 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: DefaultDict[str, ReqKVCacheBlocks] = defaultdict( - lambda: [[] for _ in range(self.num_kv_cache_groups)]) + self.req_to_blocks: DefaultDict[str, + List[KVCacheBlock]] = defaultdict(list) # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: DefaultDict[ - str, List[List[BlockHashType]]] = defaultdict( - lambda: [[] for _ in range(self.num_kv_cache_groups)]) + str, List[BlockHashType]] = defaultdict(list) @property def usage(self) -> float: return 1.0 - (self.free_block_queue.num_free_blocks / self.num_gpu_blocks) - def get_computed_blocks(self, - request: Request) -> Tuple[ReqKVCacheBlocks, int]: + def get_computed_blocks( + self, request: Request) -> Tuple[List[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -120,66 +93,51 @@ def get_computed_blocks(self, Returns: A tuple containing: - - The blocks that are computed for the request + - A list of blocks that are computed for the request. - The number of computed tokens. """ if not self.enable_caching: # Prefix caching is disabled. - return [[] for _ in self.managers], 0 + return [], 0 + + computed_blocks = [] # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] if not block_hashes: - block_hashes = [ - hash_request_tokens(manager.block_size, request, i) - for i, manager in enumerate(self.managers) - ] + block_hashes = hash_request_tokens(self.block_size, request, 0) self.req_to_block_hashes[request.request_id] = block_hashes - computed_blocks: ReqKVCacheBlocks = [] # computed blocks of each group - prefix_length: List[PrefixLength] = [ - ] # possible cached prefix length of each group - - for i, manager in enumerate(self.managers): - prefix_length_i, computed_blocks_i = ( - manager.get_possible_cached_prefix(block_hashes[i])) - computed_blocks.append(computed_blocks_i) - prefix_length.append(prefix_length_i) + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self._get_cached_block(block_hash): + computed_blocks.append(cached_block) + else: + break - if len(self.kv_cache_config.groups) == 1: - # If there is only one group, we return the computed blocks and - # tokens directly. - num_computed_tokens = prefix_length[0][-1].end - else: - # Find the common cached prefix of all groups. This path also works - # for the single group case, but it is less efficient. - num_computed_tokens = self._get_common_computed_tokens( - prefix_length) - - # Truncate the computed blocks to the number of computed tokens. - # E.g., group 0 has 3 computed blocks, and group 1 has 4 computed - # blocks with the same block size, we truncate both groups to 3 blocks. - for i, manager in enumerate(self.managers): - computed_blocks[i] = computed_blocks[i][:num_computed_tokens // - manager.block_size] + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size return computed_blocks, num_computed_tokens def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[ReqKVCacheBlocks] = None, - num_new_computed_tokens: int = 0, - ) -> Optional[ReqKVCacheBlocks]: + new_computed_blocks: Optional[List[KVCacheBlock]] = None + ) -> Optional[List[KVCacheBlock]]: """Add slots for a request with new tokens to append. Args: request: The request to allocate slots. num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. - new_computed_blocks_all_groups: A list of new computed blocks - just hitting the prefix caching. + new_computed_blocks: A list of new computed blocks just hitting the + prefix caching. Blocks layout: ----------------------------------------------------------------------- @@ -199,166 +157,86 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - req_blocks = self.req_to_blocks[request.request_id] - # We can free blocks that are no longer needed even if we cannot - # schedule this request due to the limit of free blocks. - # Should call this function before allocating new blocks to reduce - # the number of evicted blocks. - self._free_useless_blocks(req_blocks, request.num_computed_tokens) - - new_computed_blocks = new_computed_blocks if new_computed_blocks is not None else [ - [] for _ in range(self.num_kv_cache_groups) - ] + new_computed_blocks = new_computed_blocks or [] # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - num_new_computed_tokens) - - num_new_blocks = [ - manager.get_num_new_blocks( - num_computed_tokens, num_tokens, - len(req_blocks[i]) + len(new_computed_blocks[i])) - for i, manager in enumerate(self.managers) - ] - - total_new_blocks = sum(max(x, 0) for x in num_new_blocks) + len(new_computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_computed_tokens + num_tokens, + self.block_size) + req_blocks = self.req_to_blocks[request.request_id] + num_new_blocks = (num_required_blocks - len(req_blocks) - + len(new_computed_blocks)) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum( - 1 for blk_group in new_computed_blocks for blk in blk_group - if blk.ref_cnt == 0) - - if (total_new_blocks > self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + if blk.ref_cnt == 0) + if (num_new_blocks > self.free_block_queue.num_free_blocks - num_evictable_computed_blocks): - # Cannot allocate new blocks. + # Cannot allocate new blocks return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: self._touch(new_computed_blocks) else: - assert all(len(blks) == 0 for blks in new_computed_blocks), ( + assert not new_computed_blocks, ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - for i, new_computed_blocks_of_group in enumerate(new_computed_blocks): - req_blocks[i].extend(new_computed_blocks_of_group) + req_blocks.extend(new_computed_blocks) # Start to handle new blocks - new_blocks: ReqKVCacheBlocks = [] - - # Truncate the number of pre-allocated blocks to ensure that we can - # have at least `num_new_blocks` free blocks for each group. - num_preallocate_blocks = min( - self.num_preallocate_blocks, - (self.free_block_queue.num_free_blocks - total_new_blocks) // - len(self.managers)) - - for i in range(self.num_kv_cache_groups): - if num_new_blocks[i] <= 0: - # No new block is needed. - new_blocks.append([]) - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_block_to_allocate = min( - num_new_blocks[i] + num_preallocate_blocks, - # Should not exceed the maximum number of blocks per request - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - self.max_num_blocks_per_req[i] - len(req_blocks[i]), - ) - - assert num_block_to_allocate >= 0 - assert num_block_to_allocate <= \ - self.free_block_queue.num_free_blocks - - new_blocks_of_group = self._get_new_blocks( - num_block_to_allocate) - new_blocks.append(new_blocks_of_group) - req_blocks[i].extend(new_blocks_of_group) + + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_new_blocks = min( + num_new_blocks + self.num_preallocate_blocks, + self.free_block_queue.num_free_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(req_blocks), + ) + assert num_new_blocks > 0 + + # Concatenate the computed block IDs and the new block IDs. + new_blocks = self._get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) if not self.enable_caching: return new_blocks - for i, manager in enumerate(self.managers): - # NOTE(rickyx): We are assuming the `num_tokens` are actual - # tokens rather than lookahead slots (e.g. for speculative decoding). - # TODO(rickyx): When supporting speculative decoding, we will need to - # differentiate between them so that we can know how many blocks are - # full after appending the actual tokens. - num_full_blocks = (num_computed_tokens + - num_tokens) // manager.block_size - num_computed_full_blocks = num_computed_tokens // manager.block_size - - new_full_blocks = req_blocks[i][ - num_computed_full_blocks:num_full_blocks] - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=num_computed_full_blocks, - # The new full blocks are the full blocks that are not - # computed. - full_blocks=new_full_blocks, - prev_block=(req_blocks[i][num_computed_full_blocks - 1] - if num_computed_full_blocks > 0 else None), - kv_cache_group_id=i, - ) + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. + num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size + num_computed_full_blocks = num_computed_tokens // self.block_size + new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=num_computed_full_blocks, + # The new full blocks are the full blocks that are not computed. + full_blocks=new_full_blocks, + prev_block=(req_blocks[num_computed_full_blocks - 1] + if num_computed_full_blocks > 0 else None)) return new_blocks - def _merge_blocks_by_eviction_order( - self, blocks: ReqKVCacheBlocks) -> List[KVCacheBlock]: - """ - Merge the blocks of different groups to one list. The returned blocks - are sorted by eviction order, with the first block having the highest - eviction priority. - - Args: - blocks: the blocks of each kv cache group, ordered by eviction - priority. - - Returns: - A list of KVCacheBlocks sorted by eviction order. - """ - - if self.enable_caching: - # NOTE (Chen): A simple strategy that interleaves the blocks of - # different KV cache groups. We can investigate more advanced - # strategies in the future. - ordered_blocks = [] - max_len = max(len(blocks_of_group) for blocks_of_group in blocks) - for i in range(max_len): - for blocks_of_group in blocks: - if i < len(blocks_of_group): - ordered_blocks.append(blocks_of_group[i]) - else: - ordered_blocks = [] - for blocks_of_group in blocks: - ordered_blocks.extend(blocks_of_group) - - return ordered_blocks - - def _free_blocks(self, blocks: ReqKVCacheBlocks) -> None: - if len(self.kv_cache_config.groups) == 1: - # Fast path for single kv cache group models. - ordered_blocks = blocks[0] - else: - ordered_blocks = self._merge_blocks_by_eviction_order(blocks) - for block in ordered_blocks: - if block == self._null_block: - continue - block.decr_ref() - if block.ref_cnt == 0: - self.free_block_queue.append(block) - def free(self, request: Request) -> None: """Free the blocks allocated for the request. When caching is enabled, we free the blocks in reverse order so that @@ -369,13 +247,16 @@ def free(self, request: Request) -> None: """ # Default to [] in case a request is freed (aborted) before alloc. blocks = self.req_to_blocks.pop(request.request_id, []) - if len(blocks) == 0: - # This request is freed before alloc. just return - return - else: - # Reverse the blocks so that the tail blocks can have higher - # eviction priority. - self._free_blocks([list(reversed(blks)) for blks in blocks]) + ordered_blocks: Iterable[KVCacheBlock] = blocks + if self.enable_caching: + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(blocks) + + for block in ordered_blocks: + block.decr_ref() + if block.ref_cnt == 0: + self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -408,7 +289,7 @@ def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> List[int]: + ) -> int: """Calculate the number of common prefix blocks shared by all requests in the RUNNING state. @@ -442,20 +323,17 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - List[int]: The number of common prefix blocks per KV cache group. + int: The number of common prefix blocks. """ assert request.status == RequestStatus.RUNNING blocks = self.req_to_blocks[request.request_id] - num_common_blocks_per_group = [] - for blocks_of_group in blocks: - num_common_blocks = 0 - for block in blocks_of_group: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - num_common_blocks_per_group.append(num_common_blocks) - return num_common_blocks_per_group + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: """Get new blocks from the free block pool. @@ -528,7 +406,7 @@ def _get_cached_block(self, return self.cached_block_hash_to_block[block_hash][first_block_id] return None - def _touch(self, blocks: ReqKVCacheBlocks) -> None: + def _touch(self, blocks: List[KVCacheBlock]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -536,13 +414,12 @@ def _touch(self, blocks: ReqKVCacheBlocks) -> None: Args: blocks: A list of blocks to touch. """ - for blocks_of_group in blocks: - for block in blocks_of_group: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0 and block != self._null_block: - self.free_block_queue.remove(block) - block.incr_ref() + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0: + self.free_block_queue.remove(block) + block.incr_ref() def _cache_full_blocks( self, @@ -550,7 +427,6 @@ def _cache_full_blocks( blk_start_idx: int, full_blocks: List[KVCacheBlock], prev_block: Optional[KVCacheBlock], - kv_cache_group_id: int, ) -> None: """Cache a list of full blocks for prefix caching. @@ -566,10 +442,9 @@ def _cache_full_blocks( to cache. full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. - kv_cache_group_id: The KV cache group that the blocks belong to """ block_hashes = self.req_to_block_hashes[request.request_id] - num_cached_block_hashes = len(block_hashes[kv_cache_group_id]) + num_cached_block_hashes = len(block_hashes) # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None @@ -579,8 +454,6 @@ def _cache_full_blocks( assert prev_block.block_hash is not None prev_block_hash_value = prev_block.block_hash.hash_value - block_size = self.kv_cache_config.groups[ - kv_cache_group_id].kv_cache_spec.block_size # Find the first uncached block. This case should only happen when # speculative decoding is used. offset = 0 @@ -604,16 +477,16 @@ def _cache_full_blocks( # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. - block_hash = block_hashes[kv_cache_group_id][blk_idx] + block_hash = block_hashes[blk_idx] else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. - start_token_idx = blk_idx * block_size - end_token_idx = (blk_idx + 1) * block_size + start_token_idx = blk_idx * self.block_size + end_token_idx = (blk_idx + 1) * self.block_size block_tokens = request.all_token_ids[ start_token_idx:end_token_idx] - assert len(block_tokens) == block_size, ( - f"Expected {block_size} tokens, got " + assert len(block_tokens) == self.block_size, ( + f"Expected {self.block_size} tokens, got " f"{len(block_tokens)} at {blk_idx}th block for request " f"{request.request_id}({request})") @@ -625,9 +498,8 @@ def _cache_full_blocks( # Compute the hash of the current block. block_hash = hash_block_tokens(prev_block_hash_value, - block_tokens, kv_cache_group_id, - extra_keys) - block_hashes[kv_cache_group_id].append(block_hash) + block_tokens, 0, extra_keys) + block_hashes.append(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash @@ -641,56 +513,3 @@ def free_block_hashes(self, request: Request) -> None: is finished, not when it is preempted. """ self.req_to_block_hashes.pop(request.request_id, None) - - def get_null_block(self) -> KVCacheBlock: - return self._null_block - - def _get_common_computed_tokens(self, - prefix_length: List[PrefixLength]) -> int: - """ - Find the longest prefix that is cached by all KV cache groups. Returns - the number of tokens in that prefix. - - Args: - prefix_length (List[PrefixLength]): The valid cached prefix lengths - of each KV cache group. - - Returns: - The number of tokens in the common prefix. - """ - intersection = intersect_ranges(prefix_length) - - # Since incomplete blocks are not eligible for sharing, - # `num_computed_tokens` should be a multiple of `block_size` of - # all managers, so we take the least common multiple (LCM) of them - alignment = math.lcm( - *[manager.block_size for manager in self.managers]) - - # Get the longest common prefix that is aligned with the block size. - num_computed_tokens = 0 - for range_ in intersection: - aligned_end = cdiv(range_.end, alignment) * alignment - if aligned_end >= range_.start: - num_computed_tokens = aligned_end - break - - return num_computed_tokens - - def _free_useless_blocks(self, req_blocks: ReqKVCacheBlocks, - num_computed_tokens: int) -> None: - """ - Frees memory blocks that are not needed. E.g., sliding window - layer with window size 2 and block size 1, we have req_blocks as - [[1, 2, 3]], this function will free block 1 and change the req_blocks - to [[-1, 2, 3]] (-1 refers to null block) - - Args: - req_blocks: The KV cache blocks of one request. - num_computed_tokens: The number of computed tokens. - """ - removed_blocks = [] - for manager, req_blocks_of_group in zip(self.managers, req_blocks): - removed_blocks.append( - manager.remove_useless_blocks(req_blocks_of_group, - num_computed_tokens)) - self._free_blocks(removed_blocks) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 7bf6c8d933d64..95fea07a7ac78 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -12,8 +12,8 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import (FullAttentionSpec, GroupedBlockIDs, - KVCacheConfig) +from vllm.v1.kv_cache_interface import (BlockIDGenerator, FullAttentionSpec, + GroupedBlockIDs, KVCacheConfig) from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -55,6 +55,7 @@ def __init__( max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching) self.block_size = self.cache_config.block_size + BlockIDGenerator.num_kv_cache_groups = len(kv_cache_config.groups) # req_id -> Request self.requests: Dict[str, Request] = {} @@ -166,8 +167,7 @@ def schedule(self) -> "SchedulerOutput": # Schedule the request. scheduled_running_reqs.append(request) req_to_new_block_ids[ - request.request_id] = GroupedBlockIDs.from_kv_cache_blocks( - new_blocks) + request.request_id] = BlockIDGenerator.generate(new_blocks) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens @@ -219,7 +219,7 @@ def schedule(self) -> "SchedulerOutput": block_size = kv_groups[0].kv_cache_spec.block_size num_computed_tokens -= block_size num_new_tokens = block_size - computed_blocks[0].pop() + computed_blocks.pop() num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -233,8 +233,7 @@ def schedule(self) -> "SchedulerOutput": break new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks, - num_computed_tokens) + request, num_new_tokens, computed_blocks) if new_blocks is None: # The request cannot be scheduled. break @@ -250,9 +249,9 @@ def schedule(self) -> "SchedulerOutput": f"Invalid request status: {request.status}") req_to_new_block_ids[ - request.request_id] = GroupedBlockIDs.from_kv_cache_blocks( - computed_blocks - ) + GroupedBlockIDs.from_kv_cache_blocks(new_blocks) + request.request_id] = BlockIDGenerator.generate( + computed_blocks) + BlockIDGenerator.generate( + new_blocks) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 26fb35a780d17..38e13a08c225c 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Union import torch @@ -185,3 +185,21 @@ def __add__(self, other: "GroupedBlockIDs") -> "GroupedBlockIDs": def get_group(self, group_idx: int) -> List[int]: return self._block_ids[group_idx] + + +MayGroupedBlockIDs = Union[GroupedBlockIDs, List[int]] +MayGroupedInt = Union[int, List[int]] + + +class BlockIDGenerator: + num_kv_cache_groups: int + + @classmethod + def generate( + cls, kv_cache_blocks: Union[List["KVCacheBlock"], + List[List["KVCacheBlock"]]] + ) -> MayGroupedBlockIDs: + if cls.num_kv_cache_groups == 1: + return [blk.block_id for blk in kv_cache_blocks] + else: + return GroupedBlockIDs.from_kv_cache_blocks(kv_cache_blocks) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 7692972f0316f..06dd0c593f51d 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List +from typing import Callable, List, Union import numpy as np import torch @@ -144,3 +144,20 @@ def grouped_func(block_ids: "GroupedBlockIDs", *args: P.args, def __getitem__(self, idx: int) -> "BlockTable": return self.block_tables[idx] + + +def initialize_block_table( + max_num_reqs: int, + max_model_len: int, + max_num_tokens: int, + pin_memory: bool, + device: torch.device, + kv_cache_config: KVCacheConfig, +) -> Union[BlockTable, GroupedBlockTable]: + if len(kv_cache_config.groups) == 1: + return BlockTable(max_num_reqs, max_model_len, max_num_tokens, + pin_memory, device, + kv_cache_config.groups[0].kv_cache_spec) + else: + return GroupedBlockTable(max_num_reqs, max_model_len, max_num_tokens, + pin_memory, device, kv_cache_config) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c4b7cdc99f1fa..745ff5dac8817 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -12,7 +12,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.kv_cache_interface import GroupedBlockIDs, KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import GroupedBlockTable +from vllm.v1.worker.block_table import initialize_block_table if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -78,7 +78,7 @@ def __init__( self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Block table. - self.block_table = GroupedBlockTable( + self.block_table = initialize_block_table( max_num_reqs=max_num_reqs, max_model_len=max_model_len, max_num_tokens=max_num_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7936f8b0d1a1e..860e6200ba2cd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -437,9 +438,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): attn_metadata: Dict[str, FlashAttentionMetadata] = {} + if len(self.kv_cache_config.groups) == 1: + may_grouped_unwrapper = lambda x, _group_id: x + else: + may_grouped_unwrapper = lambda x, group_id: x[group_id] + for group_id, kv_cache_group in enumerate(self.kv_cache_config.groups): block_size = kv_cache_group.kv_cache_spec.block_size - block_table = self.input_batch.block_table[group_id] + block_table: BlockTable = may_grouped_unwrapper( + self.input_batch.block_table, group_id) # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] @@ -466,9 +473,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): .to(self.device, non_blocking=True).long() # Prepare for cascade attention if needed. - common_prefix_len = ( - scheduler_output.num_common_prefix_blocks[group_id] * - block_size) + common_prefix_len = (may_grouped_unwrapper( + scheduler_output.num_common_prefix_blocks, group_id) * + block_size) if common_prefix_len == 0: # Common case. use_cascade = False From 7bc9f7df63da8dbeae00ba68b8f951588faa3763 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 04:37:08 -0800 Subject: [PATCH 42/48] add back hybrid kvcache manager Signed-off-by: Chen Zhang --- vllm/v1/core/hybrid_kv_cache_manager.py | 696 ++++++++++++++++++++++++ vllm/v1/core/kv_cache_manager.py | 29 +- vllm/v1/core/scheduler.py | 7 +- 3 files changed, 727 insertions(+), 5 deletions(-) create mode 100644 vllm/v1/core/hybrid_kv_cache_manager.py diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py new file mode 100644 index 0000000000000..614241fee4d52 --- /dev/null +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -0,0 +1,696 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import defaultdict +from typing import DefaultDict, Dict, List, Optional, Tuple + +from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, + KVCacheBlock, PrefixLength, + ReqKVCacheBlocks, + generate_block_hash_extra_keys, + hash_block_tokens, + hash_request_tokens, intersect_ranges) +from vllm.v1.core.specialized_manager import BlockPoolOperations, get_managers +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class HybridKVCacheManager: + """ + The HybridKVCacheManager for models with multiple KV cache types + (e.g., Gemma-2) and thus multiple kv cache group (Refer to class + `KVCacheConfig` for the meaning of kv cache group). + """ + + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + enable_caching: bool = True, + num_preallocate_tokens: int = 64, + ) -> None: + self.kv_cache_config = kv_cache_config + self.num_gpu_blocks = kv_cache_config.num_blocks + self.max_model_len = max_model_len + self.max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.groups + ] + self.enable_caching = enable_caching + # NOTE(woosuk): To avoid frequent block allocation, we preallocate some + # blocks for each request. For example, when a request reaches the end + # of its block table, we preallocate N blocks in advance. This way, we + # reduce the overhead of updating free_block_ids and ref_cnts for each + # request every step (at the cost of some memory waste). + # NOTE(woosuk): This is different from the "lookahead" slots since this + # does not guarantee that the request always has N empty blocks. After + # the request gets N empty blocks, it starts to use the blocks without + # further allocation. When it uses up all the N empty blocks, it gets + # N new empty blocks. + # NOTE(Chen): For simplicity, we keep the number of preallocated blocks + # the same for all kv cache groups, which will result in different + # preallocated tokens for different groups if their block sizes are + # different. + self.num_preallocate_tokens = num_preallocate_tokens + self.num_preallocate_blocks = cdiv( + num_preallocate_tokens, + max(g.kv_cache_spec.block_size for g in kv_cache_config.groups)) + + self._null_block: KVCacheBlock = KVCacheBlock(-1) + + # Specialized managers for each kv cache group, which handle the + # different kv cache management logic of different attention layers. + self.managers = get_managers( + kv_cache_config, + BlockPoolOperations(get_cached_block=self._get_cached_block, + get_null_block=self.get_null_block), + ) + self.num_kv_cache_groups = len(self.kv_cache_config.groups) + + # A Block pool of all kv-cache blocks. + self.block_pool: List[KVCacheBlock] = [ + KVCacheBlock(idx) for idx in range(self.num_gpu_blocks) + ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). + self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) + + # {block_hash: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. + # The cached block may be used by running requests or in the + # free_block_queue that could potentially be evicted. + # NOTE: We currently don't de-duplicate the blocks in the cache, + # meaning that if a block becomes full and is cached, we don't check + # if there is already an identical block in the cache. This is because + # we want to make sure the allocated block IDs won't change so that + # block tables are append-only. + self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ + int, KVCacheBlock]] = defaultdict(dict) + + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: DefaultDict[str, ReqKVCacheBlocks] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) + + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: DefaultDict[ + str, List[List[BlockHashType]]] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) + + @property + def usage(self) -> float: + return 1.0 - (self.free_block_queue.num_free_blocks / + self.num_gpu_blocks) + + def get_computed_blocks(self, + request: Request) -> Tuple[ReqKVCacheBlocks, int]: + """Get the computed (cached) blocks for the request. + Note that the computed blocks must be full. + + Args: + request: The request to get the computed blocks. + + Returns: + A tuple containing: + - The blocks that are computed for the request + - The number of computed tokens. + """ + if not self.enable_caching: + # Prefix caching is disabled. + return [[] for _ in self.managers], 0 + + # The block hashes for the request may already be computed + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = [ + hash_request_tokens(manager.block_size, request, i) + for i, manager in enumerate(self.managers) + ] + self.req_to_block_hashes[request.request_id] = block_hashes + + computed_blocks: ReqKVCacheBlocks = [] # computed blocks of each group + prefix_length: List[PrefixLength] = [ + ] # possible cached prefix length of each group + + for i, manager in enumerate(self.managers): + prefix_length_i, computed_blocks_i = ( + manager.get_possible_cached_prefix(block_hashes[i])) + computed_blocks.append(computed_blocks_i) + prefix_length.append(prefix_length_i) + + if len(self.kv_cache_config.groups) == 1: + # If there is only one group, we return the computed blocks and + # tokens directly. + num_computed_tokens = prefix_length[0][-1].end + else: + # Find the common cached prefix of all groups. This path also works + # for the single group case, but it is less efficient. + num_computed_tokens = self._get_common_computed_tokens( + prefix_length) + + # Truncate the computed blocks to the number of computed tokens. + # E.g., group 0 has 3 computed blocks, and group 1 has 4 computed + # blocks with the same block size, we truncate both groups to 3 blocks. + for i, manager in enumerate(self.managers): + computed_blocks[i] = computed_blocks[i][:num_computed_tokens // + manager.block_size] + return computed_blocks, num_computed_tokens + + def allocate_slots( + self, + request: Request, + num_tokens: int, + new_computed_blocks: Optional[ReqKVCacheBlocks] = None, + num_new_computed_tokens: int = 0, + ) -> Optional[ReqKVCacheBlocks]: + """Add slots for a request with new tokens to append. + + Args: + request: The request to allocate slots. + num_tokens: The number of tokens to allocate. Note that this does + not include the tokens that have already been computed. + new_computed_blocks_all_groups: A list of new computed blocks + just hitting the prefix caching. + + Blocks layout: + ----------------------------------------------------------------------- + | < computed > | < new computed > | < new > | < pre-allocated > | + ----------------------------------------------------------------------- + | < required > | + -------------------------------------------------- + | < full > | + ------------------------------------------------ + | | + -------------- + The following *_blocks are illustrated in this layout. + + Returns: + A list of new allocated blocks. + """ + if num_tokens == 0: + raise ValueError("num_tokens must be greater than 0") + + req_blocks = self.req_to_blocks[request.request_id] + # We can free blocks that are no longer needed even if we cannot + # schedule this request due to the limit of free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. + self._free_useless_blocks(req_blocks, request.num_computed_tokens) + + new_computed_blocks = new_computed_blocks if new_computed_blocks is not None else [ + [] for _ in range(self.num_kv_cache_groups) + ] + + # The number of computed tokens is the number of computed tokens plus + # the new prefix caching hits + num_computed_tokens = (request.num_computed_tokens + + num_new_computed_tokens) + + num_new_blocks = [ + manager.get_num_new_blocks( + num_computed_tokens, num_tokens, + len(req_blocks[i]) + len(new_computed_blocks[i])) + for i, manager in enumerate(self.managers) + ] + + total_new_blocks = sum(max(x, 0) for x in num_new_blocks) + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = sum( + 1 for blk_group in new_computed_blocks for blk in blk_group + if blk.ref_cnt == 0) + + if (total_new_blocks > self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks): + # Cannot allocate new blocks. + return None + + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self._touch(new_computed_blocks) + else: + assert all(len(blks) == 0 for blks in new_computed_blocks), ( + "Computed blocks should be empty when " + "prefix caching is disabled") + + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + for i, new_computed_blocks_of_group in enumerate(new_computed_blocks): + req_blocks[i].extend(new_computed_blocks_of_group) + + # Start to handle new blocks + new_blocks: ReqKVCacheBlocks = [] + + # Truncate the number of pre-allocated blocks to ensure that we can + # have at least `num_new_blocks` free blocks for each group. + num_preallocate_blocks = min( + self.num_preallocate_blocks, + (self.free_block_queue.num_free_blocks - total_new_blocks) // + len(self.managers)) + + for i in range(self.num_kv_cache_groups): + if num_new_blocks[i] <= 0: + # No new block is needed. + new_blocks.append([]) + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_block_to_allocate = min( + num_new_blocks[i] + num_preallocate_blocks, + # Should not exceed the maximum number of blocks per request + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req[i] - len(req_blocks[i]), + ) + + assert num_block_to_allocate >= 0 + assert num_block_to_allocate <= \ + self.free_block_queue.num_free_blocks + + new_blocks_of_group = self._get_new_blocks( + num_block_to_allocate) + new_blocks.append(new_blocks_of_group) + req_blocks[i].extend(new_blocks_of_group) + + if not self.enable_caching: + return new_blocks + + for i, manager in enumerate(self.managers): + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. + num_full_blocks = (num_computed_tokens + + num_tokens) // manager.block_size + num_computed_full_blocks = num_computed_tokens // manager.block_size + + new_full_blocks = req_blocks[i][ + num_computed_full_blocks:num_full_blocks] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=num_computed_full_blocks, + # The new full blocks are the full blocks that are not + # computed. + full_blocks=new_full_blocks, + prev_block=(req_blocks[i][num_computed_full_blocks - 1] + if num_computed_full_blocks > 0 else None), + kv_cache_group_id=i, + ) + + return new_blocks + + def _merge_blocks_by_eviction_order( + self, blocks: ReqKVCacheBlocks) -> List[KVCacheBlock]: + """ + Merge the blocks of different groups to one list. The returned blocks + are sorted by eviction order, with the first block having the highest + eviction priority. + + Args: + blocks: the blocks of each kv cache group, ordered by eviction + priority. + + Returns: + A list of KVCacheBlocks sorted by eviction order. + """ + + if self.enable_caching: + # NOTE (Chen): A simple strategy that interleaves the blocks of + # different KV cache groups. We can investigate more advanced + # strategies in the future. + ordered_blocks = [] + max_len = max(len(blocks_of_group) for blocks_of_group in blocks) + for i in range(max_len): + for blocks_of_group in blocks: + if i < len(blocks_of_group): + ordered_blocks.append(blocks_of_group[i]) + else: + ordered_blocks = [] + for blocks_of_group in blocks: + ordered_blocks.extend(blocks_of_group) + + return ordered_blocks + + def _free_blocks(self, blocks: ReqKVCacheBlocks) -> None: + if len(self.kv_cache_config.groups) == 1: + # Fast path for single kv cache group models. + ordered_blocks = blocks[0] + else: + ordered_blocks = self._merge_blocks_by_eviction_order(blocks) + for block in ordered_blocks: + if block == self._null_block: + continue + block.decr_ref() + if block.ref_cnt == 0: + self.free_block_queue.append(block) + + def free(self, request: Request) -> None: + """Free the blocks allocated for the request. + When caching is enabled, we free the blocks in reverse order so that + the tail blocks are evicted first. + + Args: + request: The request to free the blocks. + """ + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request.request_id, []) + if len(blocks) == 0: + # This request is freed before alloc. just return + return + else: + # Reverse the blocks so that the tail blocks can have higher + # eviction priority. + self._free_blocks([list(reversed(blks)) for blks in blocks]) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.num_gpu_blocks - + self.free_block_queue.num_free_blocks) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Remove all hashes so that no new blocks will hit. + self.cached_block_hash_to_block = defaultdict(dict) + + # Remove all hashes from all blocks. + for block in self.block_pool: + block.reset_hash() + + logger.info("Successfully reset prefix cache") + return True + + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> List[int]: + """Calculate the number of common prefix blocks shared by all requests + in the RUNNING state. + + The function determines this by selecting any request and iterating + through its blocks. A block is considered a common prefix block if its + `ref_cnt` equals the total number of requests in the RUNNING state. + + NOTE(woosuk): The number of requests in the RUNNING state is **greater + than or equal to** the number of requests scheduled in the current step. + This is because the RUNNING state only indicates that: + 1. The request has not yet finished, and + 2. The request holds its blocks unfreed. + + While all scheduled requests must be in the RUNNING state, the inverse + is not necessarily true. There may be RUNNING requests that are not + scheduled in the current step. As of 1/1/2025, the scheduler does not + allow this case, but it is possible in the future, as we allow more + flexible scheduling. + + This can result in an edge case where the number of common prefix blocks + is 0, even though all scheduled requests share a common prefix. This + occurs because there may be unscheduled RUNNING requests that do not + share the common prefix. Currently, this case cannot be easily detected, + so the function returns 0 in such cases. + + Args: + request: Any request in the RUNNING state, used to identify the + common prefix blocks. + num_running_requests: The total number of requests in the RUNNING + state. This can be different from the number of scheduled + requests in the current step. + + Returns: + List[int]: The number of common prefix blocks per KV cache group. + """ + assert request.status == RequestStatus.RUNNING + blocks = self.req_to_blocks[request.request_id] + num_common_blocks_per_group = [] + for blocks_of_group in blocks: + num_common_blocks = 0 + for block in blocks_of_group: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + num_common_blocks_per_group.append(num_common_blocks) + return num_common_blocks_per_group + + def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + + Note that we do not check block cache in this function. + + Args: + num_blocks: The number of blocks to allocate. + + Returns: + A list of new block. + """ + if num_blocks > self.free_block_queue.num_free_blocks: + raise ValueError( + f"Cannot get {num_blocks} free blocks from the pool") + + ret: List[KVCacheBlock] = [] + idx = 0 + while idx < num_blocks: + # First allocate blocks. + curr_block = self.free_block_queue.popleft() + assert curr_block.ref_cnt == 0 + + # If the block is cached, evict it. + if self.enable_caching: + self._maybe_evict_cached_block(curr_block) + + curr_block.incr_ref() + ret.append(curr_block) + idx += 1 + + return ret + + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. + + Args: + block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. + """ + block_hash = block.block_hash + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] + + return True + return False + + def _get_cached_block(self, + block_hash: BlockHashType) -> Optional[KVCacheBlock]: + """Get a cached block by the block hash, or None if cache miss. + If there are duplicated blocks, we return the first block in the cache. + + Args: + block_hash: The hash value of the block. + + Returns: + The cached block if it exists, or None. + """ + if block_hash in self.cached_block_hash_to_block: + first_block_id = list( + self.cached_block_hash_to_block[block_hash].keys())[0] + return self.cached_block_hash_to_block[block_hash][first_block_id] + return None + + def _touch(self, blocks: ReqKVCacheBlocks) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for blocks_of_group in blocks: + for block in blocks_of_group: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and block != self._null_block: + self.free_block_queue.remove(block) + block.incr_ref() + + def _cache_full_blocks( + self, + request: Request, + blk_start_idx: int, + full_blocks: List[KVCacheBlock], + prev_block: Optional[KVCacheBlock], + kv_cache_group_id: int, + ) -> None: + """Cache a list of full blocks for prefix caching. + + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `blk_start_idx` to the end + of the request's full blocks, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blk_start_idx: The index of the first block in the request's blocks + to cache. + full_blocks: The list of blocks to update hash metadata. + prev_block: The previous block in the chain. + kv_cache_group_id: The KV cache group that the blocks belong to + """ + block_hashes = self.req_to_block_hashes[request.request_id] + num_cached_block_hashes = len(block_hashes[kv_cache_group_id]) + + # Update the new blocks with the block hashes through the chain. + prev_block_hash_value = None + if prev_block is not None: + # Previous block must have a block hash because it must be + # a full, cached block. + assert prev_block.block_hash is not None + prev_block_hash_value = prev_block.block_hash.hash_value + + block_size = self.kv_cache_config.groups[ + kv_cache_group_id].kv_cache_spec.block_size + # Find the first uncached block. This case should only happen when + # speculative decoding is used. + offset = 0 + for blk in full_blocks: + if blk.block_hash is None: + break + else: + prev_block_hash_value = blk.block_hash.hash_value + offset += 1 + else: + # All blocks are cached. + return + + for i, blk in enumerate(full_blocks[offset:]): + blk_idx = blk_start_idx + offset + i + assert blk.block_hash is None + + if blk_idx < num_cached_block_hashes: + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = block_hashes[kv_cache_group_id][blk_idx] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * block_size + end_token_idx = (blk_idx + 1) * block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == block_size, ( + f"Expected {block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, kv_cache_group_id, + extra_keys) + block_hashes[kv_cache_group_id].append(block_hash) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash_value = block_hash.hash_value + + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) + + def get_null_block(self) -> KVCacheBlock: + return self._null_block + + def _get_common_computed_tokens(self, + prefix_length: List[PrefixLength]) -> int: + """ + Find the longest prefix that is cached by all KV cache groups. Returns + the number of tokens in that prefix. + + Args: + prefix_length (List[PrefixLength]): The valid cached prefix lengths + of each KV cache group. + + Returns: + The number of tokens in the common prefix. + """ + intersection = intersect_ranges(prefix_length) + + # Since incomplete blocks are not eligible for sharing, + # `num_computed_tokens` should be a multiple of `block_size` of + # all managers, so we take the least common multiple (LCM) of them + alignment = math.lcm( + *[manager.block_size for manager in self.managers]) + + # Get the longest common prefix that is aligned with the block size. + num_computed_tokens = 0 + for range_ in intersection: + aligned_end = cdiv(range_.end, alignment) * alignment + if aligned_end >= range_.start: + num_computed_tokens = aligned_end + break + + return num_computed_tokens + + def _free_useless_blocks(self, req_blocks: ReqKVCacheBlocks, + num_computed_tokens: int) -> None: + """ + Frees memory blocks that are not needed. E.g., sliding window + layer with window size 2 and block size 1, we have req_blocks as + [[1, 2, 3]], this function will free block 1 and change the req_blocks + to [[-1, 2, 3]] (-1 refers to null block) + + Args: + req_blocks: The KV cache blocks of one request. + num_computed_tokens: The number of computed tokens. + """ + removed_blocks = [] + for manager, req_blocks_of_group in zip(self.managers, req_blocks): + removed_blocks.append( + manager.remove_useless_blocks(req_blocks_of_group, + num_computed_tokens)) + self._free_blocks(removed_blocks) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 68dad14011620..88020bb26b288 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -5,6 +5,7 @@ from vllm.logger import init_logger from vllm.utils import cdiv +from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, @@ -128,7 +129,8 @@ def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[List[KVCacheBlock]] = None + new_computed_blocks: Optional[List[KVCacheBlock]] = None, + num_new_computed_tokens: int = 0, ) -> Optional[List[KVCacheBlock]]: """Add slots for a request with new tokens to append. @@ -162,7 +164,7 @@ def allocate_slots( # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) + num_new_computed_tokens) num_required_blocks = cdiv(num_computed_tokens + num_tokens, self.block_size) req_blocks = self.req_to_blocks[request.request_id] @@ -513,3 +515,26 @@ def free_block_hashes(self, request: Request) -> None: is finished, not when it is preempted. """ self.req_to_block_hashes.pop(request.request_id, None) + + +def init_kv_cache_manager(kv_cache_config: KVCacheConfig, + max_model_len: int, + enable_caching: bool = True, + num_preallocate_tokens: int = 64): + print("kv_cache_config", kv_cache_config) + if len(kv_cache_config.groups) > 1: + logger.info("Using HybridKVCacheManager") + return HybridKVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=max_model_len, + enable_caching=enable_caching, + num_preallocate_tokens=num_preallocate_tokens, + ) + else: + logger.info("Using KVCacheManager") + return KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=max_model_len, + enable_caching=enable_caching, + num_preallocate_tokens=num_preallocate_tokens, + ) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 95fea07a7ac78..6bd793f7b1841 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -10,7 +10,7 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheManager, init_kv_cache_manager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import (BlockIDGenerator, FullAttentionSpec, GroupedBlockIDs, KVCacheConfig) @@ -50,7 +50,7 @@ def __init__( num_gpu_blocks = cache_config.num_gpu_blocks assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( + self.kv_cache_manager = init_kv_cache_manager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching) @@ -233,7 +233,8 @@ def schedule(self) -> "SchedulerOutput": break new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens, computed_blocks, + num_computed_tokens) if new_blocks is None: # The request cannot be scheduled. break From 3e5a8c353c376f82b9f4116f51bd03e11411e599 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 06:18:31 -0800 Subject: [PATCH 43/48] support block pool Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 254 +++++++++++++++++++++++ vllm/v1/core/hybrid_kv_cache_manager.py | 261 ++---------------------- vllm/v1/core/kv_cache_manager.py | 241 ++-------------------- vllm/v1/core/specialized_manager.py | 33 ++- 4 files changed, 297 insertions(+), 492 deletions(-) create mode 100644 vllm/v1/core/block_pool.py diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py new file mode 100644 index 0000000000000..2d73c2d3239e3 --- /dev/null +++ b/vllm/v1/core/block_pool.py @@ -0,0 +1,254 @@ +from collections import defaultdict +from typing import Dict, List, Optional + +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class BlockPool: + + def __init__(self, num_gpu_blocks: int, enable_caching: bool): + self.num_gpu_blocks = num_gpu_blocks + self.enable_caching = enable_caching + # A Block pool of all kv-cache blocks. + self._block_pool: List[KVCacheBlock] = [ + KVCacheBlock(idx) for idx in range(self.num_gpu_blocks) + ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). + self._free_block_queue = FreeKVCacheBlockQueue(self._block_pool) + + # {block_hash: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. + # The cached block may be used by running requests or in the + # free_block_queue that could potentially be evicted. + # NOTE: We currently don't de-duplicate the blocks in the cache, + # meaning that if a block becomes full and is cached, we don't check + # if there is already an identical block in the cache. This is because + # we want to make sure the allocated block IDs won't change so that + # block tables are append-only. + self._cached_block_hash_to_block: Dict[BlockHashType, Dict[ + int, KVCacheBlock]] = defaultdict(dict) + + self._null_block: KVCacheBlock = KVCacheBlock(-1) + + def get_cached_block(self, + block_hash: BlockHashType) -> Optional[KVCacheBlock]: + """Get a cached block by the block hash, or None if cache miss. + If there are duplicated blocks, we return the first block in the cache. + + Args: + block_hash: The hash value of the block. + + Returns: + The cached block if it exists, or None. + """ + if block_hash in self._cached_block_hash_to_block: + first_block_id = list( + self._cached_block_hash_to_block[block_hash].keys())[0] + return self._cached_block_hash_to_block[block_hash][first_block_id] + return None + + def cache_full_blocks(self, + request: Request, + block_hashes: List[BlockHashType], + block_size: int, + blk_start_idx: int, + full_blocks: List[KVCacheBlock], + prev_block: Optional[KVCacheBlock], + kv_cache_group_id: int = 0) -> None: + """Cache a list of full blocks for prefix caching. + + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `blk_start_idx` to the end + of the request's full blocks, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blk_start_idx: The index of the first block in the request's blocks + to cache. + full_blocks: The list of blocks to update hash metadata. + prev_block: The previous block in the chain. + """ + num_cached_block_hashes = len(block_hashes) + + # Update the new blocks with the block hashes through the chain. + prev_block_hash_value = None + if prev_block is not None: + # Previous block must have a block hash because it must be + # a full, cached block. + assert prev_block.block_hash is not None + prev_block_hash_value = prev_block.block_hash.hash_value + + # Find the first uncached block. This case should only happen when + # speculative decoding is used. + offset = 0 + for blk in full_blocks: + if blk.block_hash is None: + break + else: + prev_block_hash_value = blk.block_hash.hash_value + offset += 1 + else: + # All blocks are cached. + return + + for i, blk in enumerate(full_blocks[offset:]): + blk_idx = blk_start_idx + offset + i + assert blk.block_hash is None + + if blk_idx < num_cached_block_hashes: + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = block_hashes[blk_idx] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * block_size + end_token_idx = (blk_idx + 1) * block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == block_size, ( + f"Expected {block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, kv_cache_group_id, + extra_keys) + block_hashes.append(block_hash) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self._cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash_value = block_hash.hash_value + + def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + + Note that we do not check block cache in this function. + + Args: + num_blocks: The number of blocks to allocate. + + Returns: + A list of new block. + """ + if num_blocks > self._free_block_queue.num_free_blocks: + raise ValueError( + f"Cannot get {num_blocks} free blocks from the pool") + + ret: List[KVCacheBlock] = [] + idx = 0 + while idx < num_blocks: + # First allocate blocks. + curr_block = self._free_block_queue.popleft() + assert curr_block.ref_cnt == 0 + + # If the block is cached, evict it. + if self.enable_caching: + self._maybe_evict_cached_block(curr_block) + + curr_block.incr_ref() + ret.append(curr_block) + idx += 1 + + return ret + + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. + + Args: + block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. + """ + block_hash = block.block_hash + if block_hash and block_hash in self._cached_block_hash_to_block: + block.reset_hash() + del self._cached_block_hash_to_block[block_hash][block.block_id] + + if len(self._cached_block_hash_to_block[block_hash]) == 0: + del self._cached_block_hash_to_block[block_hash] + + return True + return False + + def touch(self, blocks: List[KVCacheBlock]) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0: + self._free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, ordered_blocks: List[KVCacheBlock]) -> None: + """ + TODO: add docstring + the first block will be evicted first + """ + for block in ordered_blocks: + if block == self._null_block: + continue + block.decr_ref() + if block.ref_cnt == 0: + self._free_block_queue.append(block) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.num_gpu_blocks - + self._free_block_queue.num_free_blocks) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Remove all hashes so that no new blocks will hit. + self.cached_block_hash_to_block = defaultdict(dict) + + # Remove all hashes from all blocks. + for block in self._block_pool: + block.reset_hash() + + logger.info("Successfully reset prefix cache") + return True + + def get_num_free_blocks(self) -> int: + return self._free_block_queue.num_free_blocks + + def get_null_block(self) -> KVCacheBlock: + return self._null_block diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index 614241fee4d52..24674fcd7e00a 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -6,13 +6,14 @@ from vllm.logger import init_logger from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, PrefixLength, ReqKVCacheBlocks, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens, intersect_ranges) -from vllm.v1.core.specialized_manager import BlockPoolOperations, get_managers +from vllm.v1.core.specialized_manager import get_managers from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus @@ -59,39 +60,16 @@ def __init__( self.num_preallocate_blocks = cdiv( num_preallocate_tokens, max(g.kv_cache_spec.block_size for g in kv_cache_config.groups)) - - self._null_block: KVCacheBlock = KVCacheBlock(-1) + self.block_pool = BlockPool(self.num_gpu_blocks, self.enable_caching) # Specialized managers for each kv cache group, which handle the # different kv cache management logic of different attention layers. self.managers = get_managers( kv_cache_config, - BlockPoolOperations(get_cached_block=self._get_cached_block, - get_null_block=self.get_null_block), + block_pool=self.block_pool, ) self.num_kv_cache_groups = len(self.kv_cache_config.groups) - # A Block pool of all kv-cache blocks. - self.block_pool: List[KVCacheBlock] = [ - KVCacheBlock(idx) for idx in range(self.num_gpu_blocks) - ] - # Free block queue that constructs and manipulates a doubly linked - # list of free blocks (including eviction candidates when caching is - # enabled). - self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) - - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ - int, KVCacheBlock]] = defaultdict(dict) - # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. @@ -107,7 +85,7 @@ def __init__( @property def usage(self) -> float: - return 1.0 - (self.free_block_queue.num_free_blocks / + return 1.0 - (self.block_pool.get_num_free_blocks() / self.num_gpu_blocks) def get_computed_blocks(self, @@ -231,14 +209,15 @@ def allocate_slots( 1 for blk_group in new_computed_blocks for blk in blk_group if blk.ref_cnt == 0) - if (total_new_blocks > self.free_block_queue.num_free_blocks - + if (total_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): # Cannot allocate new blocks. return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self._touch(new_computed_blocks) + for blocks in new_computed_blocks: + self.block_pool.touch(blocks) else: assert all(len(blks) == 0 for blks in new_computed_blocks), ( "Computed blocks should be empty when " @@ -256,7 +235,7 @@ def allocate_slots( # have at least `num_new_blocks` free blocks for each group. num_preallocate_blocks = min( self.num_preallocate_blocks, - (self.free_block_queue.num_free_blocks - total_new_blocks) // + (self.block_pool.get_num_free_blocks() - total_new_blocks) // len(self.managers)) for i in range(self.num_kv_cache_groups): @@ -278,9 +257,9 @@ def allocate_slots( assert num_block_to_allocate >= 0 assert num_block_to_allocate <= \ - self.free_block_queue.num_free_blocks + self.block_pool.get_num_free_blocks() - new_blocks_of_group = self._get_new_blocks( + new_blocks_of_group = self.block_pool.get_new_blocks( num_block_to_allocate) new_blocks.append(new_blocks_of_group) req_blocks[i].extend(new_blocks_of_group) @@ -301,8 +280,11 @@ def allocate_slots( new_full_blocks = req_blocks[i][ num_computed_full_blocks:num_full_blocks] if new_full_blocks: - self._cache_full_blocks( + block_hashes = self.req_to_block_hashes[request.request_id][i] + self.block_pool.cache_full_blocks( request=request, + block_hashes=block_hashes, + block_size=manager.block_size, blk_start_idx=num_computed_full_blocks, # The new full blocks are the full blocks that are not # computed. @@ -352,12 +334,7 @@ def _free_blocks(self, blocks: ReqKVCacheBlocks) -> None: ordered_blocks = blocks[0] else: ordered_blocks = self._merge_blocks_by_eviction_order(blocks) - for block in ordered_blocks: - if block == self._null_block: - continue - block.decr_ref() - if block.ref_cnt == 0: - self.free_block_queue.append(block) + self.block_pool.free_blocks(ordered_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -378,31 +355,7 @@ def free(self, request: Request) -> None: self._free_blocks([list(reversed(blks)) for blks in blocks]) def reset_prefix_cache(self) -> bool: - """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, - or used for resetting prefix caching status for benchmarking. - - Returns: - bool: True if the prefix cache is successfully reset, - False otherwise. - """ - num_used_blocks = (self.num_gpu_blocks - - self.free_block_queue.num_free_blocks) - if num_used_blocks > 0: - logger.warning( - "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) - return False - - # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) - - # Remove all hashes from all blocks. - for block in self.block_pool: - block.reset_hash() - - logger.info("Successfully reset prefix cache") - return True + return self.block_pool.reset_prefix_cache() def get_num_common_prefix_blocks( self, @@ -457,183 +410,6 @@ def get_num_common_prefix_blocks( num_common_blocks_per_group.append(num_common_blocks) return num_common_blocks_per_group - def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: - """Get new blocks from the free block pool. - - Note that we do not check block cache in this function. - - Args: - num_blocks: The number of blocks to allocate. - - Returns: - A list of new block. - """ - if num_blocks > self.free_block_queue.num_free_blocks: - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") - - ret: List[KVCacheBlock] = [] - idx = 0 - while idx < num_blocks: - # First allocate blocks. - curr_block = self.free_block_queue.popleft() - assert curr_block.ref_cnt == 0 - - # If the block is cached, evict it. - if self.enable_caching: - self._maybe_evict_cached_block(curr_block) - - curr_block.incr_ref() - ret.append(curr_block) - idx += 1 - - return ret - - def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: - """ - If a block is cached in `cached_block_hash_to_block`, we reset its hash - metadata and evict it from the cache. - - Args: - block: The block to evict. - - Returns: - True if the block is evicted, False otherwise. - """ - block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] - - return True - return False - - def _get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: - """Get a cached block by the block hash, or None if cache miss. - If there are duplicated blocks, we return the first block in the cache. - - Args: - block_hash: The hash value of the block. - - Returns: - The cached block if it exists, or None. - """ - if block_hash in self.cached_block_hash_to_block: - first_block_id = list( - self.cached_block_hash_to_block[block_hash].keys())[0] - return self.cached_block_hash_to_block[block_hash][first_block_id] - return None - - def _touch(self, blocks: ReqKVCacheBlocks) -> None: - """Touch a block increases its reference count by 1, and may remove - the block from the free queue. This is used when a block is hit by - another request with the same prefix. - - Args: - blocks: A list of blocks to touch. - """ - for blocks_of_group in blocks: - for block in blocks_of_group: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0 and block != self._null_block: - self.free_block_queue.remove(block) - block.incr_ref() - - def _cache_full_blocks( - self, - request: Request, - blk_start_idx: int, - full_blocks: List[KVCacheBlock], - prev_block: Optional[KVCacheBlock], - kv_cache_group_id: int, - ) -> None: - """Cache a list of full blocks for prefix caching. - - This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `blk_start_idx` to the end - of the request's full blocks, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. - - Args: - request: The request to cache the blocks. - blk_start_idx: The index of the first block in the request's blocks - to cache. - full_blocks: The list of blocks to update hash metadata. - prev_block: The previous block in the chain. - kv_cache_group_id: The KV cache group that the blocks belong to - """ - block_hashes = self.req_to_block_hashes[request.request_id] - num_cached_block_hashes = len(block_hashes[kv_cache_group_id]) - - # Update the new blocks with the block hashes through the chain. - prev_block_hash_value = None - if prev_block is not None: - # Previous block must have a block hash because it must be - # a full, cached block. - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.hash_value - - block_size = self.kv_cache_config.groups[ - kv_cache_group_id].kv_cache_spec.block_size - # Find the first uncached block. This case should only happen when - # speculative decoding is used. - offset = 0 - for blk in full_blocks: - if blk.block_hash is None: - break - else: - prev_block_hash_value = blk.block_hash.hash_value - offset += 1 - else: - # All blocks are cached. - return - - for i, blk in enumerate(full_blocks[offset:]): - blk_idx = blk_start_idx + offset + i - assert blk.block_hash is None - - if blk_idx < num_cached_block_hashes: - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption). In this case we simply - # reuse the block hash. - block_hash = block_hashes[kv_cache_group_id][blk_idx] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - start_token_idx = blk_idx * block_size - end_token_idx = (blk_idx + 1) * block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == block_size, ( - f"Expected {block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, - block_tokens, kv_cache_group_id, - extra_keys) - block_hashes[kv_cache_group_id].append(block_hash) - - # Update and added the full block to the cache. - blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk - prev_block_hash_value = block_hash.hash_value - def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -642,9 +418,6 @@ def free_block_hashes(self, request: Request) -> None: """ self.req_to_block_hashes.pop(request.request_id, None) - def get_null_block(self) -> KVCacheBlock: - return self._null_block - def _get_common_computed_tokens(self, prefix_length: List[PrefixLength]) -> int: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 88020bb26b288..1b023b262e8c6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -5,6 +5,7 @@ from vllm.logger import init_logger from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, @@ -46,26 +47,7 @@ def __init__( self.num_preallocate_blocks = cdiv(num_preallocate_tokens, self.block_size) - # A Block pool of all kv-cache blocks. - self.block_pool: List[KVCacheBlock] = [ - KVCacheBlock(idx) for idx in range(self.num_gpu_blocks) - ] - # Free block queue that constructs and manipulates a doubly linked - # list of free blocks (including eviction candidates when caching is - # enabled). - self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) - - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ - int, KVCacheBlock]] = defaultdict(dict) + self.block_pool = BlockPool(self.num_gpu_blocks, self.enable_caching) # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request @@ -81,7 +63,7 @@ def __init__( @property def usage(self) -> float: - return 1.0 - (self.free_block_queue.num_free_blocks / + return 1.0 - (self.block_pool.get_num_free_blocks() / self.num_gpu_blocks) def get_computed_blocks( @@ -114,7 +96,7 @@ def get_computed_blocks( # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self._get_cached_block(block_hash): + if cached_block := self.block_pool.get_cached_block(block_hash): computed_blocks.append(cached_block) else: break @@ -176,14 +158,14 @@ def allocate_slots( # when allocating this request. num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks if blk.ref_cnt == 0) - if (num_new_blocks > self.free_block_queue.num_free_blocks - + if (num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): # Cannot allocate new blocks return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self._touch(new_computed_blocks) + self.block_pool.touch(new_computed_blocks) else: assert not new_computed_blocks, ( "Computed blocks should be empty when " @@ -203,7 +185,7 @@ def allocate_slots( # preallocated blocks. num_new_blocks = min( num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, + self.block_pool.get_num_free_blocks(), # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. @@ -214,7 +196,7 @@ def allocate_slots( assert num_new_blocks > 0 # Concatenate the computed block IDs and the new block IDs. - new_blocks = self._get_new_blocks(num_new_blocks) + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) if not self.enable_caching: @@ -229,8 +211,11 @@ def allocate_slots( num_computed_full_blocks = num_computed_tokens // self.block_size new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] if new_full_blocks: - self._cache_full_blocks( + block_hashes = self.req_to_block_hashes[request.request_id] + self.block_pool.cache_full_blocks( request=request, + block_hashes=block_hashes, + block_size=self.block_size, blk_start_idx=num_computed_full_blocks, # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, @@ -255,37 +240,10 @@ def free(self, request: Request) -> None: # freed first. ordered_blocks = reversed(blocks) - for block in ordered_blocks: - block.decr_ref() - if block.ref_cnt == 0: - self.free_block_queue.append(block) + self.block_pool.free_blocks(ordered_blocks) def reset_prefix_cache(self) -> bool: - """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, - or used for resetting prefix caching status for benchmarking. - - Returns: - bool: True if the prefix cache is successfully reset, - False otherwise. - """ - num_used_blocks = (self.num_gpu_blocks - - self.free_block_queue.num_free_blocks) - if num_used_blocks > 0: - logger.warning( - "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) - return False - - # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) - - # Remove all hashes from all blocks. - for block in self.block_pool: - block.reset_hash() - - logger.info("Successfully reset prefix cache") - return True + return self.block_pool.reset_prefix_cache() def get_num_common_prefix_blocks( self, @@ -337,177 +295,6 @@ def get_num_common_prefix_blocks( break return num_common_blocks - def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: - """Get new blocks from the free block pool. - - Note that we do not check block cache in this function. - - Args: - num_blocks: The number of blocks to allocate. - - Returns: - A list of new block. - """ - if num_blocks > self.free_block_queue.num_free_blocks: - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") - - ret: List[KVCacheBlock] = [] - idx = 0 - while idx < num_blocks: - # First allocate blocks. - curr_block = self.free_block_queue.popleft() - assert curr_block.ref_cnt == 0 - - # If the block is cached, evict it. - if self.enable_caching: - self._maybe_evict_cached_block(curr_block) - - curr_block.incr_ref() - ret.append(curr_block) - idx += 1 - - return ret - - def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: - """ - If a block is cached in `cached_block_hash_to_block`, we reset its hash - metadata and evict it from the cache. - - Args: - block: The block to evict. - - Returns: - True if the block is evicted, False otherwise. - """ - block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] - - return True - return False - - def _get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: - """Get a cached block by the block hash, or None if cache miss. - If there are duplicated blocks, we return the first block in the cache. - - Args: - block_hash: The hash value of the block. - - Returns: - The cached block if it exists, or None. - """ - if block_hash in self.cached_block_hash_to_block: - first_block_id = list( - self.cached_block_hash_to_block[block_hash].keys())[0] - return self.cached_block_hash_to_block[block_hash][first_block_id] - return None - - def _touch(self, blocks: List[KVCacheBlock]) -> None: - """Touch a block increases its reference count by 1, and may remove - the block from the free queue. This is used when a block is hit by - another request with the same prefix. - - Args: - blocks: A list of blocks to touch. - """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0: - self.free_block_queue.remove(block) - block.incr_ref() - - def _cache_full_blocks( - self, - request: Request, - blk_start_idx: int, - full_blocks: List[KVCacheBlock], - prev_block: Optional[KVCacheBlock], - ) -> None: - """Cache a list of full blocks for prefix caching. - - This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `blk_start_idx` to the end - of the request's full blocks, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. - - Args: - request: The request to cache the blocks. - blk_start_idx: The index of the first block in the request's blocks - to cache. - full_blocks: The list of blocks to update hash metadata. - prev_block: The previous block in the chain. - """ - block_hashes = self.req_to_block_hashes[request.request_id] - num_cached_block_hashes = len(block_hashes) - - # Update the new blocks with the block hashes through the chain. - prev_block_hash_value = None - if prev_block is not None: - # Previous block must have a block hash because it must be - # a full, cached block. - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.hash_value - - # Find the first uncached block. This case should only happen when - # speculative decoding is used. - offset = 0 - for blk in full_blocks: - if blk.block_hash is None: - break - else: - prev_block_hash_value = blk.block_hash.hash_value - offset += 1 - else: - # All blocks are cached. - return - - for i, blk in enumerate(full_blocks[offset:]): - blk_idx = blk_start_idx + offset + i - assert blk.block_hash is None - - if blk_idx < num_cached_block_hashes: - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption). In this case we simply - # reuse the block hash. - block_hash = block_hashes[blk_idx] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - start_token_idx = blk_idx * self.block_size - end_token_idx = (blk_idx + 1) * self.block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, - block_tokens, 0, extra_keys) - block_hashes.append(block_hash) - - # Update and added the full block to the cache. - blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk - prev_block_hash_value = block_hash.hash_value - def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index c0f626218fa15..52c762bcf6210 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Type from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, PrefixLength, PrefixLengthRange) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -11,12 +12,6 @@ from vllm.v1.utils import ConstantList -@dataclass -class BlockPoolOperations: - get_cached_block: Callable[[BlockHashType], Optional[KVCacheBlock]] - get_null_block: Callable[[], KVCacheBlock] - - class SpecializedManager(ABC): """ An abstract base class for specialized managers that handle the kv @@ -28,14 +23,14 @@ class SpecializedManager(ABC): def __init__( self, kv_cache_spec: KVCacheSpec, - block_pool_operations: BlockPoolOperations, + block_pool: BlockPool, ) -> None: """ Initializes the SpecializedManager. Args: kv_cache_spec: The kv_cache_spec for this manager. - block_pool_operations: Operations to interact with the block pool. + block_pool_operations: Operations to interact with the block pool. TODO update Returns: None @@ -43,7 +38,7 @@ def __init__( self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_pool_operations = block_pool_operations + self.block_pool = block_pool @abstractmethod def get_possible_cached_prefix( @@ -114,8 +109,7 @@ def get_possible_cached_prefix( # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self.block_pool_operations.get_cached_block( - block_hash): + if cached_block := self.block_pool.get_cached_block(block_hash): computed_blocks.append(cached_block) else: break @@ -139,10 +133,10 @@ def remove_useless_blocks(self, block_table: List[KVCacheBlock], class SlidingWindowManager(FullAttentionManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, - block_pool_operations: BlockPoolOperations): - super().__init__(kv_cache_spec, block_pool_operations) + block_pool: BlockPool): + super().__init__(kv_cache_spec, block_pool) self.sliding_window = kv_cache_spec.sliding_window - self._null_block = block_pool_operations.get_null_block() + self._null_block = block_pool.get_null_block() def get_possible_cached_prefix( self, block_hashes: ConstantList[BlockHashType] @@ -160,8 +154,7 @@ def get_possible_cached_prefix( # cached. for i, block_hash in enumerate(chain(block_hashes, [dummy_block_hash])): - if cached_block := self.block_pool_operations.get_cached_block( - block_hash): + if cached_block := self.block_pool.get_cached_block(block_hash): computed_blocks.append(cached_block) else: if start == 0: @@ -209,13 +202,11 @@ def remove_useless_blocks(self, block_table: List[KVCacheBlock], } -def get_managers( - kv_cache_config: KVCacheConfig, - block_pool_operations: BlockPoolOperations -) -> List[SpecializedManager]: +def get_managers(kv_cache_config: KVCacheConfig, + block_pool: BlockPool) -> List[SpecializedManager]: managers: List[SpecializedManager] = [] for g in kv_cache_config.groups: manager_class = spec_manager_map[type(g.kv_cache_spec)] - manager = manager_class(g.kv_cache_spec, block_pool_operations) + manager = manager_class(g.kv_cache_spec, block_pool) managers.append(manager) return managers From 2a48c70cbe54034d5e56c0356dd9a981ed918b06 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 07:19:30 -0800 Subject: [PATCH 44/48] support sliding window in single group mgr Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 3 ++ vllm/v1/core/hybrid_kv_cache_manager.py | 26 ++++------- vllm/v1/core/kv_cache_manager.py | 62 ++++++++++++++++--------- vllm/v1/core/specialized_manager.py | 13 ++---- 4 files changed, 59 insertions(+), 45 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 2d73c2d3239e3..6abbacabd59d4 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -252,3 +252,6 @@ def get_num_free_blocks(self) -> int: def get_null_block(self) -> KVCacheBlock: return self._null_block + + def get_usage(self) -> float: + return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index 24674fcd7e00a..80faf249adf89 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -13,7 +13,7 @@ generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens, intersect_ranges) -from vllm.v1.core.specialized_manager import get_managers +from vllm.v1.core.specialized_manager import get_specialized_manager from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus @@ -64,10 +64,12 @@ def __init__( # Specialized managers for each kv cache group, which handle the # different kv cache management logic of different attention layers. - self.managers = get_managers( - kv_cache_config, - block_pool=self.block_pool, - ) + self.managers = [ + get_specialized_manager( + g.kv_cache_spec, + block_pool=self.block_pool, + ) for g in kv_cache_config.groups + ] self.num_kv_cache_groups = len(self.kv_cache_config.groups) # Mapping from request ID to blocks to track the blocks allocated @@ -85,8 +87,7 @@ def __init__( @property def usage(self) -> float: - return 1.0 - (self.block_pool.get_num_free_blocks() / - self.num_gpu_blocks) + return self.block_pool.get_usage() def get_computed_blocks(self, request: Request) -> Tuple[ReqKVCacheBlocks, int]: @@ -125,15 +126,8 @@ def get_computed_blocks(self, computed_blocks.append(computed_blocks_i) prefix_length.append(prefix_length_i) - if len(self.kv_cache_config.groups) == 1: - # If there is only one group, we return the computed blocks and - # tokens directly. - num_computed_tokens = prefix_length[0][-1].end - else: - # Find the common cached prefix of all groups. This path also works - # for the single group case, but it is less efficient. - num_computed_tokens = self._get_common_computed_tokens( - prefix_length) + # Find the common cached prefix of all groups. + num_computed_tokens = self._get_common_computed_tokens(prefix_length) # Truncate the computed blocks to the number of computed tokens. # E.g., group 0 has 3 computed blocks, and group 1 has 4 computed diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 1b023b262e8c6..4d665712e900a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -12,6 +12,7 @@ generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) +from vllm.v1.core.specialized_manager import get_specialized_manager from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus @@ -27,7 +28,8 @@ def __init__( enable_caching: bool = True, num_preallocate_tokens: int = 64, ) -> None: - self.block_size = kv_cache_config.groups[0].kv_cache_spec.block_size + kv_cache_spec = kv_cache_config.groups[0].kv_cache_spec + self.block_size = kv_cache_spec.block_size self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) @@ -49,6 +51,11 @@ def __init__( self.block_pool = BlockPool(self.num_gpu_blocks, self.enable_caching) + self.manager = get_specialized_manager( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + ) + # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. @@ -63,8 +70,7 @@ def __init__( @property def usage(self) -> float: - return 1.0 - (self.block_pool.get_num_free_blocks() / - self.num_gpu_blocks) + return self.block_pool.get_usage() def get_computed_blocks( self, request: Request) -> Tuple[List[KVCacheBlock], int]: @@ -83,8 +89,6 @@ def get_computed_blocks( # Prefix caching is disabled. return [], 0 - computed_blocks = [] - # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] @@ -92,19 +96,12 @@ def get_computed_blocks( block_hashes = hash_request_tokens(self.block_size, request, 0) self.req_to_block_hashes[request.request_id] = block_hashes - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): - computed_blocks.append(cached_block) - else: - break + prefix_length, computed_blocks = self.manager.get_possible_cached_prefix( + block_hashes) + num_computed_tokens = prefix_length[-1].end + computed_blocks = computed_blocks[:num_computed_tokens // + self.block_size] - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size return computed_blocks, num_computed_tokens def allocate_slots( @@ -141,17 +138,24 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") + req_blocks = self.req_to_blocks[request.request_id] + # We can free blocks that are no longer needed even if we cannot + # schedule this request due to the limit of free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. + self._free_useless_blocks(req_blocks, request.num_computed_tokens) + new_computed_blocks = new_computed_blocks or [] # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + num_new_computed_tokens) - num_required_blocks = cdiv(num_computed_tokens + num_tokens, - self.block_size) + req_blocks = self.req_to_blocks[request.request_id] - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) + num_new_blocks = self.manager.get_num_new_blocks( + num_computed_tokens, num_tokens, + len(req_blocks) + len(new_computed_blocks)) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block @@ -303,6 +307,22 @@ def free_block_hashes(self, request: Request) -> None: """ self.req_to_block_hashes.pop(request.request_id, None) + def _free_useless_blocks(self, req_blocks: List[KVCacheBlock], + num_computed_tokens: int) -> None: + """ + Frees memory blocks that are not needed. E.g., sliding window + layer with window size 2 and block size 1, we have req_blocks as + [[1, 2, 3]], this function will free block 1 and change the req_blocks + to [[-1, 2, 3]] (-1 refers to null block) + + Args: + req_blocks: The KV cache blocks of one request. + num_computed_tokens: The number of computed tokens. + """ + removed_blocks = self.manager.remove_useless_blocks( + req_blocks, num_computed_tokens) + self.block_pool.free_blocks(removed_blocks) + def init_kv_cache_manager(kv_cache_config: KVCacheConfig, max_model_len: int, diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 52c762bcf6210..bfe5cd0cea2b7 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -202,11 +202,8 @@ def remove_useless_blocks(self, block_table: List[KVCacheBlock], } -def get_managers(kv_cache_config: KVCacheConfig, - block_pool: BlockPool) -> List[SpecializedManager]: - managers: List[SpecializedManager] = [] - for g in kv_cache_config.groups: - manager_class = spec_manager_map[type(g.kv_cache_spec)] - manager = manager_class(g.kv_cache_spec, block_pool) - managers.append(manager) - return managers +def get_specialized_manager(kv_cache_spec: KVCacheSpec, + block_pool: BlockPool) -> SpecializedManager: + manager_class = spec_manager_map[type(kv_cache_spec)] + manager = manager_class(kv_cache_spec, block_pool) + return manager From 4dded76abc84a32fd59d88b2dfa162f9e98cafa2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 08:05:57 -0800 Subject: [PATCH 45/48] unify allocate_slots Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 28 +++++++++++++--------- vllm/v1/core/hybrid_kv_cache_manager.py | 32 +++++++++---------------- vllm/v1/core/kv_cache_manager.py | 30 +++++++++++------------ 3 files changed, 42 insertions(+), 48 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 6abbacabd59d4..9baff5c7f8744 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -55,11 +55,11 @@ def get_cached_block(self, def cache_full_blocks(self, request: Request, + blocks: List[KVCacheBlock], block_hashes: List[BlockHashType], + old_num_computed_tokens: int, + new_num_computed_tokens: int, block_size: int, - blk_start_idx: int, - full_blocks: List[KVCacheBlock], - prev_block: Optional[KVCacheBlock], kv_cache_group_id: int = 0) -> None: """Cache a list of full blocks for prefix caching. @@ -76,20 +76,26 @@ def cache_full_blocks(self, full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ + num_full_blocks = new_num_computed_tokens // block_size + num_computed_full_blocks = old_num_computed_tokens // block_size + new_full_blocks = blocks[num_computed_full_blocks:num_full_blocks] + if not new_full_blocks: + return num_cached_block_hashes = len(block_hashes) - # Update the new blocks with the block hashes through the chain. - prev_block_hash_value = None - if prev_block is not None: - # Previous block must have a block hash because it must be - # a full, cached block. + if num_computed_full_blocks == 0: + prev_block_hash_value = None + else: + + prev_block = blocks[num_computed_full_blocks - 1], assert prev_block.block_hash is not None prev_block_hash_value = prev_block.block_hash.hash_value + # Update the new blocks with the block hashes through the chain. # Find the first uncached block. This case should only happen when # speculative decoding is used. offset = 0 - for blk in full_blocks: + for blk in new_full_blocks: if blk.block_hash is None: break else: @@ -99,8 +105,8 @@ def cache_full_blocks(self, # All blocks are cached. return - for i, blk in enumerate(full_blocks[offset:]): - blk_idx = blk_start_idx + offset + i + for i, blk in enumerate(new_full_blocks[offset:]): + blk_idx = num_computed_full_blocks + offset + i assert blk.block_hash is None if blk_idx < num_cached_block_hashes: diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index 80faf249adf89..ad45c4f0d52d7 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -178,7 +178,7 @@ def allocate_slots( # the number of evicted blocks. self._free_useless_blocks(req_blocks, request.num_computed_tokens) - new_computed_blocks = new_computed_blocks if new_computed_blocks is not None else [ + new_computed_blocks = new_computed_blocks or [ [] for _ in range(self.num_kv_cache_groups) ] @@ -267,26 +267,16 @@ def allocate_slots( # TODO(rickyx): When supporting speculative decoding, we will need to # differentiate between them so that we can know how many blocks are # full after appending the actual tokens. - num_full_blocks = (num_computed_tokens + - num_tokens) // manager.block_size - num_computed_full_blocks = num_computed_tokens // manager.block_size - - new_full_blocks = req_blocks[i][ - num_computed_full_blocks:num_full_blocks] - if new_full_blocks: - block_hashes = self.req_to_block_hashes[request.request_id][i] - self.block_pool.cache_full_blocks( - request=request, - block_hashes=block_hashes, - block_size=manager.block_size, - blk_start_idx=num_computed_full_blocks, - # The new full blocks are the full blocks that are not - # computed. - full_blocks=new_full_blocks, - prev_block=(req_blocks[i][num_computed_full_blocks - 1] - if num_computed_full_blocks > 0 else None), - kv_cache_group_id=i, - ) + block_hashes = self.req_to_block_hashes[request.request_id][i] + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks[i], + block_hashes=block_hashes, + old_num_computed_tokens=num_computed_tokens, + new_num_computed_tokens=num_computed_tokens + num_tokens, + block_size=manager.block_size, + kv_cache_group_id=i, + ) return new_blocks diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 4d665712e900a..d8817eecd310c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -152,7 +152,6 @@ def allocate_slots( num_computed_tokens = (request.num_computed_tokens + num_new_computed_tokens) - req_blocks = self.req_to_blocks[request.request_id] num_new_blocks = self.manager.get_num_new_blocks( num_computed_tokens, num_tokens, len(req_blocks) + len(new_computed_blocks)) @@ -185,11 +184,13 @@ def allocate_slots( # No new block is needed. new_blocks = [] else: + num_preallocate_blocks = min( + self.num_preallocate_blocks, + self.block_pool.get_num_free_blocks() - num_new_blocks) # Get new blocks from the free block pool considering # preallocated blocks. num_new_blocks = min( num_new_blocks + self.num_preallocate_blocks, - self.block_pool.get_num_free_blocks(), # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. @@ -198,6 +199,7 @@ def allocate_slots( self.max_num_blocks_per_req - len(req_blocks), ) assert num_new_blocks > 0 + assert num_new_blocks <= self.block_pool.get_num_free_blocks() # Concatenate the computed block IDs and the new block IDs. new_blocks = self.block_pool.get_new_blocks(num_new_blocks) @@ -206,25 +208,21 @@ def allocate_slots( if not self.enable_caching: return new_blocks + block_hashes = self.req_to_block_hashes[request.request_id] + # NOTE(rickyx): We are assuming the `num_tokens` are actual # tokens rather than lookahead slots (e.g. for speculative decoding). # TODO(rickyx): When supporting speculative decoding, we will need to # differentiate between them so that we can know how many blocks are # full after appending the actual tokens. - num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - num_computed_full_blocks = num_computed_tokens // self.block_size - new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] - if new_full_blocks: - block_hashes = self.req_to_block_hashes[request.request_id] - self.block_pool.cache_full_blocks( - request=request, - block_hashes=block_hashes, - block_size=self.block_size, - blk_start_idx=num_computed_full_blocks, - # The new full blocks are the full blocks that are not computed. - full_blocks=new_full_blocks, - prev_block=(req_blocks[num_computed_full_blocks - 1] - if num_computed_full_blocks > 0 else None)) + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks, + block_hashes=block_hashes, + old_num_computed_tokens=num_computed_tokens, + new_num_computed_tokens=num_computed_tokens + num_tokens, + block_size=self.block_size, + ) return new_blocks From 282ccd5d7a070b3c4ed0be6ced9c74e6fe652b8a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 08:14:55 -0800 Subject: [PATCH 46/48] unify other parts Signed-off-by: Chen Zhang --- vllm/v1/core/hybrid_kv_cache_manager.py | 98 ++++++++++++------------- vllm/v1/core/kv_cache_manager.py | 2 +- 2 files changed, 48 insertions(+), 52 deletions(-) diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index ad45c4f0d52d7..4c58d1bdfc02b 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -280,44 +280,8 @@ def allocate_slots( return new_blocks - def _merge_blocks_by_eviction_order( - self, blocks: ReqKVCacheBlocks) -> List[KVCacheBlock]: - """ - Merge the blocks of different groups to one list. The returned blocks - are sorted by eviction order, with the first block having the highest - eviction priority. - - Args: - blocks: the blocks of each kv cache group, ordered by eviction - priority. - - Returns: - A list of KVCacheBlocks sorted by eviction order. - """ - - if self.enable_caching: - # NOTE (Chen): A simple strategy that interleaves the blocks of - # different KV cache groups. We can investigate more advanced - # strategies in the future. - ordered_blocks = [] - max_len = max(len(blocks_of_group) for blocks_of_group in blocks) - for i in range(max_len): - for blocks_of_group in blocks: - if i < len(blocks_of_group): - ordered_blocks.append(blocks_of_group[i]) - else: - ordered_blocks = [] - for blocks_of_group in blocks: - ordered_blocks.extend(blocks_of_group) - - return ordered_blocks - def _free_blocks(self, blocks: ReqKVCacheBlocks) -> None: - if len(self.kv_cache_config.groups) == 1: - # Fast path for single kv cache group models. - ordered_blocks = blocks[0] - else: - ordered_blocks = self._merge_blocks_by_eviction_order(blocks) + ordered_blocks = self._merge_blocks_by_eviction_order(blocks) self.block_pool.free_blocks(ordered_blocks) def free(self, request: Request) -> None: @@ -402,6 +366,25 @@ def free_block_hashes(self, request: Request) -> None: """ self.req_to_block_hashes.pop(request.request_id, None) + def _free_useless_blocks(self, req_blocks: ReqKVCacheBlocks, + num_computed_tokens: int) -> None: + """ + Frees memory blocks that are not needed. E.g., sliding window + layer with window size 2 and block size 1, we have req_blocks as + [[1, 2, 3]], this function will free block 1 and change the req_blocks + to [[-1, 2, 3]] (-1 refers to null block) + + Args: + req_blocks: The KV cache blocks of one request. + num_computed_tokens: The number of computed tokens. + """ + removed_blocks = [] + for manager, req_blocks_of_group in zip(self.managers, req_blocks): + removed_blocks.append( + manager.remove_useless_blocks(req_blocks_of_group, + num_computed_tokens)) + self._free_blocks(removed_blocks) + def _get_common_computed_tokens(self, prefix_length: List[PrefixLength]) -> int: """ @@ -433,21 +416,34 @@ def _get_common_computed_tokens(self, return num_computed_tokens - def _free_useless_blocks(self, req_blocks: ReqKVCacheBlocks, - num_computed_tokens: int) -> None: + def _merge_blocks_by_eviction_order( + self, blocks: ReqKVCacheBlocks) -> List[KVCacheBlock]: """ - Frees memory blocks that are not needed. E.g., sliding window - layer with window size 2 and block size 1, we have req_blocks as - [[1, 2, 3]], this function will free block 1 and change the req_blocks - to [[-1, 2, 3]] (-1 refers to null block) + Merge the blocks of different groups to one list. The returned blocks + are sorted by eviction order, with the first block having the highest + eviction priority. Args: - req_blocks: The KV cache blocks of one request. - num_computed_tokens: The number of computed tokens. + blocks: the blocks of each kv cache group, ordered by eviction + priority. + + Returns: + A list of KVCacheBlocks sorted by eviction order. """ - removed_blocks = [] - for manager, req_blocks_of_group in zip(self.managers, req_blocks): - removed_blocks.append( - manager.remove_useless_blocks(req_blocks_of_group, - num_computed_tokens)) - self._free_blocks(removed_blocks) + + if self.enable_caching: + # NOTE (Chen): A simple strategy that interleaves the blocks of + # different KV cache groups. We can investigate more advanced + # strategies in the future. + ordered_blocks = [] + max_len = max(len(blocks_of_group) for blocks_of_group in blocks) + for i in range(max_len): + for blocks_of_group in blocks: + if i < len(blocks_of_group): + ordered_blocks.append(blocks_of_group[i]) + else: + ordered_blocks = [] + for blocks_of_group in blocks: + ordered_blocks.extend(blocks_of_group) + + return ordered_blocks diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index d8817eecd310c..7b9d89153cda8 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -190,7 +190,7 @@ def allocate_slots( # Get new blocks from the free block pool considering # preallocated blocks. num_new_blocks = min( - num_new_blocks + self.num_preallocate_blocks, + num_new_blocks + num_preallocate_blocks, # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. From 0e33053ff583f34c34c85518c8e1561f5261c26f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 19:05:16 -0800 Subject: [PATCH 47/48] reduce code duplication Signed-off-by: Chen Zhang --- vllm/v1/core/hybrid_kv_cache_manager.py | 18 ++++-------------- vllm/v1/core/kv_cache_manager.py | 2 +- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index 4c58d1bdfc02b..a3c9256d997c0 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -7,6 +7,7 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, PrefixLength, ReqKVCacheBlocks, @@ -85,9 +86,9 @@ def __init__( str, List[List[BlockHashType]]] = defaultdict( lambda: [[] for _ in range(self.num_kv_cache_groups)]) - @property - def usage(self) -> float: - return self.block_pool.get_usage() + usage = KVCacheManager.usage + reset_prefix_cache = KVCacheManager.reset_prefix_cache + free_block_hashes = KVCacheManager.free_block_hashes def get_computed_blocks(self, request: Request) -> Tuple[ReqKVCacheBlocks, int]: @@ -302,9 +303,6 @@ def free(self, request: Request) -> None: # eviction priority. self._free_blocks([list(reversed(blks)) for blks in blocks]) - def reset_prefix_cache(self) -> bool: - return self.block_pool.reset_prefix_cache() - def get_num_common_prefix_blocks( self, request: Request, @@ -358,14 +356,6 @@ def get_num_common_prefix_blocks( num_common_blocks_per_group.append(num_common_blocks) return num_common_blocks_per_group - def free_block_hashes(self, request: Request) -> None: - """Discard the block hashes for the request. - - NOTE: Unlike `free`, this method should be called only when the request - is finished, not when it is preempted. - """ - self.req_to_block_hashes.pop(request.request_id, None) - def _free_useless_blocks(self, req_blocks: ReqKVCacheBlocks, num_computed_tokens: int) -> None: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 7b9d89153cda8..fd0a4267fa9ef 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,7 +6,6 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, @@ -326,6 +325,7 @@ def init_kv_cache_manager(kv_cache_config: KVCacheConfig, max_model_len: int, enable_caching: bool = True, num_preallocate_tokens: int = 64): + from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager print("kv_cache_config", kv_cache_config) if len(kv_cache_config.groups) > 1: logger.info("Using HybridKVCacheManager") From 5491910868bcecf305b916ee3d282fe4077cd774 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 14 Feb 2025 19:47:07 -0800 Subject: [PATCH 48/48] fix typing Signed-off-by: Chen Zhang --- tests/v1/core/test_specialized_manager.py | 5 +++-- .../v1/e2e/test_correctness_sliding_window.py | 1 + vllm/v1/core/block_pool.py | 15 +++++++++------ vllm/v1/core/hybrid_kv_cache_manager.py | 19 ++++++++----------- vllm/v1/core/kv_cache_manager.py | 11 ++++------- vllm/v1/core/scheduler.py | 16 ++++++++-------- vllm/v1/core/specialized_manager.py | 18 +++++++++--------- vllm/v1/kv_cache_interface.py | 17 +++++++++++++++-- vllm/v1/worker/gpu_input_batch.py | 6 +++--- vllm/v1/worker/gpu_model_runner.py | 8 +++++--- 10 files changed, 65 insertions(+), 51 deletions(-) diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 969b18afbe977..54f68a917c1c0 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -1,12 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 from collections import deque from typing import Deque import torch -from vllm.v1.core.specialized_manager import (BlockPoolOperations, - SlidingWindowManager) from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, PrefixLengthRange) +from vllm.v1.core.specialized_manager import (BlockPoolOperations, + SlidingWindowManager) from vllm.v1.kv_cache_interface import SlidingWindowSpec diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 2bc6d3d1712fc..4506ad08b27af 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass from typing import List, Tuple diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 9baff5c7f8744..fe547b4254c9a 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,8 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, Iterable, List, Optional from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens +from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens) from vllm.v1.request import Request logger = init_logger(__name__) @@ -86,8 +90,7 @@ def cache_full_blocks(self, if num_computed_full_blocks == 0: prev_block_hash_value = None else: - - prev_block = blocks[num_computed_full_blocks - 1], + prev_block = blocks[num_computed_full_blocks - 1] assert prev_block.block_hash is not None prev_block_hash_value = prev_block.block_hash.hash_value # Update the new blocks with the block hashes through the chain. @@ -214,7 +217,7 @@ def touch(self, blocks: List[KVCacheBlock]) -> None: self._free_block_queue.remove(block) block.incr_ref() - def free_blocks(self, ordered_blocks: List[KVCacheBlock]) -> None: + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: """ TODO: add docstring the first block will be evicted first @@ -244,7 +247,7 @@ def reset_prefix_cache(self) -> bool: return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) + self._cached_block_hash_to_block = defaultdict(dict) # Remove all hashes from all blocks. for block in self._block_pool: diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index a3c9256d997c0..d8b75c7f2f678 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -2,17 +2,14 @@ import math from collections import defaultdict -from typing import DefaultDict, Dict, List, Optional, Tuple +from typing import DefaultDict, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, PrefixLength, - ReqKVCacheBlocks, - generate_block_hash_extra_keys, - hash_block_tokens, +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + PrefixLength, ReqKVCacheBlocks, hash_request_tokens, intersect_ranges) from vllm.v1.core.specialized_manager import get_specialized_manager from vllm.v1.kv_cache_interface import KVCacheConfig @@ -262,12 +259,12 @@ def allocate_slots( if not self.enable_caching: return new_blocks + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. for i, manager in enumerate(self.managers): - # NOTE(rickyx): We are assuming the `num_tokens` are actual - # tokens rather than lookahead slots (e.g. for speculative decoding). - # TODO(rickyx): When supporting speculative decoding, we will need to - # differentiate between them so that we can know how many blocks are - # full after appending the actual tokens. block_hashes = self.req_to_block_hashes[request.request_id][i] self.block_pool.cache_full_blocks( request=request, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index fd0a4267fa9ef..5009ca934bae3 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,15 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple +from typing import DefaultDict, Iterable, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens, +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.core.specialized_manager import get_specialized_manager from vllm.v1.kv_cache_interface import KVCacheConfig @@ -95,8 +92,8 @@ def get_computed_blocks( block_hashes = hash_request_tokens(self.block_size, request, 0) self.req_to_block_hashes[request.request_id] = block_hashes - prefix_length, computed_blocks = self.manager.get_possible_cached_prefix( - block_hashes) + prefix_length, computed_blocks = \ + self.manager.get_possible_cached_prefix(block_hashes) num_computed_tokens = prefix_length[-1].end computed_blocks = computed_blocks[:num_computed_tokens // self.block_size] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 6bd793f7b1841..c434c22701056 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -10,10 +10,10 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager, init_kv_cache_manager +from vllm.v1.core.kv_cache_manager import init_kv_cache_manager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import (BlockIDGenerator, FullAttentionSpec, - GroupedBlockIDs, KVCacheConfig) + KVCacheConfig, MayGroupedBlockIDs) from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -109,7 +109,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] - req_to_new_block_ids: Dict[str, GroupedBlockIDs] = {} + req_to_new_block_ids: Dict[str, MayGroupedBlockIDs] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -332,7 +332,7 @@ def schedule(self) -> "SchedulerOutput": def _make_cached_request_data( self, request: Request, - new_block_ids: GroupedBlockIDs, + new_block_ids: MayGroupedBlockIDs, num_computed_tokens: int, resumed_from_preemption: bool, ) -> "CachedRequestData": @@ -584,14 +584,14 @@ class NewRequestData: mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams - block_ids: GroupedBlockIDs + block_ids: MayGroupedBlockIDs num_computed_tokens: int @classmethod def from_request( cls, request: Request, - block_ids: GroupedBlockIDs, + block_ids: MayGroupedBlockIDs, num_computed_tokens: int, ) -> "NewRequestData": return cls( @@ -615,7 +615,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool - new_block_ids: GroupedBlockIDs + new_block_ids: MayGroupedBlockIDs num_computed_tokens: int @classmethod @@ -623,7 +623,7 @@ def from_request( cls, request: Request, resumed_from_preemption: bool, - new_block_ids: GroupedBlockIDs, + new_block_ids: MayGroupedBlockIDs, num_computed_tokens: int, ) -> "CachedRequestData": return cls( diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index bfe5cd0cea2b7..dcbc716fbb681 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,15 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from dataclasses import dataclass from itertools import chain -from typing import Callable, Dict, List, Optional, Tuple, Type +from typing import Dict, List, Tuple, Type from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, PrefixLength, PrefixLengthRange) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec, SlidingWindowSpec) -from vllm.v1.utils import ConstantList +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) class SpecializedManager(ABC): @@ -30,7 +29,8 @@ def __init__( Args: kv_cache_spec: The kv_cache_spec for this manager. - block_pool_operations: Operations to interact with the block pool. TODO update + block_pool_operations: Operations to interact with the block pool. + TODO update Returns: None @@ -42,7 +42,7 @@ def __init__( @abstractmethod def get_possible_cached_prefix( - self, block_hashes: ConstantList[BlockHashType] + self, block_hashes: List[BlockHashType] ) -> Tuple[PrefixLength, List[KVCacheBlock]]: """ Get the possible cached prefixes of a request based on its block hashes. @@ -102,7 +102,7 @@ def remove_useless_blocks(self, block_table: List[KVCacheBlock], class FullAttentionManager(SpecializedManager): def get_possible_cached_prefix( - self, block_hashes: ConstantList[BlockHashType] + self, block_hashes: List[BlockHashType] ) -> Tuple[List[PrefixLengthRange], List[KVCacheBlock]]: computed_blocks: List[KVCacheBlock] = [] for block_hash in block_hashes: @@ -139,7 +139,7 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, self._null_block = block_pool.get_null_block() def get_possible_cached_prefix( - self, block_hashes: ConstantList[BlockHashType] + self, block_hashes: List[BlockHashType] ) -> Tuple[List[PrefixLengthRange], List[KVCacheBlock]]: # TODO: check the hit every num_block_sliding_window blocks, to optimize # the time complexity from O(num_block) to diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 38e13a08c225c..33ca612c321c0 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Union, cast, overload import torch @@ -9,7 +9,7 @@ from vllm.utils import cdiv, get_dtype_size if TYPE_CHECKING: - from vllm.v1.core.kv_cache_utils import ReqKVCacheBlocks + from vllm.v1.core.kv_cache_utils import KVCacheBlock, ReqKVCacheBlocks logger = init_logger(__name__) @@ -194,12 +194,25 @@ def get_group(self, group_idx: int) -> List[int]: class BlockIDGenerator: num_kv_cache_groups: int + @overload + @classmethod + def generate(cls, kv_cache_blocks: List[KVCacheBlock]) -> List[int]: + ... + + @overload + @classmethod + def generate(cls, + kv_cache_blocks: List[List[KVCacheBlock]]) -> GroupedBlockIDs: + ... + @classmethod def generate( cls, kv_cache_blocks: Union[List["KVCacheBlock"], List[List["KVCacheBlock"]]] ) -> MayGroupedBlockIDs: if cls.num_kv_cache_groups == 1: + kv_cache_blocks = cast(List["KVCacheBlock"], kv_cache_blocks) return [blk.block_id for blk in kv_cache_blocks] else: + kv_cache_blocks = cast(List[List["KVCacheBlock"]], kv_cache_blocks) return GroupedBlockIDs.from_kv_cache_blocks(kv_cache_blocks) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 745ff5dac8817..5c3b153988571 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,7 +10,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.v1.kv_cache_interface import GroupedBlockIDs, KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, MayGroupedBlockIDs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import initialize_block_table @@ -29,7 +29,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: GroupedBlockIDs + block_ids: MayGroupedBlockIDs num_computed_tokens: int output_token_ids: List[int] @@ -196,7 +196,7 @@ def add_request( self.num_tokens[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(request.block_ids, req_index) + self.block_table.add_row(request.block_ids, req_index) # type: ignore sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 860e6200ba2cd..2f5c23584fab2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -306,7 +306,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_state.num_computed_tokens = req_data.num_computed_tokens if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - req_state.block_ids.extend(req_data.new_block_ids) + req_state.block_ids.extend( + req_data.new_block_ids) # type: ignore else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -323,8 +324,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) - self.input_batch.block_table.append_row(req_data.new_block_ids, - req_index) + self.input_batch.block_table.append_row( + req_data.new_block_ids, # type: ignore + req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first.