Skip to content

Commit

Permalink
[Bugfix] Fix decode tokens w. CUDA graph (vllm-project#6757)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jul 25, 2024
1 parent 7ce0623 commit 5bf0afc
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 4 deletions.
1 change: 1 addition & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.num_decode_tokens == len(seq_lens)
start_idx = 0
start_loc = [start_idx]
for _ in context_lens:
Expand Down
12 changes: 10 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,15 @@ def _add_seq_group(

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
Expand All @@ -297,7 +305,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
Expand Down
11 changes: 10 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,15 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
Expand All @@ -334,7 +343,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
Expand Down
11 changes: 10 additions & 1 deletion vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def _add_seq_group(

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
Expand All @@ -173,7 +182,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
Expand Down

0 comments on commit 5bf0afc

Please sign in to comment.