Skip to content

Commit

Permalink
feat: Separate QK/VO head dim dispatch for sm90 AOT (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdabcd987 authored Feb 4, 2025
1 parent fc03772 commit 1ebbde3
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 49 deletions.
14 changes: 0 additions & 14 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_dispatch_inc,
generate_single_decode_inst,
generate_single_prefill_inst,
)
Expand All @@ -48,19 +47,6 @@ def write_if_different(path: Path, content: str) -> None:

path.mkdir(parents=True, exist_ok=True)

write_if_different(
path / "dispatch.inc",
generate_dispatch_inc.get_dispatch_inc_str(
argparse.Namespace(
head_dims=head_dims,
head_dims_sm90=head_dims,
pos_encoding_modes=[0],
use_fp16_qk_reductions=[0],
mask_modes=mask_modes,
)
),
)

write_if_different(
path / "aot_default_additional_params.h",
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
Expand Down
8 changes: 5 additions & 3 deletions aot_build_utils/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
# head dims for sm90
dispatch_head_dims_sm90_entries = "\n".join(
[
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
for _ in args.head_dims_sm90
" _DISPATCH_CASE_U16x2({}, {}, case_var1, case_var2, __VA_ARGS__) \\".format(
qk, vo
)
for qk, vo in args.head_dims_sm90
]
)
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var, ...) \\
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var1, case_var2, ...) \\
{dispatch_head_dims_sm90_entries}
// EOL
"""
Expand Down
53 changes: 29 additions & 24 deletions aot_build_utils/generate_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import argparse
from itertools import product
from pathlib import Path
from typing import List
from typing import List, Tuple

from . import (
generate_batch_paged_prefill_sm90_inst,
Expand All @@ -33,7 +33,7 @@ def write_if_different(path: Path, content: str) -> None:
path.write_text(content)

path: Path = args.path
head_dims: List[int] = args.head_dims
head_dims: List[Tuple[int, int]] = args.head_dims
pos_encoding_modes: List[int] = args.pos_encoding_modes
use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions
mask_modes: List[int] = args.mask_modes
Expand All @@ -58,7 +58,7 @@ def write_if_different(path: Path, content: str) -> None:
# single prefill files
single_prefill_sm90_uris = []
for (
head_dim,
(head_dim_qk, head_dim_vo),
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -69,15 +69,15 @@ def write_if_different(path: Path, content: str) -> None:
mask_modes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)):
fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu"
fname = f"single_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu"
content = generate_single_prefill_sm90_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q,
dtype_kv,
dtype_q, # dtype_out
)
for use_sliding_window in [True, False]:
Expand All @@ -89,8 +89,8 @@ def write_if_different(path: Path, content: str) -> None:
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"head_dim_qk_{head_dim_qk}_"
f"head_dim_vo_{head_dim_vo}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
Expand All @@ -101,7 +101,7 @@ def write_if_different(path: Path, content: str) -> None:
# batch prefill files
batch_prefill_sm90_uris = []
for (
head_dim,
(head_dim_qk, head_dim_vo),
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -114,29 +114,29 @@ def write_if_different(path: Path, content: str) -> None:
idtypes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)):
fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
fname = f"batch_paged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
content = generate_batch_paged_prefill_sm90_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q,
dtype_kv,
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)

fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
fname = f"batch_ragged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
content = generate_batch_ragged_prefill_sm90_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q,
dtype_kv,
dtype_q, # dtype_out
idtype,
)
Expand All @@ -152,8 +152,8 @@ def write_if_different(path: Path, content: str) -> None:
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"head_dim_qk_{head_dim_qk}_"
f"head_dim_vo_{head_dim_vo}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
Expand All @@ -169,7 +169,11 @@ def write_if_different(path: Path, content: str) -> None:
"--path", type=Path, required=True, help="Path to the dispatch inc file"
)
parser.add_argument(
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"
"--head_dims",
type=str,
required=True,
nargs="+",
help="Head dimensions in format of 'head_dim_qk,head_dim_vo'",
)
parser.add_argument(
"--pos_encoding_modes",
Expand Down Expand Up @@ -207,4 +211,5 @@ def write_if_different(path: Path, content: str) -> None:
help="Enable bf16",
)
args = parser.parse_args()
args.head_dims = [tuple(map(int, x.split(","))) for x in args.head_dims]
get_sm90_instantiation_cu(args)
5 changes: 3 additions & 2 deletions csrc/aot_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
#define DISPATCH_head_dim(expr, const_expr, ...) \
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__))

#define DISPATCH_head_dim_sm90(expr, const_expr, ...) \
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim_sm90(const_expr, __VA_ARGS__))
#define DISPATCH_head_dim_sm90(expr1, expr2, const_expr1, const_expr2, ...) \
_DISPATCH_SWITCH_U16x2("head_dim_qk", "head_dim_vo", expr1, expr2, \
_DISPATCH_CASES_head_dim_sm90(const_expr1, const_expr2, __VA_ARGS__))

#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \
_DISPATCH_SWITCH("positional encoding mode", expr, \
Expand Down
3 changes: 1 addition & 2 deletions csrc/batch_prefill_sm90_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ using IdType = int32_t;
using DTypeO = DTypeQ; \
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
Expand Down
20 changes: 20 additions & 0 deletions csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,32 @@
} \
}()

#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
[&]() -> bool { \
switch (pack_u16(cond1, cond2)) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" \
<< int(cond1) << ", " << int(cond2) << ")"; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()

#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}

#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
case pack_u16(case_expr1, case_expr2): { \
constexpr auto case_var1 = case_expr1; \
constexpr auto case_var2 = case_expr2; \
return __VA_ARGS__(); \
}

#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
Expand Down
3 changes: 1 addition & 2 deletions csrc/single_prefill_sm90_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ using IdType = int32_t;
using DTypeKV = DTypeQ; \
using DTypeO = DTypeQ; \
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; \
return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
head_dims = list(map(int, head_dims))
SM90_ALLOWED_HEAD_DIMS = {64, 128, 256}
head_dims_sm90 = [d for d in head_dims if d in SM90_ALLOWED_HEAD_DIMS]
SM90_ALLOWED_HEAD_DIMS = {(64, 64), (128, 128), (256, 256), (192, 128)}
head_dims_sm90 = list(SM90_ALLOWED_HEAD_DIMS) # No support for custom head dims for SM90

mask_modes = [0, 1, 2]

Expand Down

0 comments on commit 1ebbde3

Please sign in to comment.