From 54ba87d99078335ccbf4c7475133591165552467 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 30 Jan 2025 21:23:09 +0000 Subject: [PATCH] add cuda graph support Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/abstract.py | 3 +- vllm/attention/backends/flashinfer.py | 5 +- vllm/attention/backends/triton_mla.py | 82 +++++++++++++++++++++++---- vllm/attention/backends/utils.py | 11 +++- vllm/worker/model_runner.py | 5 +- 5 files changed, 89 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e9e4dd871827b..105b1bf8eb18d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -168,7 +168,8 @@ def __init__(self, runner: "ModelRunnerBase"): @abstractmethod @contextmanager - def graph_capture(self, max_batch_size: int): + def graph_capture(self, max_batch_size: int, + positions: Optional[torch.Tensor]): """Context manager used when capturing CUDA graphs.""" yield diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7cccef9608218..d50a727c0de4b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -213,7 +213,10 @@ def _get_decode_wrapper(self): return self._decode_wrapper @contextmanager - def graph_capture(self, max_batch_size: int): + def graph_capture(self, max_batch_size: int, + positions: Optional[torch.Tensor]): + assert positions is None + self._is_graph_capturing = True self._graph_decode_wrapper = None self._graph_slot_mapping = torch.full((max_batch_size, ), diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 40f14e4eae0db..00b157eb35009 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -90,33 +90,93 @@ class TritonMLAState(AttentionState): def __init__(self, runner): self.runner = runner + self._is_graph_capturing = False @contextmanager - def graph_capture(self, max_batch_size: int): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + def graph_capture(self, max_batch_size: int, + positions: Optional[torch.Tensor]): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + assert positions is not None + self._positions = positions + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables def graph_clone(self, batch_size: int): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + assert self._is_graph_capturing + return self.__class__(self.runner) def graph_capture_get_metadata_for_batch( self, batch_size: int, is_encoder_decoder_model: bool = False): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return attn_metadata def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return input_buffers def prepare_graph_input_buffers(self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") def begin_forward(self, model_input): return diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 84fe89b7df360..655111e8a3cf7 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -2,7 +2,8 @@ from collections import defaultdict from contextlib import contextmanager from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -288,8 +289,12 @@ def __init__(self, runner: "ModelRunnerBase"): self._is_graph_capturing = False @contextmanager - def graph_capture(self, max_batch_size: int): + def graph_capture(self, max_batch_size: int, + positions: Optional[torch.Tensor]): + assert positions is None + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), PAD_SLOT_ID, dtype=torch.long, @@ -299,7 +304,9 @@ def graph_capture(self, max_batch_size: int): device=self.runner.device) self._graph_block_tables = torch.from_numpy( self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False del self._graph_slot_mapping del self._graph_seq_lens diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b6ed3abab4247..0f654576aa67d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1468,8 +1468,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: dtype=self.model_config.dtype, device=self.device) - with self.attn_state.graph_capture(max_batch_size), graph_capture( - self.device) as graph_capture_context: + with self.attn_state.graph_capture( + max_batch_size, input_positions), graph_capture( + self.device) as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range(