Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Core] Sliding window for block manager v2 (vllm-project#4545)
Browse files Browse the repository at this point in the history
Co-authored-by: Ruth Evans <ruthevans@Ruths-MacBook-Pro.local>
  • Loading branch information
2 people authored and robertgshaw2-redhat committed Jun 8, 2024
1 parent 9fb7b82 commit 954c332
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 45 deletions.
26 changes: 26 additions & 0 deletions tests/core/block/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Iterable, Optional

import pytest

from vllm import LLM
Expand Down Expand Up @@ -40,3 +42,27 @@ def generator_inner():
for llm in generator_inner():
yield llm
del llm


def get_text_from_llm_generator(llm_generator: Iterable[LLM],
prompts,
sampling_params,
llm_cb: Optional[Callable[[LLM],
None]] = None):
for llm in llm_generator:
if llm_cb:
llm_cb(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
text = [output.outputs[0].text for output in outputs]
del llm

return text


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
del llm

return token_ids
11 changes: 2 additions & 9 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from vllm import SamplingParams

from .conftest import get_token_ids_from_llm_generator


@pytest.mark.parametrize(
"common_llm_kwargs",
Expand Down Expand Up @@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
del llm

return token_ids
168 changes: 168 additions & 0 deletions tests/core/block/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import random
from typing import List

import pytest

from vllm import LLM, SamplingParams

from .conftest import get_text_from_llm_generator

# relatively small model with 4k sliding window
MODEL = "bigcode/starcoder2-3b"
BLOCK_SIZE = 16


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": MODEL,
# skip cuda graph creation for fast test.
"enforce_eager": True,
"block_size": BLOCK_SIZE,
# needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
batch_size, seed):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).
Additionally, we compare the results of the v1 and v2 managers.
"""
sampling_params = SamplingParams(
max_tokens=1024,
ignore_eos=True,
temperature=0.0,
)

prompts, answer, indices = prep_prompts(batch_size)

print('Getting token ids from block manager v1')
baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
prompts,
sampling_params,
llm_cb=check_window(prompts))

check_answers(indices, answer, baseline_texts)

print('Getting token ids from block manager v2')
test_texts = get_text_from_llm_generator(test_llm_generator, prompts,
sampling_params)
check_answers(indices, answer, test_texts)

cmp = [
expected_text == actual_text
for expected_text, actual_text in zip(baseline_texts, test_texts)
]
print(cmp)
# make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768
# however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290
# states that xformers and flash_attn have different ideas about the window
# size anyways
assert sum(cmp) > 0.7 * len(cmp)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": MODEL,
# skip cuda graph creation for fast test.
"enforce_eager": True,
"block_size": BLOCK_SIZE,
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"enable_chunked_prefill": True
}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
"""
This is similar to test_sliding_window_retrival, however, it doesn't
compare against the v1 block manager since v1 doesn't support
chunked prefill with sliding window.
The results with and without chunked prefill are not the same due to
numerical instabilities.
"""
sampling_params = SamplingParams(
max_tokens=10,
ignore_eos=True,
temperature=0.0,
)

prompts, answer, indices = prep_prompts(batch_size)

# We don't compare with the baseline model here, since the results
# slightly different due to different tailing in attention.
test_texts = get_text_from_llm_generator(test_llm_generator,
prompts,
sampling_params,
llm_cb=check_window(prompts))
check_answers(indices, answer, test_texts)


def prep_prompts(batch_size: int):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
"""
prompts: List[str] = []
answer: List[int] = []
indices: List[int] = []
random.seed(1)
for _ in range(batch_size):
idx = random.randint(30, 90)
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
answer.append(v)
prompt += f"x{k} = {v}\n"
prompt += f"# Now, we check the value of x{idx}:\n"
prompt += f"assert x{idx} == "
prompts.append(prompt)
return prompts, answer, indices


def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
for a1, a2 in zip(answer, answer2):
if a1 == a2:
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7


def check_window(prompts: List[str]):

def inner(llm: LLM):
sliding_window = llm.llm_engine.model_config.get_sliding_window()
assert sliding_window and sliding_window > 0
assert any(
len(llm.get_tokenizer().tokenize(prompt)) > sliding_window
for prompt in prompts)

return inner
69 changes: 69 additions & 0 deletions tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
assert num_consumed_blocks == expected_consumed_blocks


@pytest.mark.parametrize("block_size", [8, 16])
@pytest.mark.parametrize("prompt_len", [10, 300, 1000])
@pytest.mark.parametrize("num_slots_to_append", [50])
@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512])
def test_sliding_window(block_size, prompt_len, num_slots_to_append,
sliding_window):
"""Verify append_slots consumes the correct number of blocks from the block
table.
"""

