Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support fused masking in Attention #1924

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2024-25 Apple Inc.

using namespace mlx::steel;

Expand All @@ -9,6 +9,9 @@ using namespace mlx::steel;
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];

constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];

template <typename T>
struct TransformScale {
T scale;
Expand Down Expand Up @@ -69,13 +72,16 @@ template <
int BD,
int WM,
int WN,
typename MaskType = float,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
Expand All @@ -102,6 +108,12 @@ template <
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce

if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch
tidl.y * mask_params->M_strides[1] + // Head
tidl.x * BQ * mask_params->M_strides[2]; // Seqeunce
}

// Prepare threadgroup memory
constexpr short padQ = 16 / sizeof(T);
constexpr short padK = 16 / sizeof(T);
Expand Down Expand Up @@ -203,7 +215,7 @@ template <

// Load Q blocks apply scale
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
loader_q.load_safe(short2(BD, params->qL_rem));
} else {
loader_q.load_unsafe();
}
Expand All @@ -221,12 +233,39 @@ template <
max_score[i] = Limits<AccumType>::min;
}

int kb_lim = params->NK;

if (false && do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;

// Exit early
if (kb_lim <= 0) {
// Store results
O += (tm + sm) * params->O_strides[2] + sn;

if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));

if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;

Otile.template store_safe<T, 1, 1>(
O, params->O_strides[2], dst_tile_dims);
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}

return;
}
}

// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) {
for (int kb = 0; kb < kb_lim; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
loader_k.load_safe(short2(BD, params->kL_rem));
} else {
loader_k.load_unsafe();
}
Expand Down Expand Up @@ -255,7 +294,6 @@ template <
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
const short lim = params->kL - params->NK_aligned * BK;

STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
Expand All @@ -264,7 +302,29 @@ template <
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= lim) {
if ((col_pos + jj) >= params->kL_rem) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}

// Mask out of causal
if (do_causal) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();

STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
short row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
short col_pos = kb * BK + sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if (row_pos < (col_pos + jj)) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
Expand All @@ -276,7 +336,7 @@ template <

// Load V blocks
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
loader_v.load_safe(short2(BD, params->kL_rem));
} else {
loader_v.load_unsafe();
}
Expand Down Expand Up @@ -367,8 +427,7 @@ template <
O += (tm + sm) * params->O_strides[2] + sn;

if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims =
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));

if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Expand Down
37 changes: 17 additions & 20 deletions mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2024-25 Apple Inc.

// clang-format off
#include "mlx/backend/metal/kernels/utils.h"

#include "mlx/backend/metal/kernels/steel/attn/attn.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"

#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
instantiate_kernel( \
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
"_wm" #wm "_wn" #wn "_mask" #mname, \
attention, dtype, bq, bk, bd, wm, wn, mtype, float)

#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)

instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
#define instantiate_attn_mask_helper(iname, itype) \
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
instantiate_attn_shapes_helper(iname, itype, bool_, bool)

instantiate_attn_shapes_helper(float32, float);
instantiate_attn_mask_helper(float16, half);
instantiate_attn_mask_helper(bfloat16, bfloat16_t);

instantiate_attn_mask_helper(float32, float);
// clang-format on
8 changes: 8 additions & 0 deletions mlx/backend/metal/kernels/steel/attn/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@ struct AttnParams {
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks

int qL_rem; ///< Remainder in last query block
int kL_rem; ///< Remainder in last key/value block
int qL_off; ///< Offset in query sequence start

int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};

struct AttnMaskParams {
int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
};

} // namespace steel
} // namespace mlx
38 changes: 32 additions & 6 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal(
const array& k,
const array& v,
const float scale,
array& o) {
array& o,
bool do_causal_ = false,
const std::optional<array>& mask = std::nullopt) {
using namespace mlx::steel;

int wm = 4;
Expand All @@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal(

const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
const bool has_mask = !!mask;
const bool do_causal = do_causal_;

metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
};
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301}};

std::ostringstream kname;
// clang-format off
Expand All @@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal(
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm << "_wn" << wn; // clang-format on
<< "_wm" << wm
<< "_wn" << wn
<< "_mask" << (type_to_name(mask ? *mask : q)); // clang-format on

std::string base_name = kname.str();

// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_has_mask_" << (has_mask ? 't' : 'n')
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on

std::string hash_name = kname.str();

Expand Down Expand Up @@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal(
/* int NQ_aligned = */ NQ_aligned,
/* int NK_aligned = */ NK_aligned,

/* int qL_rem = */ (qL - NQ_aligned * bq),
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
Expand All @@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);

if (mask) {
auto m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}};

compute_encoder.set_bytes(mask_params, 5);
compute_encoder.set_input_array(m, 6);
}

MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);

Expand Down Expand Up @@ -345,7 +367,7 @@ void ScaledDotProductAttention::eval_gpu(

// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(3) == 1;
return arr.strides(-1) == 1;
};

// We are in vector mode ie single query
Expand Down Expand Up @@ -414,7 +436,11 @@ void ScaledDotProductAttention::eval_gpu(
{str_oB, str_oH, str_oL, str_oD},
flags);

sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
auto mask = inputs.size() > 3
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
: std::nullopt;

sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
}

d.add_temporaries(std::move(copies), s.index);
Expand Down
Loading