diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 077bbfed49e08..9677c30f22d8a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -171,7 +171,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { args.v_head_size = v_head_size; args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast(qk_head_size)) : scale_; /* - block_size_q, block_size_kv correspond to Br, Bc in the FlashAttention paper. + q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper. Let M = l2_cache_size / sizeof(float) In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache: slice of Q -- [Br, qk_head_size] @@ -190,17 +190,17 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { 1. storing small tensors l and m 2. instruction (code) */ - args.block_size_kv = l2_cache_size_ / (static_cast(sizeof(float)) * 4 * (qk_head_size + v_head_size)); - args.block_size_kv = std::max(args.block_size_kv, 1); // avoid block_size_kv = 0 - args.block_size_q = std::min(args.block_size_kv, qk_head_size + v_head_size); - args.block_size_kv = std::min(args.block_size_kv, kv_sequence_length); // No point to have block_size_kv > kv_sequence_length - args.block_size_q = std::min(args.block_size_q, q_sequence_length); // No point to have block_size_q > q_sequence_length + args.kv_block_size = l2_cache_size_ / (static_cast(sizeof(float)) * 4 * (qk_head_size + v_head_size)); + args.kv_block_size = std::max(args.kv_block_size, 1); // avoid kv_block_size = 0 + args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size); + args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); // No point to have kv_block_size > kv_sequence_length + args.q_block_size = std::min(args.q_block_size, q_sequence_length); // No point to have q_block_size > q_sequence_length auto* tp = context->GetOperatorThreadPool(); args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); - args.buffer_size_per_thread = (static_cast(args.block_size_q) * 2 + - static_cast(args.block_size_q) * static_cast(args.block_size_kv) + - static_cast(args.block_size_q) * static_cast(args.v_head_size)) * + args.buffer_size_per_thread = (static_cast(args.q_block_size) * 2 + + static_cast(args.q_block_size) * static_cast(args.kv_block_size) + + static_cast(args.q_block_size) * static_cast(args.v_head_size)) * sizeof(float); size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count; IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_bytes); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ffff5fb211e66..675f7c7a13e8c 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1833,8 +1833,8 @@ struct MlasFlashAttentionThreadedArgs { int kv_sequence_length; int qk_head_size; int v_head_size; - int block_size_q; - int block_size_kv; + int q_block_size; + int kv_block_size; float scale; int thread_count; float* buffer; diff --git a/onnxruntime/core/mlas/lib/flashattn.cpp b/onnxruntime/core/mlas/lib/flashattn.cpp index 3cde54e9ba6af..fe5402ed144aa 100644 --- a/onnxruntime/core/mlas/lib/flashattn.cpp +++ b/onnxruntime/core/mlas/lib/flashattn.cpp @@ -9,8 +9,8 @@ MlasFlashAttentionThreaded( ) { const MlasFlashAttentionThreadedArgs* args = reinterpret_cast(argptr); - ptrdiff_t block_size_q = static_cast(args->block_size_q); - ptrdiff_t block_size_kv = static_cast(args->block_size_kv); + ptrdiff_t q_block_size = static_cast(args->q_block_size); + ptrdiff_t kv_block_size = static_cast(args->kv_block_size); ptrdiff_t batch_size = static_cast(args->batch_size); ptrdiff_t num_heads = static_cast(args->num_heads); ptrdiff_t q_sequence_length = static_cast(args->q_sequence_length); @@ -29,7 +29,7 @@ MlasFlashAttentionThreaded( auto&& mlas_platform = GetMlasPlatform(); #endif - ptrdiff_t q_chunk_count = (q_sequence_length + (block_size_q - 1)) / block_size_q; + ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size; ptrdiff_t task_start = 0; ptrdiff_t task_end = 0; @@ -46,38 +46,38 @@ MlasFlashAttentionThreaded( for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { ptrdiff_t batch_idx = task_index; - ptrdiff_t q_idx = (batch_idx % q_chunk_count) * block_size_q; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; batch_idx /= q_chunk_count; ptrdiff_t head_idx = batch_idx % num_heads; batch_idx /= num_heads; char* buffer_current_thread = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; float* l = reinterpret_cast(buffer_current_thread); - float* m = l + block_size_q; - for (ptrdiff_t t = 0; t < block_size_q; ++t) { + float* m = l + q_block_size; + for (ptrdiff_t t = 0; t < q_block_size; ++t) { m[t] = std::numeric_limits::lowest(); } - float* intermediate = m + block_size_q; - float* temp_output = intermediate + block_size_q * block_size_kv; + float* intermediate = m + q_block_size; + float* temp_output = intermediate + q_block_size * kv_block_size; float negmax = 0; - for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += block_size_kv) { + for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) { /* - S = Q[batch_idx, head_idx, q_idx:q_idx+block_size_q, :] * (K[batch_idx, head_idx, ir:ir+block_size_kv, :]).T + S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T old_m = m m = max(m, rowmax(S)) diff = old_m - m S = exp(S - m) l = exp(diff) * l + rowsum(S) - O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+block_size_kv, :] + O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :] */ ptrdiff_t h = batch_idx * num_heads + head_idx; const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size; const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size; const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size; - size_t row_size_q_capped = static_cast(std::min(block_size_q, q_sequence_length - q_idx)); - size_t row_size_kv_capped = static_cast(std::min(block_size_kv, kv_sequence_length - ir)); + size_t row_size_q_capped = static_cast(std::min(q_block_size, q_sequence_length - q_idx)); + size_t row_size_kv_capped = static_cast(std::min(kv_block_size, kv_sequence_length - ir)); MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans, @@ -141,7 +141,7 @@ MlasFlashAttentionThreaded( } float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size; - ptrdiff_t row_size_q_valid = std::min(block_size_q, q_sequence_length - q_idx); + ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx); // TODO: leverage advanced instruction sets for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) { for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {