diff --git a/python/flashinfer/jit/activation.py b/python/flashinfer/jit/activation.py index 1e6a655f5..c0dc59b09 100644 --- a/python/flashinfer/jit/activation.py +++ b/python/flashinfer/jit/activation.py @@ -62,8 +62,7 @@ def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str: def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> None: gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) + os.makedirs(gen_directory, exist_ok=True) sources = [gen_directory / f"{act_func_name}_and_mul.cu"] write_if_different( sources[0], diff --git a/python/flashinfer/jit/attention.py b/python/flashinfer/jit/attention.py index 7bc5fed57..72c2fca26 100644 --- a/python/flashinfer/jit/attention.py +++ b/python/flashinfer/jit/attention.py @@ -155,8 +155,6 @@ def get_batch_decode_uri( def gen_batch_decode_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) uri = get_batch_decode_uri(*args) sources = get_batch_decode_sources(*args) source_paths = [] @@ -214,8 +212,6 @@ def get_batch_decode_mla_uri( def gen_batch_decode_mla_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) uri = get_batch_decode_mla_uri(*args) sources = get_batch_decode_mla_sources(*args) source_paths = [] @@ -275,8 +271,6 @@ def get_single_prefill_uri( def gen_single_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) uri = get_single_prefill_uri(*args) sources = get_single_prefill_sources(*args) source_paths = [] @@ -341,8 +335,6 @@ def get_batch_prefill_uri( def gen_batch_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) uri = get_batch_prefill_uri(*args) sources = get_batch_prefill_sources(*args) source_paths = [] @@ -518,8 +510,6 @@ def get_customize_single_prefill_sources( def gen_customize_single_decode_module(module_name, *args): gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) sources = get_customize_single_decode_sources(*args) source_paths = [] for suffix, source in zip(single_decode_suffix, sources): @@ -532,8 +522,6 @@ def gen_customize_single_decode_module(module_name, *args): def gen_customize_single_prefill_module(module_name, *args): gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) sources = get_customize_single_prefill_sources(*args) source_paths = [] for suffix, source in zip(single_prefill_suffix, sources): diff --git a/python/flashinfer/jit/core.py b/python/flashinfer/jit/core.py index 8b9490ca8..32ab7ca68 100644 --- a/python/flashinfer/jit/core.py +++ b/python/flashinfer/jit/core.py @@ -14,8 +14,8 @@ from .env import FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR from .env import FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR -if not os.path.exists(FLASHINFER_WORKSPACE_DIR): - os.makedirs(FLASHINFER_WORKSPACE_DIR) +os.makedirs(FLASHINFER_WORKSPACE_DIR, exist_ok=True) +os.makedirs(FLASHINFER_CSRC_DIR, exist_ok=True) class FlashInferJITLogger(logging.Logger): @@ -99,8 +99,7 @@ def load_cuda_ops( logger.info(f"Loading JIT ops: {name}") check_cuda_arch() build_directory = FLASHINFER_JIT_DIR / name - if not os.path.exists(build_directory): - os.makedirs(build_directory, exist_ok=True) + os.makedirs(build_directory, exist_ok=True) if extra_include_paths is None: extra_include_paths = [ FLASHINFER_INCLUDE_DIR, diff --git a/python/flashinfer/jit/utils.py b/python/flashinfer/jit/utils.py index 63a87c5f8..01a698706 100644 --- a/python/flashinfer/jit/utils.py +++ b/python/flashinfer/jit/utils.py @@ -16,7 +16,7 @@ import pathlib import threading -from typing import Callable, List +from typing import Any, Callable, List, Tuple import torch @@ -35,19 +35,19 @@ def write_if_different(path: pathlib.Path, content: str) -> None: def parallel_load_modules( - load_module_funcs: List[Callable], + load_module_func_args: List[Tuple[Callable, List[Any]]], ): threads = [] exceptions = [] - def wrapper(func): + def wrapper(func, args): try: - func() + func(*args) except Exception as e: exceptions.append((func, e)) - for func in load_module_funcs: - thread = threading.Thread(target=wrapper, args=(func,)) + for func, args in load_module_func_args: + thread = threading.Thread(target=wrapper, args=(func, args)) thread.start() threads.append(thread) diff --git a/tests/jit_utils.py b/tests/jit_utils.py new file mode 100644 index 000000000..9e62dad38 --- /dev/null +++ b/tests/jit_utils.py @@ -0,0 +1,149 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import itertools + +import torch + +import flashinfer + + +def jit_decode_attention_func_args( + q_dtypes, + kv_dtypes, + head_dims, + pos_encoding_modes, + use_sliding_window_options, + use_logits_soft_cap_options, +): + load_module_func_args = [] + + for ( + q_dtype, + kv_dtype, + head_dim, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + ) in itertools.product( + q_dtypes, + kv_dtypes, + head_dims, + pos_encoding_modes, + use_sliding_window_options, + use_logits_soft_cap_options, + ): + load_module_func_args.append( + ( + flashinfer.decode.get_single_decode_module, + ( + q_dtype, + kv_dtype, + q_dtype, + head_dim, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + ), + ) + ) + load_module_func_args.append( + ( + flashinfer.decode.get_batch_decode_module, + ( + q_dtype, + kv_dtype, + q_dtype, + torch.int32, + head_dim, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + ), + ) + ) + + return load_module_func_args + + +def jit_prefill_attention_func_args( + q_dtypes, + kv_dtypes, + head_dims, + pos_encoding_modes, + use_sliding_window_options, + use_logits_soft_cap_options, + allow_fp16_qk_reduction_options, +): + load_module_func_args = [] + + for ( + q_dtype, + kv_dtype, + head_dim, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + allow_fp16_qk_reduction, + ) in itertools.product( + q_dtypes, + kv_dtypes, + head_dims, + pos_encoding_modes, + use_sliding_window_options, + use_logits_soft_cap_options, + allow_fp16_qk_reduction_options, + ): + load_module_func_args.append( + ( + flashinfer.prefill.gen_single_prefill_module, + ( + q_dtype, + kv_dtype, + q_dtype, + head_dim, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + allow_fp16_qk_reduction, + ), + ) + ) + load_module_func_args.append( + ( + flashinfer.prefill.gen_batch_prefill_module, + ( + q_dtype, + kv_dtype, + q_dtype, + torch.int32, + head_dim, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + allow_fp16_qk_reduction, + ), + ) + ) + + load_module_func_args.append( + ( + flashinfer.quantization.get_quantization_module, + [], + ) # required for attention with custom mask + ) + + return load_module_func_args diff --git a/tests/test_alibi.py b/tests/test_alibi.py index 67ed7abb0..f01811ecf 100644 --- a/tests/test_alibi.py +++ b/tests/test_alibi.py @@ -18,10 +18,42 @@ import pytest import torch from alibi_reference import alibi_attention +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + @pytest.mark.parametrize("seq_len", [1, 9, 81, 729]) @pytest.mark.parametrize("num_heads", [4, 8, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index 9cd272f78..4d2d67c61 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -16,10 +16,42 @@ import pytest import torch +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes + [128, 256], # head_dims + [0, 1, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes + [128, 256], # head_dims + [0, 1, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97, 512]) @pytest.mark.parametrize("page_size", [1, 8, 16]) diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index b7b0c981a..f9ceadee6 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -16,10 +16,34 @@ import pytest import torch +from jit_utils import jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes + [128, 256], # head_dims + [0, 1, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97]) @pytest.mark.parametrize("qo_len", [37, 17]) diff --git a/tests/test_block_sparse.py b/tests/test_block_sparse.py index 01950e090..8672dbb0f 100644 --- a/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -18,10 +18,42 @@ import pytest import scipy as sp import torch +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + def bsr_attention_ref( q, k, diff --git a/tests/test_jit_warmup.py b/tests/test_jit_warmup.py index a664c0262..dc0d5b892 100644 --- a/tests/test_jit_warmup.py +++ b/tests/test_jit_warmup.py @@ -24,31 +24,37 @@ def test_warmpup_llama(): parallel_load_modules( [ - lambda: flashinfer.activation.get_act_and_mul_module("silu"), - flashinfer.norm.get_norm_module, - flashinfer.sampling.get_sampling_module, - flashinfer.quantization.get_quantization_module, - flashinfer.page.get_page_module, - lambda: flashinfer.decode.get_batch_decode_module( - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 128, - PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap + (flashinfer.activation.get_act_and_mul_module, ["silu"]), + (flashinfer.norm.get_norm_module, []), + (flashinfer.sampling.get_sampling_module, []), + (flashinfer.quantization.get_quantization_module, []), + (flashinfer.page.get_page_module, []), + ( + flashinfer.decode.get_batch_decode_module, + [ + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 128, + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + ], ), - lambda: flashinfer.prefill.gen_batch_prefill_module( - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 128, - PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap - False, # allow_fp16_qk_reduction + ( + flashinfer.prefill.gen_batch_prefill_module, + [ + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 128, + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + False, # allow_fp16_qk_reduction + ], ), ] ) diff --git a/tests/test_logits_cap.py b/tests/test_logits_cap.py index 2cd38b18a..c42278aa1 100644 --- a/tests/test_logits_cap.py +++ b/tests/test_logits_cap.py @@ -18,10 +18,42 @@ import pytest import torch +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + def attention_logits_soft_cap_torch(q, k, v, soft_cap): q_len, num_heads, head_dim = q.shape kv_len = k.shape[0] diff --git a/tests/test_non_contiguous_decode.py b/tests/test_non_contiguous_decode.py index 8fdc0ac63..22db5f87a 100644 --- a/tests/test_non_contiguous_decode.py +++ b/tests/test_non_contiguous_decode.py @@ -1,9 +1,41 @@ import pytest import torch +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + @pytest.mark.parametrize("batch_size", [1, 19, 99]) @pytest.mark.parametrize("page_size", [1, 5]) @pytest.mark.parametrize("seq_len", [1]) diff --git a/tests/test_non_contiguous_prefill.py b/tests/test_non_contiguous_prefill.py index ccf0a8308..a45c09adc 100644 --- a/tests/test_non_contiguous_prefill.py +++ b/tests/test_non_contiguous_prefill.py @@ -16,10 +16,34 @@ import pytest import torch +from jit_utils import jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + @pytest.mark.parametrize("seq_len", [1, 7, 127, 999, 3579]) @pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) @pytest.mark.parametrize("num_qo_heads", [4, 8, 32]) diff --git a/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py index 8ec9deb32..9ef2103c9 100644 --- a/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -16,10 +16,42 @@ import pytest import torch +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + def ceil_div(a, b): return (a + b - 1) // b diff --git a/tests/test_sliding_window.py b/tests/test_sliding_window.py index cf389e2ee..e9a700790 100644 --- a/tests/test_sliding_window.py +++ b/tests/test_sliding_window.py @@ -16,10 +16,42 @@ import pytest import torch +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False, True], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False, True], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + @pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999]) @pytest.mark.parametrize("window_left", [3, 13, 23, 43]) @pytest.mark.parametrize("num_kv_heads", [1, 4]) diff --git a/tests/test_tensor_cores_decode.py b/tests/test_tensor_cores_decode.py index fc2c45f5c..bf312fb85 100644 --- a/tests/test_tensor_cores_decode.py +++ b/tests/test_tensor_cores_decode.py @@ -16,10 +16,42 @@ import pytest import torch +from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args import flashinfer +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + return + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + jit_prefill_attention_func_args( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # allow_fp16_qk_reductions + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + @pytest.mark.parametrize("kv_len", [54, 128, 999, 32789]) @pytest.mark.parametrize("num_kv_heads", [4, 8]) @pytest.mark.parametrize("group_size", [1, 4, 8])