From 1ebbde36b1c10f2eba84cd84fb923ef2b906065d Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Mon, 3 Feb 2025 22:56:25 -0800 Subject: [PATCH] feat: Separate QK/VO head dim dispatch for sm90 AOT (#778) --- aot_build_utils/generate.py | 14 ------- aot_build_utils/generate_dispatch_inc.py | 8 ++-- aot_build_utils/generate_sm90.py | 53 +++++++++++++----------- csrc/aot_extension_utils.h | 5 ++- csrc/batch_prefill_sm90_config.inc | 3 +- csrc/pytorch_extension_utils.h | 20 +++++++++ csrc/single_prefill_sm90_config.inc | 3 +- setup.py | 4 +- 8 files changed, 61 insertions(+), 49 deletions(-) diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index c1c3eef6a..fe9b748a1 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -24,7 +24,6 @@ generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, - generate_dispatch_inc, generate_single_decode_inst, generate_single_prefill_inst, ) @@ -48,19 +47,6 @@ def write_if_different(path: Path, content: str) -> None: path.mkdir(parents=True, exist_ok=True) - write_if_different( - path / "dispatch.inc", - generate_dispatch_inc.get_dispatch_inc_str( - argparse.Namespace( - head_dims=head_dims, - head_dims_sm90=head_dims, - pos_encoding_modes=[0], - use_fp16_qk_reductions=[0], - mask_modes=mask_modes, - ) - ), - ) - write_if_different( path / "aot_default_additional_params.h", generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(), diff --git a/aot_build_utils/generate_dispatch_inc.py b/aot_build_utils/generate_dispatch_inc.py index fe4cf5f22..3f7ad94dc 100644 --- a/aot_build_utils/generate_dispatch_inc.py +++ b/aot_build_utils/generate_dispatch_inc.py @@ -35,11 +35,13 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: # head dims for sm90 dispatch_head_dims_sm90_entries = "\n".join( [ - " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_) - for _ in args.head_dims_sm90 + " _DISPATCH_CASE_U16x2({}, {}, case_var1, case_var2, __VA_ARGS__) \\".format( + qk, vo + ) + for qk, vo in args.head_dims_sm90 ] ) - dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var, ...) \\ + dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var1, case_var2, ...) \\ {dispatch_head_dims_sm90_entries} // EOL """ diff --git a/aot_build_utils/generate_sm90.py b/aot_build_utils/generate_sm90.py index 2466dd219..970f35918 100644 --- a/aot_build_utils/generate_sm90.py +++ b/aot_build_utils/generate_sm90.py @@ -17,7 +17,7 @@ import argparse from itertools import product from pathlib import Path -from typing import List +from typing import List, Tuple from . import ( generate_batch_paged_prefill_sm90_inst, @@ -33,7 +33,7 @@ def write_if_different(path: Path, content: str) -> None: path.write_text(content) path: Path = args.path - head_dims: List[int] = args.head_dims + head_dims: List[Tuple[int, int]] = args.head_dims pos_encoding_modes: List[int] = args.pos_encoding_modes use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions mask_modes: List[int] = args.mask_modes @@ -58,7 +58,7 @@ def write_if_different(path: Path, content: str) -> None: # single prefill files single_prefill_sm90_uris = [] for ( - head_dim, + (head_dim_qk, head_dim_vo), pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -69,15 +69,15 @@ def write_if_different(path: Path, content: str) -> None: mask_modes, ): for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): - fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu" + fname = f"single_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu" content = generate_single_prefill_sm90_inst.get_cu_file_str( - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, - dtype_q, # dtype_q - dtype_kv, # dtype_kv + dtype_q, + dtype_kv, dtype_q, # dtype_out ) for use_sliding_window in [True, False]: @@ -89,8 +89,8 @@ def write_if_different(path: Path, content: str) -> None: f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_q}_" - f"head_dim_qk_{head_dim}_" - f"head_dim_vo_{head_dim}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}_" @@ -101,7 +101,7 @@ def write_if_different(path: Path, content: str) -> None: # batch prefill files batch_prefill_sm90_uris = [] for ( - head_dim, + (head_dim_qk, head_dim_vo), pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -114,29 +114,29 @@ def write_if_different(path: Path, content: str) -> None: idtypes, ): for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): - fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + fname = f"batch_paged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" content = generate_batch_paged_prefill_sm90_inst.get_cu_file_str( - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, - dtype_q, # dtype_q - dtype_kv, # dtype_kv + dtype_q, + dtype_kv, dtype_q, # dtype_out idtype, ) write_if_different(path / fname, content) - fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + fname = f"batch_ragged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" content = generate_batch_ragged_prefill_sm90_inst.get_cu_file_str( - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, - dtype_q, # dtype_q - dtype_kv, # dtype_kv + dtype_q, + dtype_kv, dtype_q, # dtype_out idtype, ) @@ -152,8 +152,8 @@ def write_if_different(path: Path, content: str) -> None: f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_q}_" f"dtype_idx_{idtype}_" - f"head_dim_qk_{head_dim}_" - f"head_dim_vo_{head_dim}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{sliding_window}_" f"use_logits_cap_{logits_soft_cap}_" @@ -169,7 +169,11 @@ def write_if_different(path: Path, content: str) -> None: "--path", type=Path, required=True, help="Path to the dispatch inc file" ) parser.add_argument( - "--head_dims", type=int, required=True, nargs="+", help="Head dimensions" + "--head_dims", + type=str, + required=True, + nargs="+", + help="Head dimensions in format of 'head_dim_qk,head_dim_vo'", ) parser.add_argument( "--pos_encoding_modes", @@ -207,4 +211,5 @@ def write_if_different(path: Path, content: str) -> None: help="Enable bf16", ) args = parser.parse_args() + args.head_dims = [tuple(map(int, x.split(","))) for x in args.head_dims] get_sm90_instantiation_cu(args) diff --git a/csrc/aot_extension_utils.h b/csrc/aot_extension_utils.h index 6d398e3e9..acc9ddc5f 100644 --- a/csrc/aot_extension_utils.h +++ b/csrc/aot_extension_utils.h @@ -19,8 +19,9 @@ #define DISPATCH_head_dim(expr, const_expr, ...) \ _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) -#define DISPATCH_head_dim_sm90(expr, const_expr, ...) \ - _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim_sm90(const_expr, __VA_ARGS__)) +#define DISPATCH_head_dim_sm90(expr1, expr2, const_expr1, const_expr2, ...) \ + _DISPATCH_SWITCH_U16x2("head_dim_qk", "head_dim_vo", expr1, expr2, \ + _DISPATCH_CASES_head_dim_sm90(const_expr1, const_expr2, __VA_ARGS__)) #define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ _DISPATCH_SWITCH("positional encoding mode", expr, \ diff --git a/csrc/batch_prefill_sm90_config.inc b/csrc/batch_prefill_sm90_config.inc index d344915f9..e74f1a40b 100644 --- a/csrc/batch_prefill_sm90_config.inc +++ b/csrc/batch_prefill_sm90_config.inc @@ -41,8 +41,7 @@ using IdType = int32_t; using DTypeO = DTypeQ; \ using RaggedParams = BatchPrefillRaggedParams; \ using PagedParams = BatchPrefillPagedParams; \ - return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \ - [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \ return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ using AttentionVariant = DefaultAttention; \ diff --git a/csrc/pytorch_extension_utils.h b/csrc/pytorch_extension_utils.h index 2544084b4..ba965a95d 100644 --- a/csrc/pytorch_extension_utils.h +++ b/csrc/pytorch_extension_utils.h @@ -122,12 +122,32 @@ } \ }() +#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ + [&]() -> bool { \ + switch (pack_u16(cond1, cond2)) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" \ + << int(cond1) << ", " << int(cond2) << ")"; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + #define _DISPATCH_CASE(case_expr, case_var, ...) \ case case_expr: { \ constexpr auto case_var = case_expr; \ return __VA_ARGS__(); \ } +#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ + case pack_u16(case_expr1, case_expr2): { \ + constexpr auto case_var1 = case_expr1; \ + constexpr auto case_var2 = case_expr2; \ + return __VA_ARGS__(); \ + } + #define DISPATCH_BOOL(expr, const_expr, ...) \ [&]() -> bool { \ if (expr) { \ diff --git a/csrc/single_prefill_sm90_config.inc b/csrc/single_prefill_sm90_config.inc index 2ec696b7f..69bc3c63f 100644 --- a/csrc/single_prefill_sm90_config.inc +++ b/csrc/single_prefill_sm90_config.inc @@ -39,8 +39,7 @@ using IdType = int32_t; using DTypeKV = DTypeQ; \ using DTypeO = DTypeQ; \ using Params = SinglePrefillParams; \ - return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \ - [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \ return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ using AttentionVariant = DefaultAttention; \ diff --git a/setup.py b/setup.py index 827c1c2a3..4c238af57 100644 --- a/setup.py +++ b/setup.py @@ -29,8 +29,8 @@ head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") head_dims = list(map(int, head_dims)) -SM90_ALLOWED_HEAD_DIMS = {64, 128, 256} -head_dims_sm90 = [d for d in head_dims if d in SM90_ALLOWED_HEAD_DIMS] +SM90_ALLOWED_HEAD_DIMS = {(64, 64), (128, 128), (256, 256), (192, 128)} +head_dims_sm90 = list(SM90_ALLOWED_HEAD_DIMS) # No support for custom head dims for SM90 mask_modes = [0, 1, 2]