diff --git a/CMakeLists.txt b/CMakeLists.txt index 0945905104f32..5039ac2448f83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # Suppress potential warnings about unused manually-specified variables set(ignoreMe "${VLLM_PYTHON_PATH}") -# Prevent installation of dependencies (cutlass) by default. -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) - # # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. @@ -535,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") endif() # vllm-flash-attn currently only supported on CUDA -if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") +if (NOT VLLM_GPU_LANG STREQUAL "CUDA") return() endif () @@ -558,7 +555,7 @@ endif() # They should be identical but if they aren't, this is a massive footgun. # # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. -# To only install vllm-flash-attn, use --component vllm_flash_attn_c. +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). # If no component is specified, vllm-flash-attn is still installed. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. @@ -570,43 +567,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) endif() if(VLLM_FLASH_ATTN_SRC_DIR) - FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR}) + FetchContent_Declare( + vllm-flash-attn SOURCE_DIR + ${VLLM_FLASH_ATTN_SRC_DIR} + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c + GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn ) endif() -# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. -set(VLLM_PARENT_BUILD ON) - -# Ensure the vllm/vllm_flash_attn directory exists before installation -install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c) - -# Make sure vllm-flash-attn install rules are nested under vllm/ -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c) -install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c) # Fetch the vllm-flash-attn library FetchContent_MakeAvailable(vllm-flash-attn) message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") -# Restore the install prefix -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c) +# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in +# case only one is built, in the case both are built redundant work is done) +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa2_C + FILES_MATCHING PATTERN "*.py" +) -# Copy over the vllm-flash-attn python files install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm/vllm_flash_attn - COMPONENT vllm_flash_attn_c - FILES_MATCHING PATTERN "*.py" + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa3_C + FILES_MATCHING PATTERN "*.py" ) # Nothing after vllm-flash-attn, see comment about macros above diff --git a/setup.py b/setup.py index dde5072139660..36c89d435c7b7 100644 --- a/setup.py +++ b/setup.py @@ -228,8 +228,11 @@ def target_name(s: str) -> str: # CMake appends the extension prefix to the install path, # and outdir already contains that prefix, so we need to remove it. + # We assume only the final component of extension prefix is added by + # CMake, this is currently true for current extensions but may not + # always be the case. prefix = outdir - for i in range(ext.name.count('.')): + if '.' in ext.name: prefix = prefix.parent # prefix here should actually be the same for all components @@ -298,7 +301,8 @@ def run(self) -> None: files_to_copy = [ "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", - "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", + "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", + "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", "vllm/vllm_flash_attn/__init__.py", "vllm/cumem_allocator.abi3.so", @@ -593,8 +597,8 @@ def _read_requirements(filename: str) -> List[str]: ext_modules.append(CMakeExtension(name="vllm._rocm_C")) if _is_cuda(): - ext_modules.append( - CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c")) + ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) + ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/kernels/test_cascade_flash_attn.py b/tests/kernels/test_cascade_flash_attn.py index 45ec6df4e711e..00eb927205d46 100644 --- a/tests/kernels/test_cascade_flash_attn.py +++ b/tests/kernels/test_cascade_flash_attn.py @@ -78,6 +78,7 @@ def test_merge_kernel( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("soft_cap", [None, 50]) @pytest.mark.parametrize("num_blocks", [2048]) +@pytest.mark.parametrize("fa_version", [2, 3]) @torch.inference_mode() def test_cascade( seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int], @@ -87,8 +88,14 @@ def test_cascade( block_size: int, soft_cap: Optional[float], num_blocks: int, + fa_version: int, ) -> None: torch.set_default_device("cuda") + if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) + or torch.cuda.get_device_capability() == (8, 9)): + pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " + "insufficient shared memory for some shapes") + current_platform.seed_everything(0) window_size = (-1, -1) @@ -118,9 +125,7 @@ def test_cascade( cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint(0, num_blocks, @@ -140,7 +145,7 @@ def test_cascade( k=key_cache, v=value_cache, cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, + seqused_k=kv_lens_tensor, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, softmax_scale=scale, @@ -154,10 +159,8 @@ def test_cascade( assert all(common_prefix_len < kv_len for kv_len in kv_lens) cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], dtype=torch.int32) - cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32) - cu_suffix_kv_lens = ( - cu_kv_lens - - torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len) + prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32) + suffix_kv_lens = kv_lens_tensor - common_prefix_len output = torch.empty_like(query) cascade_attention( output=output, @@ -167,8 +170,8 @@ def test_cascade( cu_query_lens=cu_query_lens, max_query_len=max_query_len, cu_prefix_query_lens=cu_prefix_query_lens, - cu_prefix_kv_lens=cu_prefix_kv_lens, - cu_suffix_kv_lens=cu_suffix_kv_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, max_kv_len=max_kv_len, softmax_scale=scale, alibi_slopes=None, @@ -176,6 +179,7 @@ def test_cascade( logits_soft_cap=soft_cap if soft_cap is not None else 0, block_table=block_tables, common_prefix_len=common_prefix_len, + fa_version=fa_version, ) # Compare the results. diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 1ae78d7b46c5b..b22153c86b25f 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -80,6 +80,7 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("fa_version", [2, 3]) @torch.inference_mode() def test_flash_attn_with_paged_kv( use_out: bool, @@ -91,8 +92,14 @@ def test_flash_attn_with_paged_kv( soft_cap: Optional[float], num_blocks: int, sliding_window: Optional[int], + fa_version: int, ) -> None: torch.set_default_device("cuda") + if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) + or torch.cuda.get_device_capability() == (8, 9)): + pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " + "insufficient shared memory for some shapes") + current_platform.seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] @@ -131,6 +138,7 @@ def test_flash_attn_with_paged_kv( cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, window_size=window_size, + fa_version=fa_version, ) output = output if not use_out else out output = output.squeeze(1) @@ -159,6 +167,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("fa_version", [2, 3]) @torch.inference_mode() def test_varlen_with_paged_kv( use_out: bool, @@ -170,8 +179,14 @@ def test_varlen_with_paged_kv( block_size: int, soft_cap: Optional[float], num_blocks: int, + fa_version: int, ) -> None: torch.set_default_device("cuda") + if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) + or torch.cuda.get_device_capability() == (8, 9)): + pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " + "insufficient shared memory for some shapes") + current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] @@ -198,9 +213,7 @@ def test_varlen_with_paged_kv( cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint(0, @@ -215,7 +228,7 @@ def test_varlen_with_paged_kv( v=value_cache, out=out, cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, + seqused_k=kv_lens, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, softmax_scale=scale, @@ -223,6 +236,7 @@ def test_varlen_with_paged_kv( window_size=window_size, block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, + fa_version=fa_version, ) output = output if not use_out else out diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 60ed09d0cc44f..18acfb82fac58 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -17,7 +17,9 @@ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.envs import VLLM_FLASH_ATTN_VERSION from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -25,7 +27,8 @@ ModelInputForGPUWithSamplingMetadata) from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) + flash_attn_with_kvcache, + is_fa_version_supported) class FlashAttentionBackend(AttentionBackend): @@ -634,6 +637,20 @@ def __init__( f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type + # if hopper default to FA3, otherwise stick to FA2 for now + # TODO(lucas): profile FA3 on ampere to see if it makes sense to + # use FA3 as default for both + if current_platform.get_device_capability()[0] >= 9: + self.fa_version = 3 if is_fa_version_supported(3) else 2 + else: + self.fa_version = 2 + + if VLLM_FLASH_ATTN_VERSION is not None: + assert VLLM_FLASH_ATTN_VERSION in [2, 3] + self.fa_version = VLLM_FLASH_ATTN_VERSION + + assert is_fa_version_supported(self.fa_version) + def forward( self, layer: AttentionLayer, @@ -752,6 +769,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=prefill_output, + fa_version=self.fa_version, ) else: # prefix-enabled attention @@ -765,7 +783,7 @@ def forward( v=value_cache, cu_seqlens_q=prefill_meta.query_start_loc, max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, + seqused_k=prefill_meta.seq_lens_tensor, max_seqlen_k=max_seq_len, softmax_scale=softmax_scale, causal=True, @@ -774,6 +792,7 @@ def forward( block_table=prefill_meta.block_tables, softcap=logits_soft_cap, out=prefill_output, + fa_version=self.fa_version, ) if decode_meta := attn_metadata.decode_metadata: @@ -793,7 +812,7 @@ def forward( v=value_cache, cu_seqlens_q=decode_meta.query_start_loc, max_seqlen_q=decode_meta.max_decode_query_len, - cu_seqlens_k=decode_meta.seq_start_loc, + seqused_k=decode_meta.seq_lens_tensor, max_seqlen_k=decode_meta.max_decode_seq_len, softmax_scale=softmax_scale, causal=True, @@ -802,6 +821,7 @@ def forward( softcap=logits_soft_cap, block_table=decode_meta.block_tables, out=decode_output, + fa_version=self.fa_version, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -822,6 +842,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=decode_output.unsqueeze(1), + fa_version=self.fa_version, ) return output diff --git a/vllm/envs.py b/vllm/envs.py index 3a15e00e7b50a..b72e9141ac792 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -11,6 +11,7 @@ VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 @@ -90,6 +91,12 @@ def get_default_config_root(): ) +def maybe_convert_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + return int(value) + + # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -203,6 +210,11 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid + # when using the flash-attention backend. + "VLLM_FLASH_ATTN_VERSION": + lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)), + # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fd36ea8d8806b..1806fec8833a3 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -9,8 +9,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + is_fa_version_supported) class FlashAttentionBackend(AttentionBackend): @@ -63,7 +66,7 @@ class FlashAttentionMetadata: max_query_len: int query_start_loc: torch.Tensor max_seq_len: int - seq_start_loc: torch.Tensor + seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor @@ -71,8 +74,8 @@ class FlashAttentionMetadata: use_cascade: bool common_prefix_len: int cu_prefix_query_lens: Optional[torch.Tensor] - cu_prefix_kv_lens: Optional[torch.Tensor] - cu_suffix_kv_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -128,6 +131,20 @@ def __init__( "are not implemented for " "FlashAttentionImpl") + # if hopper default to FA3, otherwise stick to FA2 for now + # TODO(lucas): profile FA3 on ampere to see if it makes sense to + # use FA3 as default for both + if current_platform.get_device_capability()[0] >= 9: + self.fa_version = 3 if is_fa_version_supported(3) else 2 + else: + self.fa_version = 2 + + if VLLM_FLASH_ATTN_VERSION is not None: + assert VLLM_FLASH_ATTN_VERSION in [2, 3] + self.fa_version = VLLM_FLASH_ATTN_VERSION + + assert is_fa_version_supported(self.fa_version) + def forward( self, layer: torch.nn.Module, @@ -196,7 +213,7 @@ def forward( out=output[:num_actual_tokens], cu_seqlens_q=attn_metadata.query_start_loc, max_seqlen_q=attn_metadata.max_query_len, - cu_seqlens_k=attn_metadata.seq_start_loc, + seqused_k=attn_metadata.seq_lens, max_seqlen_k=attn_metadata.max_seq_len, softmax_scale=self.scale, causal=True, @@ -204,6 +221,7 @@ def forward( window_size=self.sliding_window, block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, + fa_version=self.fa_version, ) return output @@ -216,8 +234,8 @@ def forward( cu_query_lens=attn_metadata.query_start_loc, max_query_len=attn_metadata.max_query_len, cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, - cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens, - cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, max_kv_len=attn_metadata.max_seq_len, softmax_scale=self.scale, alibi_slopes=self.alibi_slopes, @@ -225,6 +243,7 @@ def forward( logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, + fa_version=self.fa_version, ) return output @@ -305,8 +324,8 @@ def cascade_attention( cu_query_lens: torch.Tensor, max_query_len: int, cu_prefix_query_lens: torch.Tensor, - cu_prefix_kv_lens: torch.Tensor, - cu_suffix_kv_lens: torch.Tensor, + prefix_kv_lens: torch.Tensor, + suffix_kv_lens: torch.Tensor, max_kv_len: int, softmax_scale: float, alibi_slopes: Optional[torch.Tensor], @@ -314,6 +333,7 @@ def cascade_attention( logits_soft_cap: float, block_table: torch.Tensor, common_prefix_len: int, + fa_version: int, ) -> torch.Tensor: assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") # TODO: Support sliding window. @@ -332,7 +352,7 @@ def cascade_attention( k=key_cache, v=value_cache, cu_seqlens_q=cu_prefix_query_lens, - cu_seqlens_k=cu_prefix_kv_lens, + seqused_k=prefix_kv_lens, max_seqlen_q=num_tokens, max_seqlen_k=common_prefix_len, softmax_scale=softmax_scale, @@ -341,6 +361,7 @@ def cascade_attention( block_table=block_table[:1], softcap=logits_soft_cap, return_softmax_lse=True, + fa_version=fa_version, ) # Process suffix per query. @@ -349,7 +370,7 @@ def cascade_attention( k=key_cache, v=value_cache, cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_suffix_kv_lens, + seqused_k=suffix_kv_lens, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len - common_prefix_len, softmax_scale=softmax_scale, @@ -358,6 +379,7 @@ def cascade_attention( block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, + fa_version=fa_version, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fdf39449a2c59..99d463b940923 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -199,11 +199,11 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() + self.seq_lens_cpu = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -412,11 +412,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): np.cumsum(num_scheduled_tokens, out=self.query_start_loc_np[1:num_reqs + 1]) - seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - max_seq_len = seq_lens.max() - self.seq_start_loc_np[0] = 0 - np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1]) + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + max_seq_len = self.seq_lens_np[:num_reqs].max() # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( @@ -433,8 +432,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): non_blocking=True) query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) - seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( - self.device, non_blocking=True) + seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, + non_blocking=True) slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( self.device, non_blocking=True).long() @@ -506,33 +505,30 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): [0, total_num_scheduled_tokens], dtype=torch.int32, device=self.device) - cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], - dtype=torch.int32, - device=self.device) - cu_suffix_kv_lens = ( - self.seq_start_loc_np[:num_reqs + 1] - - self.arange_np[:num_reqs + 1] * common_prefix_len) - cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to( - self.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.device) + suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len) + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device) else: cu_prefix_query_lens = None - cu_prefix_kv_lens = None - cu_suffix_kv_lens = None + prefix_kv_lens = None + suffix_kv_lens = None attn_metadata = FlashAttentionMetadata( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, query_start_loc=query_start_loc, max_seq_len=max_seq_len, - seq_start_loc=seq_start_loc, + seq_lens=seq_lens, block_table=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, cu_prefix_query_lens=cu_prefix_query_lens, - cu_prefix_kv_lens=cu_prefix_kv_lens, - cu_suffix_kv_lens=cu_suffix_kv_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this