diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index bd184ee22682e..c3902f4c2a163 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel( long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, int64_t const block_tables_stride, int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x < num_query_blocks) { diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index cc1fd19252019..34030d9d6ac60 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -5,6 +5,8 @@ import pytest +from tests.kernels.utils import override_backend_env_variable + from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ @@ -19,10 +21,11 @@ @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_multi_step_llm( hf_runner, vllm_runner, @@ -36,6 +39,8 @@ def test_multi_step_llm( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling via sync LLM Engine. @@ -63,6 +68,7 @@ def test_multi_step_llm( num_logprobs: corresponds to the `logprobs` argument to the OpenAI completions endpoint; `None` -> 1 logprob returned. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -114,6 +120,7 @@ def test_multi_step_llm( @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) def test_multi_step_llm_w_prompt_logprobs( vllm_runner, example_prompts, @@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs( num_prompts: int, num_logprobs: Optional[int], num_prompt_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test prompt logprobs with multi-step scheduling via sync LLM Engine. @@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs( note that this argument is not supported by the OpenAI completions endpoint. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs( @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) def test_multi_step_llm_chunked_prefill_prefix_cache( vllm_runner, example_prompts, @@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. @@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( # # The Incorrect scheduling behavior - if it occurs - will cause an exception # in the model runner resulting from `do_sample=False`. + override_backend_env_variable(monkeypatch, attention_backend) + assert len(example_prompts) >= 2 challenge_prompts = copy.deepcopy(example_prompts) challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a11462b2068a5..6ca75fabdfc38 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -256,7 +256,12 @@ def prepare_graph_input_buffers(self, def begin_forward(self, model_input): assert not self._is_graph_capturing state = self - if model_input.attn_metadata.use_cuda_graph: + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + # In case of multistep chunked-prefill, there might be prefill requests + # scheduled while CUDA graph mode is enabled. We don't run graph in that + # case. + if use_cuda_graph and is_decode: batch_size = model_input.input_tokens.shape[0] state = (self.runner.graph_runners[model_input.virtual_engine] [batch_size].attn_state) @@ -429,10 +434,24 @@ def advance_step(self, Update metadata in-place to advance one decode step. """ - assert not turn_prefills_into_decodes, \ - ("Chunked prefill is not supported with flashinfer yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter.") + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None assert num_seqs > 0 assert num_queries > 0 diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b918483d3675..ae8b7f97c827d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -5,6 +5,7 @@ import time import warnings import weakref +from contextlib import contextmanager from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union) @@ -1028,6 +1029,8 @@ def __init__( self.has_inner_state = model_config.has_inner_state + self.in_profile_run = False + # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -1228,110 +1231,123 @@ def _prepare_model_input_tensors( return builder.build() # type: ignore + @contextmanager + def set_in_profile_run(self): + self.in_profile_run = True + try: + yield + finally: + self.in_profile_run = False + @torch.inference_mode() def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, which - # needs to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data.multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() - return + with self.set_in_profile_run(): + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = \ + SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the + # total number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, + # which needs to be accounted for when calculating the GPU blocks + # for vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data. + multi_modal_placeholders, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(num_layers) + ] + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = \ + self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + return def remove_all_loras(self): if not self.lora_manager: diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index acce923498d7e..4aab09c80826b 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -32,7 +32,7 @@ MULTI_STEP_ATTENTION_BACKENDS = [ "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" ] -MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"] def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ -> List[str]: