Skip to content

Commit

Permalink
[Core] Add AttentionState abstraction (vllm-project#7663)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Aug 20, 2024
1 parent c6af027 commit 3b68217
Show file tree
Hide file tree
Showing 16 changed files with 372 additions and 247 deletions.
7 changes: 6 additions & 1 deletion tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
Expand All @@ -29,7 +30,11 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:

@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder
return AttentionMetadataBuilder

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionState, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

Expand All @@ -12,5 +12,6 @@
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]
51 changes: 50 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Expand All @@ -7,7 +8,9 @@
import torch

if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerInputBuilderBase)


class AttentionType(Enum):
Expand All @@ -34,6 +37,11 @@ def get_impl_cls() -> Type["AttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
Expand Down Expand Up @@ -126,6 +134,47 @@ def asdict_zerocopy(self,
T = TypeVar("T", bound=AttentionMetadata)


class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""

@abstractmethod
def __init__(self, runner: "ModelRunnerBase"):
...

@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield

@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...

@abstractmethod
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...

@abstractmethod
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...

@abstractmethod
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
attn_metadata: T) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...

@abstractmethod
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
"""Prepare state for forward pass."""
...


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

Expand Down
7 changes: 6 additions & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
Expand Down Expand Up @@ -98,6 +99,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
7 changes: 6 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
Expand Down Expand Up @@ -142,6 +143,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
165 changes: 164 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type

try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper

import vllm.attention.backends.flash_attn # noqa
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
Expand Down Expand Up @@ -46,6 +51,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder

@staticmethod
def get_state_cls() -> Type["FlashInferState"]:
return FlashInferState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -75,6 +84,160 @@ def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]


class FlashInferState(AttentionState):

def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
self._workspace_buffer = None
self._decode_wrapper = None
self._prefill_wrapper = None

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer

def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
return self._prefill_wrapper

def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.runner.device)
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
self._graph_last_page_len_buffer = torch.empty(
max_batch_size, dtype=torch.int32, device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._graph_decode_workspace_buffer
del self._graph_indices_buffer
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper

def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
state = self.__class__(self.runner)
state._workspace_buffer = self._graph_decode_workspace_buffer
state._decode_wrapper = self._graph_decode_wrapper
state._prefill_wrapper = self._get_prefill_wrapper()
return state

def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]

num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)

paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(0,
batch_size,
dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
self.runner.block_size,
dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)

attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.runner.model_config.get_head_size(),
page_size=self.runner.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
return attn_metadata

def get_graph_input_buffers(self, attn_metadata):
return {
"slot_mapping": attn_metadata.slot_mapping,
}

def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
return

def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()


@dataclass
class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
Expand Down
5 changes: 5 additions & 0 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)

Expand All @@ -28,6 +29,10 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
Loading

0 comments on commit 3b68217

Please sign in to comment.