Skip to content

Commit

Permalink
Check for T==float
Browse files Browse the repository at this point in the history
  • Loading branch information
duanqn committed Jun 17, 2024
1 parent 93dab52 commit c8c12ff
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/mlas/inc/mlas.h"
#include "core/mlas/inc/mlas_flashattn.h"

#include <type_traits>
#include <unsupported/Eigen/SpecialFunctions>
#include <vector>

Expand Down Expand Up @@ -142,7 +143,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));

if(!disable_flash_ && key_padding_mask == nullptr && extra_add_qk == nullptr && past_key == nullptr && past_value == nullptr){
if(std::is_same_v<T, float> && !disable_flash_ && key_padding_mask == nullptr && extra_add_qk == nullptr && past_key == nullptr && past_value == nullptr){
FlashAttentionThreadedArgs args;
args.batch_size = batch_size;
args.num_heads = num_heads_;
Expand All @@ -153,20 +154,20 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {

const auto& env = Env::Default();
int l2_cache_size = env.GetL2CacheSize();
args.row_size_kv = l2_cache_size / sizeof(T) / 4 / (qk_head_size + v_head_size);
args.row_size_kv = l2_cache_size / sizeof(float) / 4 / (qk_head_size + v_head_size);
args.row_size_q = std::min(args.row_size_kv, qk_head_size + v_head_size);

auto* tp = context->GetOperatorThreadPool();
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);

args.buffer_size_per_thread = args.row_size_q * 2 + args.row_size_q * args.row_size_kv + args.row_size_q * args.v_head_size;
args.buffer = static_cast<float*>(allocator->AllocArray(args.buffer_size_per_thread * args.thread_count, sizeof(T)));
args.buffer_size_per_thread *= sizeof(T);
args.buffer_size_per_thread *= sizeof(float);

args.query = Q.Get<Tensor>().Data<T>();
args.key = K.Get<Tensor>().Data<T>();
args.value = V.Get<Tensor>().Data<T>();
args.output = output->MutableData<T>();
args.query = Q.Get<Tensor>().Data<float>();
args.key = K.Get<Tensor>().Data<float>();
args.value = V.Get<Tensor>().Data<float>();
args.output = output->MutableData<float>();

concurrency::ThreadPool::TrySimpleParallelFor(tp, args.thread_count, [&](std::ptrdiff_t thread_id){
FlashAttentionThreaded(thread_id, &args);
Expand Down

0 comments on commit c8c12ff

Please sign in to comment.