Skip to content

Commit

Permalink
Revert "add cuda graph support"
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 31, 2025
1 parent 31c34bf commit 433322b
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 87 deletions.
3 changes: 1 addition & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ def __init__(self, runner: "ModelRunnerBase"):

@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield

Expand Down
5 changes: 1 addition & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,7 @@ def _get_decode_wrapper(self):
return self._decode_wrapper

@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
assert positions is None

def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
Expand Down
82 changes: 11 additions & 71 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,93 +90,33 @@ class TritonMLAState(AttentionState):

def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False

@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
self._is_graph_capturing = True

self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)

assert positions is not None
self._positions = positions

yield

self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_capture(self, max_batch_size: int):
raise NotImplementedError(
"TritonMLAState does not support graph capture")

def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
return self.__class__(self.runner)
raise NotImplementedError(
"TritonMLAState does not support graph capture")

def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing

attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())

if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")

return attn_metadata
raise NotImplementedError(
"TritonMLAState does not support graph capture")

def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")

return input_buffers
raise NotImplementedError(
"TritonMLAState does not support graph capture")

def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
raise NotImplementedError(
"TritonMLAState does not support graph capture")

def begin_forward(self, model_input):
return
Expand Down
9 changes: 2 additions & 7 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -289,10 +288,8 @@ def __init__(self, runner: "ModelRunnerBase"):
self._is_graph_capturing = False

@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True

self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
Expand All @@ -302,9 +299,7 @@ def graph_capture(self, max_batch_size: int,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)

yield

self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
Expand Down
5 changes: 2 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,9 +1468,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
dtype=self.model_config.dtype,
device=self.device)

with self.attn_state.graph_capture(
max_batch_size, input_positions), graph_capture(
self.device) as graph_capture_context:
with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
Expand Down

0 comments on commit 433322b

Please sign in to comment.