Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Separate QK/VO head dim dispatch for sm90 AOT #778

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_dispatch_inc,
generate_single_decode_inst,
generate_single_prefill_inst,
)
Expand All @@ -48,19 +47,6 @@ def write_if_different(path: Path, content: str) -> None:

path.mkdir(parents=True, exist_ok=True)

write_if_different(
path / "dispatch.inc",
generate_dispatch_inc.get_dispatch_inc_str(
argparse.Namespace(
head_dims=head_dims,
head_dims_sm90=head_dims,
pos_encoding_modes=[0],
use_fp16_qk_reductions=[0],
mask_modes=mask_modes,
)
),
)

write_if_different(
path / "aot_default_additional_params.h",
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
Expand Down
8 changes: 5 additions & 3 deletions aot_build_utils/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
# head dims for sm90
dispatch_head_dims_sm90_entries = "\n".join(
[
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
for _ in args.head_dims_sm90
" _DISPATCH_CASE_U16x2({}, {}, case_var1, case_var2, __VA_ARGS__) \\".format(
qk, vo
)
for qk, vo in args.head_dims_sm90
]
)
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var, ...) \\
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var1, case_var2, ...) \\
{dispatch_head_dims_sm90_entries}
// EOL
"""
Expand Down
53 changes: 29 additions & 24 deletions aot_build_utils/generate_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import argparse
from itertools import product
from pathlib import Path
from typing import List
from typing import List, Tuple

from . import (
generate_batch_paged_prefill_sm90_inst,
Expand All @@ -33,7 +33,7 @@ def write_if_different(path: Path, content: str) -> None:
path.write_text(content)

path: Path = args.path
head_dims: List[int] = args.head_dims
head_dims: List[Tuple[int, int]] = args.head_dims
pos_encoding_modes: List[int] = args.pos_encoding_modes
use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions
mask_modes: List[int] = args.mask_modes
Expand All @@ -58,7 +58,7 @@ def write_if_different(path: Path, content: str) -> None:
# single prefill files
single_prefill_sm90_uris = []
for (
head_dim,
(head_dim_qk, head_dim_vo),
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -69,15 +69,15 @@ def write_if_different(path: Path, content: str) -> None:
mask_modes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)):
fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu"
fname = f"single_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu"
content = generate_single_prefill_sm90_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q,
dtype_kv,
dtype_q, # dtype_out
)
for use_sliding_window in [True, False]:
Expand All @@ -89,8 +89,8 @@ def write_if_different(path: Path, content: str) -> None:
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"head_dim_qk_{head_dim_qk}_"
f"head_dim_vo_{head_dim_vo}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
Expand All @@ -101,7 +101,7 @@ def write_if_different(path: Path, content: str) -> None:
# batch prefill files
batch_prefill_sm90_uris = []
for (
head_dim,
(head_dim_qk, head_dim_vo),
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -114,29 +114,29 @@ def write_if_different(path: Path, content: str) -> None:
idtypes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)):
fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
fname = f"batch_paged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
content = generate_batch_paged_prefill_sm90_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q,
dtype_kv,
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)

fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
fname = f"batch_ragged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
content = generate_batch_ragged_prefill_sm90_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q,
dtype_kv,
dtype_q, # dtype_out
idtype,
)
Expand All @@ -152,8 +152,8 @@ def write_if_different(path: Path, content: str) -> None:
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"head_dim_qk_{head_dim_qk}_"
f"head_dim_vo_{head_dim_vo}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
Expand All @@ -169,7 +169,11 @@ def write_if_different(path: Path, content: str) -> None:
"--path", type=Path, required=True, help="Path to the dispatch inc file"
)
parser.add_argument(
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"
"--head_dims",
type=str,
required=True,
nargs="+",
help="Head dimensions in format of 'head_dim_qk,head_dim_vo'",
)
parser.add_argument(
"--pos_encoding_modes",
Expand Down Expand Up @@ -207,4 +211,5 @@ def write_if_different(path: Path, content: str) -> None:
help="Enable bf16",
)
args = parser.parse_args()
args.head_dims = [tuple(map(int, x.split(","))) for x in args.head_dims]
get_sm90_instantiation_cu(args)
5 changes: 3 additions & 2 deletions csrc/aot_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
#define DISPATCH_head_dim(expr, const_expr, ...) \
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__))

#define DISPATCH_head_dim_sm90(expr, const_expr, ...) \
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim_sm90(const_expr, __VA_ARGS__))
#define DISPATCH_head_dim_sm90(expr1, expr2, const_expr1, const_expr2, ...) \
_DISPATCH_SWITCH_U16x2("head_dim_qk", "head_dim_vo", expr1, expr2, \
_DISPATCH_CASES_head_dim_sm90(const_expr1, const_expr2, __VA_ARGS__))

#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \
_DISPATCH_SWITCH("positional encoding mode", expr, \
Expand Down
3 changes: 1 addition & 2 deletions csrc/batch_prefill_sm90_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ using IdType = int32_t;
using DTypeO = DTypeQ; \
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
Expand Down
20 changes: 20 additions & 0 deletions csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,32 @@
} \
}()

#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
[&]() -> bool { \
switch (pack_u16(cond1, cond2)) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" \
<< int(cond1) << ", " << int(cond2) << ")"; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()

#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}

#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
case pack_u16(case_expr1, case_expr2): { \
constexpr auto case_var1 = case_expr1; \
constexpr auto case_var2 = case_expr2; \
return __VA_ARGS__(); \
}

#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
Expand Down
3 changes: 1 addition & 2 deletions csrc/single_prefill_sm90_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ using IdType = int32_t;
using DTypeKV = DTypeQ; \
using DTypeO = DTypeQ; \
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; \
return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
head_dims = list(map(int, head_dims))
SM90_ALLOWED_HEAD_DIMS = {64, 128, 256}
head_dims_sm90 = [d for d in head_dims if d in SM90_ALLOWED_HEAD_DIMS]
SM90_ALLOWED_HEAD_DIMS = {(64, 64), (128, 128), (256, 256), (192, 128)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(192, 128) is not compatible with page attention kernels, can we separate head_dim for page attention and ragged attention kernels?

head_dims_sm90 = list(SM90_ALLOWED_HEAD_DIMS) # No support for custom head dims for SM90

mask_modes = [0, 1, 2]

Expand Down