num_gpu_blocks = 1024
watermark = 0.1
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
watermark=watermark,
sliding_window=sliding_window,
)

def check_used(min_n, max_n=None):
if max_n is None:
max_n = min_n
used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
#print("check", min_n, used, max_n)
assert min_n <= used
assert used <= max_n

def num_blocks(num_tokens):
return (num_tokens + block_size - 1) // block_size

check_used(0)

seq_group = create_seq_group(
seq_prompt_len=prompt_len,
seq_output_lens=[0],
)

check_used(0)

# Allocate seq
assert block_manager.can_allocate(seq_group)
block_manager.allocate(seq_group)

check_used(num_blocks(prompt_len))

# Seq seq to RUNNING
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

seq.data.update_num_computed_tokens(prompt_len)
check_used(num_blocks(prompt_len))

# this is how we compute it in BlockSpaceManagerV2.__init__
sliding_blocks = (sliding_window // block_size) + 2
# plus one block for null block
sliding_blocks += 1

# Append tokens to the sequeqnce
for token_id in range(num_slots_to_append):
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
seq.data.update_num_computed_tokens(1)
block_manager.append_slots(seq, num_lookahead_slots=0)
if prompt_len < sliding_window + 10:
check_used(0, sliding_blocks + 1)
else:
check_used(sliding_blocks, sliding_blocks + 1)
6 changes: 5 additions & 1 deletion vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,10 @@ def context_attention_fwd(q,

grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,

# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0

num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
Expand Down Expand Up @@ -794,7 +798,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
SLIDING_WINDOW=sliding_window,
num_warps=num_warps,
num_stages=1,
)
Expand Down
34 changes: 32 additions & 2 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class BlockTable:
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
max_block_sliding_window (Optional[int], optional): The number of
blocks to keep around for each sequance. If None, all blocks
are kept (eg., when sliding window is not used).
It should at least fit the sliding window size of the model.
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
Expand All @@ -37,13 +41,15 @@ def __init__(
block_size: int,
block_allocator: DeviceAwareBlockAllocator,
_blocks: Optional[List[Block]] = None,
max_block_sliding_window: Optional[int] = None,
):
self._block_size = block_size
self._allocator = block_allocator
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks

self._max_block_sliding_window = max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
Expand Down Expand Up @@ -89,7 +95,8 @@ def allocate(self,

def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0) -> None:
num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
Expand All @@ -104,13 +111,35 @@ def append_token_ids(self,
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
num_computed_slots (Optional[int]): The number of KV cache slots
that are already filled (computed).
When sliding window is enabled, this is used to compute how many
blocks to drop at the front of the sequence.
Without sliding window, None can be passed.
Without chunked prefill, it should be the same as
_num_full_slots.
"""
assert self._is_allocated
assert self._is_allocated, "no blocks have been allocated"
assert len(self._blocks) > 0

# Drop blocks that are no longer needed due to sliding window
if self._max_block_sliding_window is not None:
null_block = self._allocator.allocate_or_get_null_block()
assert num_computed_slots is not None
end_block_idx = (num_computed_slots //
self._block_size) - self._max_block_sliding_window
for idx in range(0, end_block_idx):
b = self._blocks[idx]
if b is not null_block:
self._allocator.free(b)
self._blocks[idx] = null_block

# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)

# Update the blocks with the new tokens
blocks = self._blocks[self._num_full_slots // self._block_size:]
token_blocks = self._chunk_token_blocks_for_append(token_ids)

Expand Down Expand Up @@ -168,6 +197,7 @@ def fork(self) -> "BlockTable":
block_size=self._block_size,
block_allocator=self._allocator,
_blocks=forked_blocks,
max_block_sliding_window=self._max_block_sliding_window,
)

def free(self) -> None:
Expand Down
Loading

0 comments on commit 954c332

Please sign in to comment.