Skip to content

Commit

Permalink
bugfix: fix JIT compilation of batch prefill attention kernels (#670)
Browse files Browse the repository at this point in the history
The batch prefill attention JIT template was broken in
#635 because we messed
up some jinja syntax. This PR fixes the issue.
  • Loading branch information
yzh119 authored Dec 17, 2024
1 parent 73aa00e commit 8f92670
Showing 1 changed file with 14 additions and 31 deletions.
45 changes: 14 additions & 31 deletions flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""
Copyright(c) 2024 by FlashInfer team.
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License,
Version 2.0(the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Expand Down Expand Up @@ -32,18 +31,13 @@ def ragged_prefill_inst_templ(mask_mode: str) -> str:
#include <flashinfer/attention/variants.cuh>
namespace flashinfer {
{
% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %
}
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using RaggedParamsT =
BatchPrefillRaggedParams<{{dtype_q}}, {{dtype_kv}}, {{dtype_o}}, {{dtype_idx}}>;
constexpr bool use_custom_mask =
""
"
+ mask_mode +
r
""
" == MaskMode::kCustom;
"""
+ mask_mode
+ r""" == MaskMode::kCustom;
using RaggedAttentionVariant =
ComposedAttention<RaggedParamsT,
get_variant_code(use_custom_mask, {{use_sliding_window}},
Expand Down Expand Up @@ -84,18 +78,13 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
#include <flashinfer/attention/variants.cuh>
namespace flashinfer {
{
% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %
}
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using PagedParamsT =
BatchPrefillPagedParams<{{dtype_q}}, {{dtype_kv}}, {{dtype_o}}, {{dtype_idx}}>;
constexpr bool use_custom_mask =
""
"
+ mask_mode +
r
""
" == MaskMode::kCustom;
"""
+ mask_mode
+ r""" == MaskMode::kCustom;
using PagedAttentionVariant =
ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{use_sliding_window}},
{{use_logits_soft_cap}}, {{use_alibi}})>;
Expand Down Expand Up @@ -183,8 +172,7 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
using namespace flashinfer;
{
% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
namespace flashinfer {
Expand Down Expand Up @@ -245,9 +233,7 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
(maybe_qk_indptr ? static_cast<{{dtype_idx}}*>(maybe_qk_indptr->data_ptr()) : nullptr),
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast<{{dtype_o}}*>(o.data_ptr()),
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
{ % if use_alibi == "true" % } static_cast<float*>(maybe_alibi_slopes->data_ptr()) {
% else %
} nullptr { % endif % },
{% if use_alibi == "true" %} static_cast<float*>(maybe_alibi_slopes->data_ptr()) {% else %} nullptr {% endif %},
num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);
Expand Down Expand Up @@ -316,8 +302,7 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
using namespace flashinfer;
{
% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
namespace flashinfer {
Expand Down Expand Up @@ -397,9 +382,7 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
(maybe_qk_indptr ? static_cast<{{dtype_idx}}*>(maybe_qk_indptr->data_ptr()) : nullptr),
/*q_offset=*/nullptr, static_cast<{{dtype_o}}*>(o.data_ptr()),
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
{ % if use_alibi == "true" % } static_cast<float*>(maybe_alibi_slopes->data_ptr()) {
% else %
} nullptr { % endif % },
{% if use_alibi == "true" %} static_cast<float*>(maybe_alibi_slopes->data_ptr()) {% else %} nullptr {% endif %},
num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale,
rope_theta);
Expand Down

0 comments on commit 8f92670

Please sign in to comment.