From aebd189286f4dc7118af3d06a065dfcbd9436b8a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 24 Jul 2024 22:33:56 -0700 Subject: [PATCH] [Bugfix] Fix decode tokens w. CUDA graph (#6757) --- tests/worker/test_model_runner.py | 1 + vllm/attention/backends/flash_attn.py | 12 ++++++++++-- vllm/attention/backends/flashinfer.py | 11 ++++++++++- vllm/attention/backends/utils.py | 11 ++++++++++- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index b5742c4338616..4a0e2b4184936 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size): for _ in range(expected_bs - len(seq_lens)): seq_lens.append(1) assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.num_decode_tokens == len(seq_lens) start_idx = 0 start_loc = [start_idx] for _ in context_lens: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 949bd973cf3c4..7d7aff9dc3cdc 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -272,7 +272,15 @@ def _add_seq_group( def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors.""" + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) @@ -297,7 +305,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + cuda_graph_pad_size + num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size]. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 9746304347d6e..83a420d76834b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -320,6 +320,15 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) @@ -334,7 +343,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + cuda_graph_pad_size + num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size]. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 5877712b9b7d3..dcd10ed410a79 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -149,6 +149,15 @@ def _add_seq_group( def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) @@ -173,7 +182,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + cuda_graph_pad_size + num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size].