From d4f398590786f0015d474b03a3d078db1e7d1be2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Mon, 27 May 2024 19:07:07 -0700 Subject: [PATCH] [Core] Sliding window for block manager v2 (#4545) Co-authored-by: Ruth Evans --- tests/core/block/e2e/conftest.py | 26 +++ tests/core/block/e2e/test_correctness.py | 11 +- .../e2e/test_correctness_sliding_window.py | 168 ++++++++++++++++++ tests/core/block/test_block_manager_v2.py | 69 +++++++ vllm/attention/ops/prefix_prefill.py | 6 +- vllm/core/block/block_table.py | 34 +++- vllm/core/block/cpu_gpu_block_allocator.py | 74 ++++++++ vllm/core/block/interfaces.py | 9 + vllm/core/block_manager_v2.py | 24 ++- vllm/engine/arg_utils.py | 3 +- vllm/worker/cache_engine.py | 5 +- vllm/worker/model_runner.py | 73 +++++--- 12 files changed, 457 insertions(+), 45 deletions(-) create mode 100644 tests/core/block/e2e/test_correctness_sliding_window.py diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index b0d62c8993d3f..e870597b7a011 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,3 +1,5 @@ +from typing import Callable, Iterable, Optional + import pytest from vllm import LLM @@ -40,3 +42,27 @@ def generator_inner(): for llm in generator_inner(): yield llm del llm + + +def get_text_from_llm_generator(llm_generator: Iterable[LLM], + prompts, + sampling_params, + llm_cb: Optional[Callable[[LLM], + None]] = None): + for llm in llm_generator: + if llm_cb: + llm_cb(llm) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + text = [output.outputs[0].text for output in outputs] + del llm + + return text + + +def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): + for llm in llm_generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + del llm + + return token_ids diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index c3666da7542b5..3713ef2fed4d1 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -4,6 +4,8 @@ from vllm import SamplingParams +from .conftest import get_token_ids_from_llm_generator + @pytest.mark.parametrize( "common_llm_kwargs", @@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids - - -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - del llm - - return token_ids diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py new file mode 100644 index 0000000000000..e98292e807d73 --- /dev/null +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -0,0 +1,168 @@ +import random +from typing import List + +import pytest + +from vllm import LLM, SamplingParams + +from .conftest import get_text_from_llm_generator + +# relatively small model with 4k sliding window +MODEL = "bigcode/starcoder2-3b" +BLOCK_SIZE = 16 + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, + batch_size, seed): + """ + The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then + asks for value of one of them (which is outside the sliding window). + If we tell it upfront which we are going to be looking for, then + it answers correctly (mostly). + + Additionally, we compare the results of the v1 and v2 managers. + """ + sampling_params = SamplingParams( + max_tokens=1024, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + print('Getting token ids from block manager v1') + baseline_texts = get_text_from_llm_generator(baseline_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + + check_answers(indices, answer, baseline_texts) + + print('Getting token ids from block manager v2') + test_texts = get_text_from_llm_generator(test_llm_generator, prompts, + sampling_params) + check_answers(indices, answer, test_texts) + + cmp = [ + expected_text == actual_text + for expected_text, actual_text in zip(baseline_texts, test_texts) + ] + print(cmp) + # make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768 + # however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290 + # states that xformers and flash_attn have different ideas about the window + # size anyways + assert sum(cmp) > 0.7 * len(cmp) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "use_v2_block_manager": True, + "enable_chunked_prefill": True +}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): + """ + This is similar to test_sliding_window_retrival, however, it doesn't + compare against the v1 block manager since v1 doesn't support + chunked prefill with sliding window. + + The results with and without chunked prefill are not the same due to + numerical instabilities. + """ + sampling_params = SamplingParams( + max_tokens=10, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + # We don't compare with the baseline model here, since the results + # slightly different due to different tailing in attention. + test_texts = get_text_from_llm_generator(test_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + check_answers(indices, answer, test_texts) + + +def prep_prompts(batch_size: int): + """ + Generate prompts which a bunch of assignments, + then asking for the value of one of them. + The prompt is just under 10k tokens; sliding window is 4k + so the answer is outside sliding window, but should still be correct. + """ + prompts: List[str] = [] + answer: List[int] = [] + indices: List[int] = [] + random.seed(1) + for _ in range(batch_size): + idx = random.randint(30, 90) + indices.append(idx) + prompt = "```python\n# We set a number of variables, " + \ + f"x{idx} will be important later\n" + ln = random.randint(800, 1100) + for k in range(30, ln): + v = random.randint(10, 99) + if k == idx: + answer.append(v) + prompt += f"x{k} = {v}\n" + prompt += f"# Now, we check the value of x{idx}:\n" + prompt += f"assert x{idx} == " + prompts.append(prompt) + return prompts, answer, indices + + +def check_answers(indices: List[int], answer: List[int], outputs: List[str]): + answer2 = [int(text[0:2].strip()) for text in outputs] + print(list(zip(indices, zip(answer, answer2)))) + numok = 0 + for a1, a2 in zip(answer, answer2): + if a1 == a2: + numok += 1 + frac_ok = numok / len(answer) + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") + assert frac_ok > 0.7 + + +def check_window(prompts: List[str]): + + def inner(llm: LLM): + sliding_window = llm.llm_engine.model_config.get_sliding_window() + assert sliding_window and sliding_window > 0 + assert any( + len(llm.get_tokenizer().tokenize(prompt)) > sliding_window + for prompt in prompts) + + return inner diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 1e8e4ccdfb151..91b047f0e183e 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, range(prompt_len + num_slots_to_append + num_lookahead_slots)), block_size)) - len(chunk_list(list(range(prompt_len)), block_size)) assert num_consumed_blocks == expected_consumed_blocks + + +@pytest.mark.parametrize("block_size", [8, 16]) +@pytest.mark.parametrize("prompt_len", [10, 300, 1000]) +@pytest.mark.parametrize("num_slots_to_append", [50]) +@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512]) +def test_sliding_window(block_size, prompt_len, num_slots_to_append, + sliding_window): + """Verify append_slots consumes the correct number of blocks from the block + table. + """ + + num_gpu_blocks = 1024 + watermark = 0.1 + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + watermark=watermark, + sliding_window=sliding_window, + ) + + def check_used(min_n, max_n=None): + if max_n is None: + max_n = min_n + used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() + #print("check", min_n, used, max_n) + assert min_n <= used + assert used <= max_n + + def num_blocks(num_tokens): + return (num_tokens + block_size - 1) // block_size + + check_used(0) + + seq_group = create_seq_group( + seq_prompt_len=prompt_len, + seq_output_lens=[0], + ) + + check_used(0) + + # Allocate seq + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + + check_used(num_blocks(prompt_len)) + + # Seq seq to RUNNING + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + seq.data.update_num_computed_tokens(prompt_len) + check_used(num_blocks(prompt_len)) + + # this is how we compute it in BlockSpaceManagerV2.__init__ + sliding_blocks = (sliding_window // block_size) + 2 + # plus one block for null block + sliding_blocks += 1 + + # Append tokens to the sequeqnce + for token_id in range(num_slots_to_append): + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + seq.data.update_num_computed_tokens(1) + block_manager.append_slots(seq, num_lookahead_slots=0) + if prompt_len < sliding_window + 10: + check_used(0, sliding_blocks + 1) + else: + check_used(sliding_blocks, sliding_blocks + 1) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 997b25e887e30..b99cf9a50d105 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -697,6 +697,10 @@ def context_attention_fwd(q, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: _fwd_kernel_alibi[grid]( @@ -794,7 +798,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window if sliding_window is not None else 0, + SLIDING_WINDOW=sliding_window, num_warps=num_warps, num_stages=1, ) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index b0d9511fba521..26c704b8de901 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -20,6 +20,10 @@ class BlockTable: _blocks (Optional[List[Block]], optional): An optional list of existing blocks to initialize the BlockTable with. If not provided, an empty BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequance. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. Attributes: _block_size (int): The maximum number of tokens that can be stored in a @@ -37,6 +41,7 @@ def __init__( block_size: int, block_allocator: DeviceAwareBlockAllocator, _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, ): self._block_size = block_size self._allocator = block_allocator @@ -44,6 +49,7 @@ def __init__( _blocks = [] self._blocks: List[Block] = _blocks + self._max_block_sliding_window = max_block_sliding_window # Use helper method instead of directly calculating, as blocks # may not be allocated. self._num_full_slots = len(self._get_all_token_ids()) @@ -89,7 +95,8 @@ def allocate(self, def append_token_ids(self, token_ids: List[int], - num_lookahead_slots: int = 0) -> None: + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -104,13 +111,35 @@ def append_token_ids(self, Args: token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. """ - assert self._is_allocated + assert self._is_allocated, "no blocks have been allocated" assert len(self._blocks) > 0 + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) + # Update the blocks with the new tokens blocks = self._blocks[self._num_full_slots // self._block_size:] token_blocks = self._chunk_token_blocks_for_append(token_ids) @@ -168,6 +197,7 @@ def fork(self) -> "BlockTable": block_size=self._block_size, block_allocator=self._allocator, _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, ) def free(self) -> None: diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 0577ca76ea971..d28a684376974 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -105,11 +105,19 @@ def __init__( Device.GPU: gpu_block_allocator, } + self._null_block: Optional[Block] = None + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} for _, allocator in self._allocators.items(): for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable(None, Device.GPU)) + return self._null_block + def allocate_mutable(self, prev_block: Optional[Block], device: Device) -> Block: """Allocates a new mutable block on the specified device. @@ -149,6 +157,9 @@ def free(self, block: Block) -> None: Args: block (Block): The block to be freed. """ + # Null block should never be freed + if isinstance(block, NullBlock): + return block_id = block.block_id assert block_id is not None allocator = self._block_ids_to_allocator[block_id] @@ -165,6 +176,8 @@ def fork(self, last_block: Block) -> List[Block]: List[Block]: A new list of blocks that shares the same memory as the original sequence. """ + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) block_id = last_block.block_id assert block_id is not None allocator = self._block_ids_to_allocator[block_id] @@ -226,3 +239,64 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: raise NotImplementedError + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 140fbbb0949cc..8fc4c601106cd 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -203,3 +203,12 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index f0bc96564050a..834436c25e160 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -66,9 +66,18 @@ def __init__( self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks - assert sliding_window is None, "Sliding window not yet supported" - - self.block_sliding_window = None + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 self.watermark = watermark assert watermark >= 0.0 @@ -96,10 +105,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) - assert self.block_sliding_window is None - if self.block_sliding_window is not None: + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, - self.block_sliding_window) + self.max_block_sliding_window) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU) @@ -125,8 +133,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table = BlockTable( block_size=self.block_size, block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, ) - assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) self.block_tables[seq.seq_id] = block_table @@ -174,6 +183,7 @@ def append_slots( block_table.append_token_ids( token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), ) # Return any new copy-on-writes. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3267c8c9f44d2..11485aa2438c0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -648,7 +648,8 @@ def create_engine_config(self, ) -> EngineConfig: guided_decoding_backend=self.guided_decoding_backend) if (model_config.get_sliding_window() is not None - and scheduler_config.chunked_prefill_enabled): + and scheduler_config.chunked_prefill_enabled + and not scheduler_config.use_v2_block_manager): raise ValueError( "Chunked prefill is not supported with sliding window. " "Set --disable-sliding-window to disable sliding window.") diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 07d51dca226bd..2f0e59f7ae7c9 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -68,8 +68,11 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. kv_cache.append( - torch.empty(kv_cache_shape, + torch.zeros(kv_cache_shape, dtype=self.dtype, pin_memory=pin_memory, device=device)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 87d5f5c1b9d67..5ddd2d1b65f81 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -269,6 +269,12 @@ def _prepare_model_input( if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window + self.block_size - + 1) // self.block_size + block_aligned_sliding_window = \ + sliding_window_blocks * self.block_size + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt @@ -309,6 +315,30 @@ def _prepare_model_input( and self.sliding_window is None and is_prompt) + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + curr_sliding_window_blocks = sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. @@ -316,6 +346,13 @@ def _prepare_model_input( assert computed_block_nums is not None context_len = len(computed_block_nums) * self.block_size tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + if self.attn_backend.get_name() == "flash-attn": # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. @@ -329,14 +366,9 @@ def _prepare_model_input( if seq_group_metadata.block_tables is not None: # chunked prefill or decode block_table = seq_group_metadata.block_tables[seq_id] - if self.sliding_window is not None: - # chunked prefill doesn't support sliding window. - assert (not self.scheduler_config. - chunked_prefill_enabled) - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - + if curr_sliding_window_blocks is not None: + block_table = block_table[ + -curr_sliding_window_blocks:] if self.attn_backend.get_name() == "flashinfer": paged_kv_indices.extend(block_table) paged_kv_indptr.append(paged_kv_indptr[-1] + @@ -354,16 +386,9 @@ def _prepare_model_input( block_table = [] block_tables.append(block_table) - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - seq_len = min(seq_len, self.sliding_window) - context_len = seq_len - 1 - - seq_lens.append(seq_len) - context_lens.append(context_len) - query_len = seq_len - context_len + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len query_lens.append(query_len) input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) @@ -380,16 +405,15 @@ def _prepare_model_input( "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) num_decode_tokens += query_len - decode_seq_lens.append(seq_len) + decode_seq_lens.append(sliding_seq_len) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (seq_len - context_len) + lora_index_mapping += [lora_id] * query_len lora_prompt_mapping.extend( [lora_id] * - (seq_len - - context_len if seq_group_metadata.sampling_params + (query_len if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs else 1)) @@ -417,9 +441,10 @@ def _prepare_model_input( start_idx = 0 if self.sliding_window is not None: if is_prompt: - assert context_len == 0, ( + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( "Prefix caching is currently not supported with " - "sliding window attention") + "sliding window attention in V1 block manager") # It is an optimization. When it is decoding, it is always # 0. When prefill, we use it to not write slots to kv cache # to save memory.