Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Dense and sparse customizable flashattention-3 template #667

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,25 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

-------------------------------------------------------------------------------------------------
Some of the code in this project are adapted from other open-source projects with different
licenses. This product also bundles some third-party components under other open source licenses.
This section summarizes those components and their licenses.
See licenses/ for text of these licenses.

BSD 3-Clause License
--------------------

include/flashinfer/attention/hopper/epilogue.cuh
include/flashinfer/attention/hopper/mainloop.cuh
include/flashinfer/attention/hopper/kernel_traits.cuh
include/flashinfer/attention/hopper/named_barrier.cuh
include/flashinfer/attention/hopper/tile_scheduler.cuh
include/flashinfer/attention/hopper/utils.cuh

BSD 3-Clause "New" License
--------------------------

3rdparty/cutlass
include/flashinfer/attention/hopper/block_sparse_gather.cuh
96 changes: 96 additions & 0 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):
def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>


namespace flashinfer {{

using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;

using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;

{get_insts("LogitsSoftCap")}

{get_insts("StandardAttention")}

}}"""
return content


if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)

with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
97 changes: 97 additions & 0 deletions aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):

def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>


namespace flashinfer {{

using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;

using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;

{get_insts("LogitsSoftCap")}

{get_insts("StandardAttention")}

}}
"""
return content


if __name__ == "__main__":
pattern = (
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
85 changes: 85 additions & 0 deletions aot_build_utils/generate_single_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
):
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>

namespace flashinfer {{

using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;

using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>(
Params& params,
cudaStream_t stream);

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>(
Params& params,
cudaStream_t stream);

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>(
Params& params,
cudaStream_t stream);

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>(
Params& params,
cudaStream_t stream);
}}
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
dtype_q=dtype_literal[dtype_q],
dtype_kv=dtype_literal[dtype_kv],
dtype_out=dtype_literal[dtype_out],
use_custom_mask="true" if int(mask_mode) == 2 else "false",
)
return content


if __name__ == "__main__":
pattern = (
r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu"
)

compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
Loading