diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index a57fdac803e42..1e7f560fc68cc 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -5,6 +5,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.worker.embedding_model_runner import ( @@ -29,7 +30,11 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: @staticmethod def get_builder_cls() -> Type["AttentionMetadataBuilder"]: - raise AttentionMetadataBuilder + return AttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 4643d316d48b7..2cd4ad3e00135 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,7 +1,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionState, AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -12,5 +12,6 @@ "AttentionType", "AttentionMetadataBuilder", "Attention", + "AttentionState", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 23c7830cd6264..ccfc6b254c1e7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass, fields from enum import Enum, auto from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, @@ -7,7 +8,9 @@ import torch if TYPE_CHECKING: - from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase + from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase) class AttentionType(Enum): @@ -34,6 +37,11 @@ def get_impl_cls() -> Type["AttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: raise NotImplementedError + @staticmethod + @abstractmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -126,6 +134,47 @@ def asdict_zerocopy(self, T = TypeVar("T", bound=AttentionMetadata) +class AttentionState(ABC, Generic[T]): + """Holds attention backend-specific objects reused during the + lifetime of the model runner.""" + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing CUDA graphs.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], + attn_metadata: T) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + class AttentionMetadataBuilder(ABC, Generic[T]): """Abstract class for attention metadata builders.""" diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 907b45393eeb5..d84a40890ebbd 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -5,7 +5,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import CommonMetadataBuilder +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention @@ -98,6 +99,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: return BlocksparseFlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f146285bfc9e2..30ce715d5d05a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,7 +9,8 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, + compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -142,6 +143,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3022fa70e2ca7..2aa3bd79e4a64 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,14 +1,19 @@ +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper import vllm.attention.backends.flash_attn # noqa + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -16,7 +21,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionState, AttentionType) from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) @@ -46,6 +51,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -75,6 +84,160 @@ def get_supported_head_sizes() -> List[int]: return [64, 128, 256] +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch(self, batch_size: int): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, "NHD", + use_tensor_cores) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, attn_metadata): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, input_buffers, attn_metadata): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + state = (self.runner.graph_runners[model_input.virtual_engine] + [batch_size].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + @dataclass class FlashInferMetadata(AttentionMetadata): # Maximum sequence length among prefill batch. 0 if there are decoding diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index bac30aec24826..64d60e4e47e48 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -8,6 +8,7 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -28,6 +29,10 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]: def get_metadata_cls() -> Type["IpexAttnMetadata"]: return IpexAttnMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py index 0f21b50ad4dc7..7992c70f52659 100644 --- a/vllm/attention/backends/openvino.py +++ b/vllm/attention/backends/openvino.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from typing import List, Tuple +from typing import List, Tuple, Type import openvino as ov import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) +from vllm.attention.backends.utils import CommonAttentionState class OpenVINOAttentionBackend(AttentionBackend): @@ -24,6 +25,10 @@ def get_impl_cls(): def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": return OpenVINOAttentionMetadata(*args, **kwargs) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 4ecf698c8d514..ac03b6d8b1ead 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -6,6 +6,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState class PallasAttentionBackend(AttentionBackend): @@ -18,6 +19,10 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: def get_metadata_cls() -> Type["PallasMetadata"]: return PallasMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e305679231d02..b0f4d0530b7f0 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import CommonMetadataBuilder +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -33,6 +34,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: return ROCmFlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index b83c673f0165e..8a1f8f2930c84 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -8,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu @@ -34,6 +35,10 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return TorchSDPAMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e6b5f820c5fa0..0375d3488eb15 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,12 +1,17 @@ """Attention backend utils""" -from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union import numpy as np import torch -from vllm.attention import AttentionMetadata, AttentionMetadataBuilder +from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) from vllm.utils import async_tensor_h2d, make_tensor_with_pad +if TYPE_CHECKING: + from vllm.worker.model_runner_base import ModelRunnerBase + # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " @@ -269,3 +274,69 @@ def build(self, seq_lens: List[int], query_lens: List[int], block_tables=block_tables, use_cuda_graph=use_captured_graph, ) + + +class CommonAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch(self, batch_size: int): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + return attn_metadata + + def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: + return { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + + def prepare_graph_input_buffers(self, input_buffers, + attn_metadata) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + + def begin_forward(self, model_input) -> None: + return diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 7e36509bff864..e073d616bf01d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import CommonMetadataBuilder +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -37,6 +38,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["XFormersMetadataBuilder"]: return XFormersMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 1bb3b83744fec..053e9203e01eb 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -11,17 +11,6 @@ from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -90,11 +79,6 @@ def __init__( observability_config=observability_config, ) - self.flashinfer_decode_workspace_buffer = None - self.flashinfer_decode_wrapper = None - self.flashinfer_prefill_workspace_buffer = None - self.flashinfer_prefill_wrapper = None - def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): @@ -270,36 +254,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - if self.attn_backend.get_name() == "flashinfer": - assert model_input.attn_metadata is not None - assert model_input.input_tokens is not None - if self.flashinfer_decode_workspace_buffer is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - - model_input.attn_metadata.prefill_wrapper = \ - self.flashinfer_prefill_wrapper - if model_input.attn_metadata.use_cuda_graph: - batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = \ - self.graph_runners[model_input. - virtual_engine][batch_size].flashinfer_decode_wrapper - else: - model_input.attn_metadata.decode_wrapper = \ - self.flashinfer_decode_wrapper - model_input.attn_metadata.begin_forward() + self.attn_state.begin_forward(model_input) # Detect exec mode assert model_input.attn_metadata is not None diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 1afda0e45b702..5c700229660c0 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -6,6 +6,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) @@ -20,7 +21,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad -from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase, +from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -395,7 +396,7 @@ def _prepare_encoder_model_input_tensors( # initialized yet. In this case, we just use a dummy # slot mapping. # In embeddings, the block tables are {seq_id: None}. - cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) else: for i in range(0, seq_len): block_number = seq_group_metadata.cross_block_table[ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9f27c734efd1e..793f03456e997 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -13,19 +13,10 @@ import torch.distributed import torch.nn as nn -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.utils import CommonAttentionState from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -52,8 +43,7 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available) + flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -66,7 +56,6 @@ logger = init_logger(__name__) -_PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. @@ -858,6 +847,11 @@ def __init__( self.kv_cache_dtype, self.block_size, ) if num_attn_heads else None + if self.attn_backend: + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) + else: + self.attn_state = CommonAttentionState(weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry @@ -872,11 +866,6 @@ def __init__( self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None - self.flashinfer_decode_workspace_buffer = None - self.flashinfer_decode_wrapper = None - self.flashinfer_prefill_workspace_buffer = None - self.flashinfer_prefill_wrapper = None - set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -1203,10 +1192,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() - slot_mapping.fill_(_PAD_SLOT_ID) - seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() - block_tables = torch.from_numpy(self.graph_block_tables).cuda() intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1226,102 +1211,16 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - if self.attn_backend.get_name() == "flashinfer": - # For flashinfer, different batch sizes will share the - # same workspace buffer. - decode_workspace_buffer = \ - torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - indices_buffer = torch.empty(max_batch_size * - self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) - indptr_buffer = torch.empty(max_batch_size + 1, - dtype=torch.int32, - device=self.device) - last_page_len_buffer = torch.empty(max_batch_size, - dtype=torch.int32, - device=self.device) - - with graph_capture() as graph_capture_context: + with self.attn_state.graph_capture( + max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): for batch_size in reversed(batch_size_capture_list): - if self.attn_backend.get_name() == "flashinfer": - _indptr_buffer = indptr_buffer[:batch_size + 1] - _last_page_len_buffer = last_page_len_buffer[: - batch_size] - - num_qo_heads = ( - self.model_config.get_num_attention_heads( - self.parallel_config)) - num_kv_heads = self.model_config.get_num_kv_heads( - self.parallel_config) - if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = True - else: - use_tensor_cores = False - decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, _indptr_buffer, - indices_buffer, _last_page_len_buffer, "NHD", - use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.kv_cache_dtype, self.model_config.dtype) - - paged_kv_indptr_tensor_host = torch.arange( - 0, batch_size + 1, dtype=torch.int32) - paged_kv_indices_tensor_host = torch.arange( - 0, batch_size, dtype=torch.int32) - paged_kv_last_page_len_tensor_host = torch.full( - (batch_size, ), self.block_size, dtype=torch.int32) - query_start_loc_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - slot_mapping=slot_mapping[:batch_size], - num_prefill_tokens=0, - num_decode_tokens=batch_size, - max_prefill_seq_len=0, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor_host, - paged_kv_indices=paged_kv_indices_tensor_host, - paged_kv_last_page_len= - paged_kv_last_page_len_tensor_host, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=None, - query_start_loc=query_start_loc_host, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=True, - decode_wrapper=decode_wrapper, - prefill_wrapper=None) - attn_metadata.begin_forward() - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) + attn_metadata = ( + self.attn_state.graph_capture_get_metadata_for_batch( + batch_size)) if self.lora_config: lora_mapping = LoRAMapping( @@ -1339,17 +1238,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: set(), prompt_adapter_mapping) graph_runner = CUDAGraphRunner( - self.model, self.attn_backend.get_name()) - - if self.attn_backend.get_name() == "flashinfer": - graph_runner.flashinfer_indptr_buffer = _indptr_buffer - graph_runner.flashinfer_indices_buffer = indices_buffer - graph_runner.flashinfer_last_page_len_buffer = \ - _last_page_len_buffer - graph_runner.flashinfer_decode_workspace_buffer = \ - decode_workspace_buffer - graph_runner.flashinfer_decode_wrapper = \ - decode_wrapper + self.model, self.attn_backend.get_name(), + self.attn_state.graph_clone(batch_size)) capture_inputs = { "input_ids": @@ -1476,36 +1366,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - if self.attn_backend.get_name() == "flashinfer": - assert model_input.attn_metadata is not None - assert model_input.input_tokens is not None - if self.flashinfer_decode_workspace_buffer is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - - model_input.attn_metadata.prefill_wrapper = \ - self.flashinfer_prefill_wrapper - if model_input.attn_metadata.use_cuda_graph: - batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = self.graph_runners[ - model_input. - virtual_engine][batch_size].flashinfer_decode_wrapper - else: - model_input.attn_metadata.decode_wrapper = \ - self.flashinfer_decode_wrapper - model_input.attn_metadata.begin_forward() + self.attn_state.begin_forward(model_input) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None @@ -1613,22 +1474,17 @@ def execute_model( class CUDAGraphRunner: - def __init__(self, model: nn.Module, backend_name: str): + def __init__(self, model: nn.Module, backend_name: str, + attn_state: AttentionState): self.model = model self.backend_name = backend_name + self.attn_state = attn_state self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} self._graph: Optional[torch.cuda.CUDAGraph] = None - self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None - self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None - self.flashinfer_indices_buffer: Optional[torch.Tensor] = None - self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None - self.flashinfer_decode_wrapper: Optional[ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None - @property def graph(self): assert self._graph is not None @@ -1693,25 +1549,13 @@ def capture( torch.cuda.synchronize() # Save the input and output buffers. - if self.backend_name == "flashinfer": - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - **kwargs, - } - else: - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": - attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - **kwargs, - } + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + **self.attn_state.get_graph_input_buffers(attn_metadata), + **kwargs, + } if intermediate_inputs is not None: self.input_buffers.update(intermediate_inputs.tensors) if get_pp_group().is_last_rank: @@ -1739,12 +1583,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - if self.backend_name != "flashinfer": - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, - non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + self.attn_state.prepare_graph_input_buffers(self.input_buffers, + attn_metadata) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs)