Skip to content

Commit

Permalink
Renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
duanqn committed Jul 11, 2024
1 parent e1cf289 commit 852fd98
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
18 changes: 9 additions & 9 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
args.v_head_size = v_head_size;
args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast<float>(qk_head_size)) : scale_;
/*
block_size_q, block_size_kv correspond to Br, Bc in the FlashAttention paper.
q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper.
Let M = l2_cache_size / sizeof(float)
In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache:
slice of Q -- [Br, qk_head_size]
Expand All @@ -190,17 +190,17 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
1. storing small tensors l and m
2. instruction (code)
*/
args.block_size_kv = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
args.block_size_kv = std::max(args.block_size_kv, 1); // avoid block_size_kv = 0
args.block_size_q = std::min(args.block_size_kv, qk_head_size + v_head_size);
args.block_size_kv = std::min(args.block_size_kv, kv_sequence_length); // No point to have block_size_kv > kv_sequence_length
args.block_size_q = std::min(args.block_size_q, q_sequence_length); // No point to have block_size_q > q_sequence_length
args.kv_block_size = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
args.kv_block_size = std::max(args.kv_block_size, 1); // avoid kv_block_size = 0
args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size);
args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); // No point to have kv_block_size > kv_sequence_length
args.q_block_size = std::min(args.q_block_size, q_sequence_length); // No point to have q_block_size > q_sequence_length

auto* tp = context->GetOperatorThreadPool();
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);
args.buffer_size_per_thread = (static_cast<size_t>(args.block_size_q) * 2 +
static_cast<size_t>(args.block_size_q) * static_cast<size_t>(args.block_size_kv) +
static_cast<size_t>(args.block_size_q) * static_cast<size_t>(args.v_head_size)) *
args.buffer_size_per_thread = (static_cast<size_t>(args.q_block_size) * 2 +
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.kv_block_size) +
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.v_head_size)) *
sizeof(float);
size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count;
IAllocatorUniquePtr<void> buffer = IAllocator::MakeUniquePtr<void>(allocator, buffer_bytes);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1833,8 +1833,8 @@ struct MlasFlashAttentionThreadedArgs {
int kv_sequence_length;
int qk_head_size;
int v_head_size;
int block_size_q;
int block_size_kv;
int q_block_size;
int kv_block_size;
float scale;
int thread_count;
float* buffer;
Expand Down
28 changes: 14 additions & 14 deletions onnxruntime/core/mlas/lib/flashattn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ MlasFlashAttentionThreaded(
)
{
const MlasFlashAttentionThreadedArgs* args = reinterpret_cast<MlasFlashAttentionThreadedArgs*>(argptr);
ptrdiff_t block_size_q = static_cast<ptrdiff_t>(args->block_size_q);
ptrdiff_t block_size_kv = static_cast<ptrdiff_t>(args->block_size_kv);
ptrdiff_t q_block_size = static_cast<ptrdiff_t>(args->q_block_size);
ptrdiff_t kv_block_size = static_cast<ptrdiff_t>(args->kv_block_size);
ptrdiff_t batch_size = static_cast<ptrdiff_t>(args->batch_size);
ptrdiff_t num_heads = static_cast<ptrdiff_t>(args->num_heads);
ptrdiff_t q_sequence_length = static_cast<ptrdiff_t>(args->q_sequence_length);
Expand All @@ -29,7 +29,7 @@ MlasFlashAttentionThreaded(
auto&& mlas_platform = GetMlasPlatform();
#endif

ptrdiff_t q_chunk_count = (q_sequence_length + (block_size_q - 1)) / block_size_q;
ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size;

ptrdiff_t task_start = 0;
ptrdiff_t task_end = 0;
Expand All @@ -46,38 +46,38 @@ MlasFlashAttentionThreaded(

for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) {
ptrdiff_t batch_idx = task_index;
ptrdiff_t q_idx = (batch_idx % q_chunk_count) * block_size_q;
ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size;
batch_idx /= q_chunk_count;
ptrdiff_t head_idx = batch_idx % num_heads;
batch_idx /= num_heads;

char* buffer_current_thread = reinterpret_cast<char*>(buffer) + thread_id * buffer_size_per_thread;
float* l = reinterpret_cast<float*>(buffer_current_thread);
float* m = l + block_size_q;
for (ptrdiff_t t = 0; t < block_size_q; ++t) {
float* m = l + q_block_size;
for (ptrdiff_t t = 0; t < q_block_size; ++t) {
m[t] = std::numeric_limits<float>::lowest();
}
float* intermediate = m + block_size_q;
float* temp_output = intermediate + block_size_q * block_size_kv;
float* intermediate = m + q_block_size;
float* temp_output = intermediate + q_block_size * kv_block_size;
float negmax = 0;

for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += block_size_kv) {
for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) {
/*
S = Q[batch_idx, head_idx, q_idx:q_idx+block_size_q, :] * (K[batch_idx, head_idx, ir:ir+block_size_kv, :]).T
S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T
old_m = m
m = max(m, rowmax(S))
diff = old_m - m
S = exp(S - m)
l = exp(diff) * l + rowsum(S)
O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+block_size_kv, :]
O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :]
*/
ptrdiff_t h = batch_idx * num_heads + head_idx;
const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size;
const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size;
const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size;

size_t row_size_q_capped = static_cast<size_t>(std::min(block_size_q, q_sequence_length - q_idx));
size_t row_size_kv_capped = static_cast<size_t>(std::min(block_size_kv, kv_sequence_length - ir));
size_t row_size_q_capped = static_cast<size_t>(std::min(q_block_size, q_sequence_length - q_idx));
size_t row_size_kv_capped = static_cast<size_t>(std::min(kv_block_size, kv_sequence_length - ir));

MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasTrans,
Expand Down Expand Up @@ -141,7 +141,7 @@ MlasFlashAttentionThreaded(
}

float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size;
ptrdiff_t row_size_q_valid = std::min(block_size_q, q_sequence_length - q_idx);
ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx);
// TODO: leverage advanced instruction sets
for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) {
for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
Expand Down

0 comments on commit 852fd98

Please sign in to comment.