From ed066e1f1135a903bcad7d691197e863ba214d6d Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sun, 26 Jan 2025 14:55:15 -0800 Subject: [PATCH 01/11] Port over FA --- .../webgpu/bert/flash_attention.cc | 422 ++++++++++++++++++ .../contrib_ops/webgpu/bert/flash_attention.h | 73 +++ .../webgpu/bert/multihead_attention.cc | 6 + 3 files changed, 501 insertions(+) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc new file mode 100644 index 0000000000000..45b2576a231cf --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -0,0 +1,422 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(total_sequence_length) + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (has_past_) { + shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + } + shader.AddOutput("present_key", ShaderUsage::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" + << "let kIdx = workgroup_id.x;\n" + << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"; + if (has_past_) { + shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n" + << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n" + << " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n" + << " }\n" + << "}\n" + << "else if (kIdx >= uniforms.past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" + << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" + << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" + << " // Assumes kv have BNSH layout.\n" + << " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = key[nOffset+w];\n" + << " present_value[presentKeyOffset+w] = value[nOffset+w];\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, + const Tensor* K, const Tensor* past_key, Tensor* present_key, + const Tensor* V, const Tensor* past_value, Tensor* present_value, + int past_sequence_length, int total_sequence_length) { + // CopyKVCache takes past key/value and current key/value and copies them to present key and value. + // This makes it so that FlashAttention only needs to look at present key and value, and saves + // number of input buffers in the shader, which we run out of (<=8) without this optimization. + const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); + bool has_past = (past_sequence_length != 0); + CopyKVCacheProgram program{"CopyKVCache", components, has_past}; + if (has_past) { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } else { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } + + program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, + {present_value, ProgramTensorMetadataDependency::Rank, components}}); + + program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads_) + .SetWorkgroupSize(1) + .CacheHint(std::to_string(components) + std::to_string(has_past)) + .AddUniformVariables({{static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(parameters.head_size_ / components)}}); + + return context.RunProgram(program); +} + +Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(new_sequence_length)(total_sequence_length) + // + // Expectation is that present_key, and present_value contain past key and values since + // we are out of storage buffers a shader can have and both past/present cant be passed. + // The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads. + constexpr int vectorization_size = 4; + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("present_key", ShaderUsage::UseUniform); + shader.AddInput("present_value", ShaderUsage::UseUniform); + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 16. + // TILE_SIZE is the number of groups sharing the k_tile. + // TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when + // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE + // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu + // gpu limits. For Intel this TILE_SIZE will be 16. + // Change precision_t to be f32 below to run dotproduct/ softmax in fp32 precision. + shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" + << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" + << "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n" + << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n" + << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" + << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" + << "alias precision_t = q_element_t;\n" + << "const MIN_VALUE : precision_t = precision_t(-65504.0h);\n"; + + // Best to keep SHM usage per workgroup < 128KB, from intel docs for Intel Iris Xe GPU. + // "The SLM is a 128KB High Bandwidth Memory (HBM) accessible from the EUs in the subslice" + // GPU afterwhich workgroups will be unscheduled to make space for memory. + shader.AdditionalImplementation() << "" + << "var q_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var k_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var v_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var o_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var qk_tile : array, TILE_SIZE>; // 16 * 2 * 16 = 512\n" + << "var max_tile : array; // 2 * 16 = 32\n" + << "var denom_tile : array; // 2 * 16 = 32\n" + << "var o_ratio : array; // 2 * 16 = 32\n"; + + shader.AdditionalImplementation() << R"HELPER_FN( +fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA + // This is the layout if TransferBSDToBNSH has not been run. + let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; + // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. + // let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var value = q[idx+offset]; + q_tile[slot][idx] = value; + } +} +fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) +{ + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) + { + var value = present_key[idx+offset]; + k_tile[slot][idx] = value; + } +} +fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) +{ + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) + { + v_tile[slot][idx] = present_value[idx+offset]; + } +} +fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) +{ + // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length || k_col >= TILE_SIZE) { + qk_tile[q_row][k_col] = 0.0; + return; + } + let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; + qk_tile[q_row][k_col] = precision_t(attention_bias[offset]); +} +fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + // Stored as float16[batch_size,sequence_length,3072] + let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size) + { + let value = o_tile[slot][idx]; + output[offset+idx] = value; + } +} +fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) +{ + var sum:vec4 = vec4(0, 0, 0, 0); + // idx is not initialized to sg_id to ensure uniformity because the loop uses + // subgroupAdd and unused lanes need to be initialized with 0 for correctness. + for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var result = vec4(0); + let sg_idx = idx+sg_id; + if (sg_idx < QKV_HEAD_VECTORIZED_SIZE) + { + result = vec4(q_tile[q_idx][sg_idx])*vec4(k_tile[k_idx][sg_idx]); + } + sum += subgroupAdd(result); + } + if (sg_id == 0) + { + let single_sum : precision_t = sum.x + sum.y + sum.z + sum.w; + let sqrt_dk = precision_t(uniforms.alpha); + let value = single_sum * sqrt_dk; + qk_tile[q_idx][k_idx] += value; + } +} +// +// Crux of Flash Attention is here, that allows for partial softmax computation, +// direct update of output and merging with previous results. +// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf +// Where b is the block size of the tile. Xi is storing QKtranspose for the ith tile. +// mi_local is the max of Xi. Note: _ in this notation means what follows is a +// subscript. max_j=1:b (Xi[j]) is the max of Xi[j] for j=1 to b. +// +// for i = 1, #tiles do +// Xi = Q[k,:] Kt[:, (i-1) b : i b] +// mi_local= max_j=1:b (Xi[j]) +// Mi = max(M_(i-1), mi_local) +// d'_i = d'_(i-1) * e^(M_(i-1)-M_i) + Σ_j=1:b e^(Xi[j]-Mi) +// o'_i = o'_(i-1) * d'_(i-1) * e^(M_(i-1)-M_i) / d'_i + Σ_j=1:b (e^(Xi[j]-Mi) / d'_i) V[j + (i - 1)b,:] +// end +// +fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) +{ + var x : precision_t = MIN_VALUE; + if (enabled){ + x = qk_tile[q_idx][sg_id]; + } + var max_value = subgroupMax(x); + max_value = max(max_tile[q_idx], max_value); + let sub = x - max_value; + var value:precision_t = 0; + if (enabled) { + value = exp(sub); + } + let sum = subgroupAdd(value); + // Compute lhs term of update di prime and the compute di prime. + let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); + var d = dleft + sum; + if (d == 0) + { + // Avoid division by zero by setting d to a really small value. + // Note: Removing this protection has had no negative effect on any + // of the prompts tried so far. This is a safety net. + d = precision_t(0.0000001h); + } + qk_tile[q_idx][sg_id] = value / d; + if (sg_id == 0) + { + max_tile[q_idx] = max_value; + denom_tile[q_idx] = d; + o_ratio[q_idx] = dleft / d; + } +} +fn computeO(q_idx: u32, sg_id:u32, enabled:bool) +{ + var attn = precision_t(0); + if (enabled) + { + attn = qk_tile[q_idx][sg_id]; + } + for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) + { + let val = vec4(v_tile[sg_id][i]); + var intermediate = attn * val; + let sum = subgroupAdd(intermediate); + if (sg_id == 0) + { + let o_ratio = o_ratio[q_idx]; + let old_o = vec4(o_tile[q_idx][i]); + let new_o = ( o_ratio * old_o) + sum; + o_tile[q_idx][i] = q_value_t(new_o); + } + } +} +)HELPER_FN"; + + // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) + // Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. + // Each workgroup has TILE_SIZE waves, with each wave having subgroup size number of lanes (threads). + // Synchronization between lanes in a wave is free, with various subgroup* functions, and this shader + // uses that. Synchronization between waves requires calling workgroupBarrier. + shader.MainFunctionBody() << R"MAIN_FN( +let head_idx = workgroup_id.x; +// It is always the case that 0 <= wave_id < TILE_SIZE +// Each wave has sg_size lanes (subgroup threads). +let wave_id = u32(local_idx / sg_size); + +let q_idx_start = workgroup_id.y * TILE_SIZE; +let q_idx_global = q_idx_start + wave_id; +let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length; +if (q_idx_global_using_wave_valid) +{ + // Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query. + loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} +if (sg_id == 0) +{ + max_tile[wave_id] = MIN_VALUE; +} +for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) +{ + // Insert barrier before updating shared memory the workgroup shares. + workgroupBarrier(); + let k_idx_global = k_start+wave_id; + let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; + if (k_idx_global_using_wave_valid) { + // Leveraging the subgroup lanes for parallelism, load into slot wave_id + // K/V values from k_start+wave_id. + loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); + loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); + } + // Next, we want for every q row (wave_id) to populate bias for new sequence length + // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, + // and sg_id, (k_start+sg_id). + loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); + // Insert barrier before workgroup starts reading the shared memory. + workgroupBarrier(); + + //if (k_idx_global_using_wave_valid) + { + // Iterate over Q rather than K because for the case of new_seq 1, there is a single query + // and context length of K by iterating over Q using the waves for K, this step can use all + // the waves in the workgroup, instead of leaving them idle. + for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) + { + // Leveraging the subgroups for parallelism, compute dot product of QK. + // We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to + // validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE. + computeDotProduct(q_idx, wave_id, sg_id, sg_size); + } + } + // Insert barrier before SoftMax reads the dot product values across K. + workgroupBarrier(); + + let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; + computeSoftMax(wave_id, sg_id, wave_lane_valid); + computeO(wave_id, sg_id, wave_lane_valid); +} +workgroupBarrier(); +if (q_idx_global_using_wave_valid) +{ + writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} +)MAIN_FN"; + + return Status::OK(); +} + +Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length_, parameters.total_sequence_length_)); + + const uint32_t subgroup_size = 16; + const uint32_t tile_size = subgroup_size; + bool has_attention_bias = attention_bias != nullptr; + FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size_, parameters.num_heads_}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); + const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) + : parameters.scale_; + std::string cache_hint = std::to_string(has_attention_bias) + + std::to_string(subgroup_size) + + std::to_string(tile_size) + + std::to_string(parameters.head_size_) + + std::to_string(parameters.num_heads_); + program.SetDispatchGroupSize(parameters.num_heads_, (parameters.sequence_length_ + tile_size - 1) / tile_size, 1) + .SetWorkgroupSize(subgroup_size * subgroup_size) + .CacheHint(cache_hint) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(parameters.total_sequence_length_)}, + {alpha}}); + + return context.RunProgram(program); +} + +bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + // The min subgroup size affects the block size while going through the sequence length. + // 16 is the smallest size tested, smaller sized would impact performance. + // Checking for this also ensures that we dont run flash attention where subgroup is not supported. + constexpr int kMinSupportedSubgroupSize = 16; + // Workgroup size is set to be (subgroup_size * subgroup_size), check that it is allowed. + // Flash attention is written only to support batch_size of 1, algorithm can be extended to support + // batch_size > 1. What bias is used for is not clear, so it is not implemented in the shader. + // The Flash attention implementation is vectorized, to keep things simple, only vec4 is implemented - + // this implies that head_size has to be a multiple of 4. + return context.DeviceLimits().maxComputeWorkgroupSizeX >= (kMinSupportedSubgroupSize * kMinSupportedSubgroupSize) && + parameters.batch_size_ == 1 && + bias == nullptr && + present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && + present_value->SizeInBytes() > 0 && parameters.head_size_ % 4 == 0; +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h new file mode 100644 index 0000000000000..ed09c705299d8 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class CopyKVCacheProgram final : public Program { + public: + CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past) + : Program{kernel_name}, components_(components), has_past_(has_past) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"vectorized_head_size", ProgramUniformVariableDataType::Uint32}); + + private: + int components_; + bool has_past_; +}; + +class FlashAttentionProgram final : public Program { + public: + FlashAttentionProgram(const std::string& kernel_name, + bool has_attention_bias, + uint32_t subgroup_size, + uint32_t tile_size, + int qkv_head_size, + int qkv_num_heads) + : Program{kernel_name}, + has_attention_bias_(has_attention_bias), + subgroup_size_(subgroup_size), + tile_size_(tile_size), + qkv_head_size_(qkv_head_size), + qkv_num_heads_(qkv_num_heads) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}); + + private: + bool has_attention_bias_; + uint32_t subgroup_size_; + uint32_t tile_size_; + int qkv_head_size_; + int qkv_num_heads_; +}; + +Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); + +bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 424556c66bd9d..ffa0f56ca126b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -5,6 +5,7 @@ #include "contrib_ops/webgpu/bert/attention_common.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -74,6 +75,11 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); + if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) { + return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context); + } + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims); From 6b978cbd36370cb0de56e9e07ef9aec4188ed5be Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sun, 2 Feb 2025 12:21:38 -0800 Subject: [PATCH 02/11] Attempt FA2 --- .../webgpu/bert/flash_attention.cc | 409 ++++++++---------- .../contrib_ops/webgpu/bert/flash_attention.h | 6 - 2 files changed, 170 insertions(+), 245 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 45b2576a231cf..0a73bd02af10a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -111,7 +111,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Expectation is that present_key, and present_value contain past key and values since // we are out of storage buffers a shader can have and both past/present cant be passed. // The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads. - constexpr int vectorization_size = 4; shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddInput("present_key", ShaderUsage::UseUniform); shader.AddInput("present_value", ShaderUsage::UseUniform); @@ -120,249 +119,184 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AddOutput("output", ShaderUsage::UseUniform); - // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 16. - // TILE_SIZE is the number of groups sharing the k_tile. - // TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when - // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE - // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu - // gpu limits. For Intel this TILE_SIZE will be 16. - // Change precision_t to be f32 below to run dotproduct/ softmax in fp32 precision. - shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" - << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" - << "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n" - << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n" - << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" - << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" - << "alias precision_t = q_element_t;\n" - << "const MIN_VALUE : precision_t = precision_t(-65504.0h);\n"; - - // Best to keep SHM usage per workgroup < 128KB, from intel docs for Intel Iris Xe GPU. - // "The SLM is a 128KB High Bandwidth Memory (HBM) accessible from the EUs in the subslice" - // GPU afterwhich workgroups will be unscheduled to make space for memory. - shader.AdditionalImplementation() << "" - << "var q_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var k_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var v_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var o_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var qk_tile : array, TILE_SIZE>; // 16 * 2 * 16 = 512\n" - << "var max_tile : array; // 2 * 16 = 32\n" - << "var denom_tile : array; // 2 * 16 = 32\n" - << "var o_ratio : array; // 2 * 16 = 32\n"; + shader.AdditionalImplementation() << "const qkv_head_size: u32 = " << qkv_head_size_ << ";\n" + << "const num_heads: u32 =" << qkv_num_heads_ << ";\n"; shader.AdditionalImplementation() << R"HELPER_FN( -fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) -{ - // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA - // This is the layout if TransferBSDToBNSH has not been run. - let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; - // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. - // let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) - { - var value = q[idx+offset]; - q_tile[slot][idx] = value; - } -} -fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) -{ - // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) - { - var value = present_key[idx+offset]; - k_tile[slot][idx] = value; - } -} -fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) -{ - // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) - { - v_tile[slot][idx] = present_value[idx+offset]; - } -} -fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) -{ - // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length || k_col >= TILE_SIZE) { - qk_tile[q_row][k_col] = 0.0; - return; - } - let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; - qk_tile[q_row][k_col] = precision_t(attention_bias[offset]); -} -fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) -{ - // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size) - { - let value = o_tile[slot][idx]; - output[offset+idx] = value; - } -} -fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) -{ - var sum:vec4 = vec4(0, 0, 0, 0); - // idx is not initialized to sg_id to ensure uniformity because the loop uses - // subgroupAdd and unused lanes need to be initialized with 0 for correctness. - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) - { - var result = vec4(0); - let sg_idx = idx+sg_id; - if (sg_idx < QKV_HEAD_VECTORIZED_SIZE) - { - result = vec4(q_tile[q_idx][sg_idx])*vec4(k_tile[k_idx][sg_idx]); - } - sum += subgroupAdd(result); - } - if (sg_id == 0) - { - let single_sum : precision_t = sum.x + sum.y + sum.z + sum.w; - let sqrt_dk = precision_t(uniforms.alpha); - let value = single_sum * sqrt_dk; - qk_tile[q_idx][k_idx] += value; - } -} -// -// Crux of Flash Attention is here, that allows for partial softmax computation, -// direct update of output and merging with previous results. -// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf -// Where b is the block size of the tile. Xi is storing QKtranspose for the ith tile. -// mi_local is the max of Xi. Note: _ in this notation means what follows is a -// subscript. max_j=1:b (Xi[j]) is the max of Xi[j] for j=1 to b. -// -// for i = 1, #tiles do -// Xi = Q[k,:] Kt[:, (i-1) b : i b] -// mi_local= max_j=1:b (Xi[j]) -// Mi = max(M_(i-1), mi_local) -// d'_i = d'_(i-1) * e^(M_(i-1)-M_i) + Σ_j=1:b e^(Xi[j]-Mi) -// o'_i = o'_(i-1) * d'_(i-1) * e^(M_(i-1)-M_i) / d'_i + Σ_j=1:b (e^(Xi[j]-Mi) / d'_i) V[j + (i - 1)b,:] -// end -// -fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) -{ - var x : precision_t = MIN_VALUE; - if (enabled){ - x = qk_tile[q_idx][sg_id]; - } - var max_value = subgroupMax(x); - max_value = max(max_tile[q_idx], max_value); - let sub = x - max_value; - var value:precision_t = 0; - if (enabled) { - value = exp(sub); - } - let sum = subgroupAdd(value); - // Compute lhs term of update di prime and the compute di prime. - let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); - var d = dleft + sum; - if (d == 0) - { - // Avoid division by zero by setting d to a really small value. - // Note: Removing this protection has had no negative effect on any - // of the prompts tried so far. This is a safety net. - d = precision_t(0.0000001h); - } - qk_tile[q_idx][sg_id] = value / d; - if (sg_id == 0) - { - max_tile[q_idx] = max_value; - denom_tile[q_idx] = d; - o_ratio[q_idx] = dleft / d; - } -} -fn computeO(q_idx: u32, sg_id:u32, enabled:bool) -{ - var attn = precision_t(0); - if (enabled) - { - attn = qk_tile[q_idx][sg_id]; - } - for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) - { - let val = vec4(v_tile[sg_id][i]); - var intermediate = attn * val; - let sum = subgroupAdd(intermediate); - if (sg_id == 0) - { - let o_ratio = o_ratio[q_idx]; - let old_o = vec4(o_tile[q_idx][i]); - let new_o = ( o_ratio * old_o) + sum; - o_tile[q_idx][i] = q_value_t(new_o); - } - } -} + const k_step: u32 = 16u; + const vec_factor: u32 = 4u; + const qkv_head_size_vec: u32 = qkv_head_size / vec_factor; + const min_value : q_element_t = q_element_t(-65504.0h); + + // Default SHM usage limit is 16KB in Dawn. + var k_tile : array, k_step>; // 96 * 2 * 16 = 3KB. + var v_tile : array, k_step>; // 96 * 2 * 16 = 3KB. + + // Private memory per lane. + var q_tile : array; + var o_tile : array; + + fn loadq(q_idx_global : u32, head_idx: u32) + { + // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA + // This is the layout if TransferBSDToBNSH has not been run. + let offset = q_idx_global * (qkv_head_size_vec) * num_heads + qkv_head_size_vec * head_idx; + // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. + //let offset = head_idx * uniforms.new_sequence_length * qkv_head_size_vec + q_idx_global * qkv_head_size_vec; + for (var idx:u32 = 0; idx < qkv_head_size_vec; idx++) + { + q_tile[idx] = q[idx+offset]; + } + } + fn loadk(k_start : u32, head_idx: u32, local_idx: u32) + { + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + k_start * qkv_head_size_vec; + for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) + { + k_tile[u32(idx/qkv_head_size_vec)][idx%qkv_head_size_vec] = present_key[offset+idx]; + } + } + fn loadv(v_start : u32, head_idx: u32, local_idx: u32) + { + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + v_start * qkv_head_size_vec; + for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) + { + v_tile[u32(idx/qkv_head_size_vec)][idx%qkv_head_size_vec] = present_value[offset+idx]; + } + } + fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 + { + // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) { + return vec4(0); + } + let offset_base = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length; + let offset = offset_base + k_idx_global; + let offset_max = offset_base + uniforms.present_sequence_length; + let c1 = q_element_t(attention_bias[min(offset, offset_max)]); + let c2 = q_element_t(attention_bias[min(offset+1, offset_max)]); + let c3 = q_element_t(attention_bias[min(offset+2, offset_max)]); + let c4 = q_element_t(attention_bias[min(offset+3, offset_max)]); + return vec4(c1,c2,c3,c4); + } + fn writeo(o_idx_global: u32, head_idx: u32) + { + // Stored as float16[batch_size,sequence_length,3072] + let offset = o_idx_global * num_heads * qkv_head_size_vec + head_idx * qkv_head_size_vec; + for (var idx:u32 = 0; idx < qkv_head_size_vec; idx ++) + { + output[offset+idx] = o_tile[idx]; + } + } )HELPER_FN"; - // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) - // Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. - // Each workgroup has TILE_SIZE waves, with each wave having subgroup size number of lanes (threads). - // Synchronization between lanes in a wave is free, with various subgroup* functions, and this shader - // uses that. Synchronization between waves requires calling workgroupBarrier. + // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / workgroup_size_x, 1) + // Each lane/thread is responsible for a single q. shader.MainFunctionBody() << R"MAIN_FN( -let head_idx = workgroup_id.x; -// It is always the case that 0 <= wave_id < TILE_SIZE -// Each wave has sg_size lanes (subgroup threads). -let wave_id = u32(local_idx / sg_size); - -let q_idx_start = workgroup_id.y * TILE_SIZE; -let q_idx_global = q_idx_start + wave_id; -let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length; -if (q_idx_global_using_wave_valid) -{ - // Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query. - loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); -} -if (sg_id == 0) -{ - max_tile[wave_id] = MIN_VALUE; -} -for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) -{ - // Insert barrier before updating shared memory the workgroup shares. + let head_idx = workgroup_id.x; + + // Load Q + let q_idx_global = workgroup_id.y * workgroup_size_x + local_idx; + let valid_q = q_idx_global < uniforms.new_sequence_length; + if (valid_q) + { + loadq(q_idx_global, head_idx); + } + + var previous_max : q_element_t = min_value; + var previous_denom : q_element_t = 0; + + for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=k_step) + { workgroupBarrier(); - let k_idx_global = k_start+wave_id; - let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; - if (k_idx_global_using_wave_valid) { - // Leveraging the subgroup lanes for parallelism, load into slot wave_id - // K/V values from k_start+wave_id. - loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); - loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); - } - // Next, we want for every q row (wave_id) to populate bias for new sequence length - // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, - // and sg_id, (k_start+sg_id). - loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); - // Insert barrier before workgroup starts reading the shared memory. + loadk(k_start, head_idx, local_idx); + loadv(k_start, head_idx, local_idx); workgroupBarrier(); - //if (k_idx_global_using_wave_valid) + // Compute QKt + var qk_1:vec4 = loadAttentionBias(q_idx_global, k_start, head_idx); + var qk_2:vec4 = loadAttentionBias(q_idx_global, k_start+4, head_idx); + var qk_3:vec4 = loadAttentionBias(q_idx_global, k_start+8, head_idx); + var qk_4:vec4 = loadAttentionBias(q_idx_global, k_start+12, head_idx); + for (var i:u32 = 0u; i < qkv_head_size_vec; i++) { - // Iterate over Q rather than K because for the case of new_seq 1, there is a single query - // and context length of K by iterating over Q using the waves for K, this step can use all - // the waves in the workgroup, instead of leaving them idle. - for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) - { - // Leveraging the subgroups for parallelism, compute dot product of QK. - // We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to - // validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE. - computeDotProduct(q_idx, wave_id, sg_id, sg_size); - } + var k_local = k_tile[sg_id][i]; + var q_own = q_tile[i]; + qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); + qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); + qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); + qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); + qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); + qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); + qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); + qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); + qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8)); + qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9)); + qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10)); + qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11)); + qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12)); + qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13)); + qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14)); + qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15)); } - // Insert barrier before SoftMax reads the dot product values across K. - workgroupBarrier(); + qk_1 = qk_1 * q_element_t(uniforms.alpha); + qk_2 = qk_2 * q_element_t(uniforms.alpha); + qk_3 = qk_3 * q_element_t(uniforms.alpha); + qk_4 = qk_4 * q_element_t(uniforms.alpha); - let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; - computeSoftMax(wave_id, sg_id, wave_lane_valid); - computeO(wave_id, sg_id, wave_lane_valid); -} -workgroupBarrier(); -if (q_idx_global_using_wave_valid) -{ - writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); -} + // Compute SoftMax + var local_max_temp = max(qk_1, qk_2); + local_max_temp = max(local_max_temp, qk_3); + local_max_temp = max(local_max_temp, qk_4); + let local_max = max(max(local_max_temp.x, local_max_temp.y),max(local_max_temp.z, local_max_temp.w)); + let new_max = max(previous_max, local_max); + qk_1 = exp(qk_1 - new_max); + qk_2 = exp(qk_2 - new_max); + qk_3 = exp(qk_3 - new_max); + qk_4 = exp(qk_4 - new_max); + let sum_vec = qk_1 + qk_2 + qk_3 + qk_4; + let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w; + // Compute lhs term of update di prime and the compute di prime. + let dleft = previous_denom * exp(previous_max-new_max); + var d = dleft + sum; + d = select(d,q_element_t(0.0000001h),d==0); + qk_1 = qk_1 / d; + qk_2 = qk_2 / d; + qk_3 = qk_3 / d; + qk_4 = qk_4 / d; + previous_max = new_max; + previous_denom = d; + let o_ratio = dleft / d; + + + for (var i:u32 = 0; i < qkv_head_size_vec; i++) + { + var val = vec4(v_tile[sg_id][i]); + var sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + sum += subgroupShuffle(val, 8) * qk_3[0]; + sum += subgroupShuffle(val, 9) * qk_3[1]; + sum += subgroupShuffle(val, 10) * qk_3[2]; + sum += subgroupShuffle(val, 11) * qk_3[3]; + sum += subgroupShuffle(val, 12) * qk_4[0]; + sum += subgroupShuffle(val, 13) * qk_4[1]; + sum += subgroupShuffle(val, 14) * qk_4[2]; + sum += subgroupShuffle(val, 14) * qk_4[3]; + o_tile[i] = o_tile[i] * o_ratio + sum; + } + } + + if (valid_q) { + writeo(q_idx_global, head_idx); + } )MAIN_FN"; return Status::OK(); @@ -373,10 +307,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length_, parameters.total_sequence_length_)); - const uint32_t subgroup_size = 16; - const uint32_t tile_size = subgroup_size; + const uint32_t tile_size = 64; bool has_attention_bias = attention_bias != nullptr; - FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size_, parameters.num_heads_}; + FlashAttentionProgram program{"FlashAttention", has_attention_bias, parameters.head_size_, parameters.num_heads_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, @@ -385,12 +318,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; std::string cache_hint = std::to_string(has_attention_bias) + - std::to_string(subgroup_size) + - std::to_string(tile_size) + std::to_string(parameters.head_size_) + std::to_string(parameters.num_heads_); program.SetDispatchGroupSize(parameters.num_heads_, (parameters.sequence_length_ + tile_size - 1) / tile_size, 1) - .SetWorkgroupSize(subgroup_size * subgroup_size) + .SetWorkgroupSize(64) .CacheHint(cache_hint) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index ed09c705299d8..c80b4d3f3fc39 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -36,14 +36,10 @@ class FlashAttentionProgram final : public Program { public: FlashAttentionProgram(const std::string& kernel_name, bool has_attention_bias, - uint32_t subgroup_size, - uint32_t tile_size, int qkv_head_size, int qkv_num_heads) : Program{kernel_name}, has_attention_bias_(has_attention_bias), - subgroup_size_(subgroup_size), - tile_size_(tile_size), qkv_head_size_(qkv_head_size), qkv_num_heads_(qkv_num_heads) { } @@ -56,8 +52,6 @@ class FlashAttentionProgram final : public Program { private: bool has_attention_bias_; - uint32_t subgroup_size_; - uint32_t tile_size_; int qkv_head_size_; int qkv_num_heads_; }; From 655c8b6123a19a5f8a2088f8bfbf87eca5d78626 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 3 Feb 2025 13:30:40 -0800 Subject: [PATCH 03/11] attempt to fix k-index --- .../webgpu/bert/flash_attention.cc | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 0a73bd02af10a..6af6e6bbe92af 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -127,6 +127,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { const vec_factor: u32 = 4u; const qkv_head_size_vec: u32 = qkv_head_size / vec_factor; const min_value : q_element_t = q_element_t(-65504.0h); + // min_value_frac is a small min value that when accumulated + // qkv_head_size_vec times will leave us with a value close to min value. + const min_value_frac : q_element_t = q_element_t(-10.0); // Default SHM usage limit is 16KB in Dawn. var k_tile : array, k_step>; // 96 * 2 * 16 = 3KB. @@ -135,7 +138,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Private memory per lane. var q_tile : array; var o_tile : array; - fn loadq(q_idx_global : u32, head_idx: u32) { // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA @@ -154,7 +156,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + k_start * qkv_head_size_vec; for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) { - k_tile[u32(idx/qkv_head_size_vec)][idx%qkv_head_size_vec] = present_key[offset+idx]; + let slot = u32(idx/qkv_head_size_vec); + let val = select(q_value_t(min_value_frac), present_key[offset+idx], k_start + slot < uniforms.present_sequence_length); + k_tile[slot][idx%qkv_head_size_vec] = val; } } fn loadv(v_start : u32, head_idx: u32, local_idx: u32) @@ -163,7 +167,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + v_start * qkv_head_size_vec; for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) { - v_tile[u32(idx/qkv_head_size_vec)][idx%qkv_head_size_vec] = present_value[offset+idx]; + let slot = u32(idx/qkv_head_size_vec); + let val = select(q_value_t(min_value_frac), present_value[offset+idx], v_start + slot < uniforms.present_sequence_length); + v_tile[slot][idx%qkv_head_size_vec] = val; } } fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 @@ -252,10 +258,10 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { local_max_temp = max(local_max_temp, qk_4); let local_max = max(max(local_max_temp.x, local_max_temp.y),max(local_max_temp.z, local_max_temp.w)); let new_max = max(previous_max, local_max); - qk_1 = exp(qk_1 - new_max); - qk_2 = exp(qk_2 - new_max); - qk_3 = exp(qk_3 - new_max); - qk_4 = exp(qk_4 - new_max); + qk_1 = q_value_t(exp(vec4(qk_1) - f32(new_max))); + qk_2 = q_value_t(exp(vec4(qk_2) - f32(new_max))); + qk_3 = q_value_t(exp(vec4(qk_3) - f32(new_max))); + qk_4 = q_value_t(exp(vec4(qk_4) - f32(new_max))); let sum_vec = qk_1 + qk_2 + qk_3 + qk_4; let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w; // Compute lhs term of update di prime and the compute di prime. @@ -270,7 +276,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { previous_denom = d; let o_ratio = dleft / d; - for (var i:u32 = 0; i < qkv_head_size_vec; i++) { var val = vec4(v_tile[sg_id][i]); From 3df32ed2234c6fd8f14511c14254f162aacfccc4 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 3 Feb 2025 14:22:47 -0800 Subject: [PATCH 04/11] This FA works --- .../webgpu/bert/flash_attention.cc | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 6af6e6bbe92af..974b8db303ac3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -127,9 +127,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { const vec_factor: u32 = 4u; const qkv_head_size_vec: u32 = qkv_head_size / vec_factor; const min_value : q_element_t = q_element_t(-65504.0h); - // min_value_frac is a small min value that when accumulated - // qkv_head_size_vec times will leave us with a value close to min value. - const min_value_frac : q_element_t = q_element_t(-10.0); // Default SHM usage limit is 16KB in Dawn. var k_tile : array, k_step>; // 96 * 2 * 16 = 3KB. @@ -157,7 +154,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) { let slot = u32(idx/qkv_head_size_vec); - let val = select(q_value_t(min_value_frac), present_key[offset+idx], k_start + slot < uniforms.present_sequence_length); + let val = select(q_value_t(0), present_key[offset+idx], k_start + slot < uniforms.present_sequence_length); k_tile[slot][idx%qkv_head_size_vec] = val; } } @@ -168,7 +165,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) { let slot = u32(idx/qkv_head_size_vec); - let val = select(q_value_t(min_value_frac), present_value[offset+idx], v_start + slot < uniforms.present_sequence_length); + let val = select(q_value_t(0), present_value[offset+idx], v_start + slot < uniforms.present_sequence_length); v_tile[slot][idx%qkv_head_size_vec] = val; } } @@ -252,6 +249,23 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { qk_3 = qk_3 * q_element_t(uniforms.alpha); qk_4 = qk_4 * q_element_t(uniforms.alpha); + // Neuter out of bounds qk values + qk_1[1] = select(min_value, qk_1[1], k_start+1 < uniforms.present_sequence_length); + qk_1[2] = select(min_value, qk_1[2], k_start+2 < uniforms.present_sequence_length); + qk_1[3] = select(min_value, qk_1[3], k_start+3 < uniforms.present_sequence_length); + qk_2[0] = select(min_value, qk_2[0], k_start+4 < uniforms.present_sequence_length); + qk_2[1] = select(min_value, qk_2[1], k_start+5 < uniforms.present_sequence_length); + qk_2[2] = select(min_value, qk_2[2], k_start+6 < uniforms.present_sequence_length); + qk_2[3] = select(min_value, qk_2[3], k_start+7 < uniforms.present_sequence_length); + qk_3[0] = select(min_value, qk_3[0], k_start+8 < uniforms.present_sequence_length); + qk_3[1] = select(min_value, qk_3[1], k_start+9 < uniforms.present_sequence_length); + qk_3[2] = select(min_value, qk_3[2], k_start+10 < uniforms.present_sequence_length); + qk_3[3] = select(min_value, qk_3[3], k_start+11 < uniforms.present_sequence_length); + qk_4[0] = select(min_value, qk_4[0], k_start+12 < uniforms.present_sequence_length); + qk_4[1] = select(min_value, qk_4[1], k_start+13 < uniforms.present_sequence_length); + qk_4[2] = select(min_value, qk_4[2], k_start+14 < uniforms.present_sequence_length); + qk_4[3] = select(min_value, qk_4[3], k_start+15 < uniforms.present_sequence_length); + // Compute SoftMax var local_max_temp = max(qk_1, qk_2); local_max_temp = max(local_max_temp, qk_3); @@ -278,7 +292,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { for (var i:u32 = 0; i < qkv_head_size_vec; i++) { - var val = vec4(v_tile[sg_id][i]); + var val = select(vec4(0), v_tile[sg_id][i], k_start + sg_id < uniforms.present_sequence_length); var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; @@ -294,7 +308,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { sum += subgroupShuffle(val, 12) * qk_4[0]; sum += subgroupShuffle(val, 13) * qk_4[1]; sum += subgroupShuffle(val, 14) * qk_4[2]; - sum += subgroupShuffle(val, 14) * qk_4[3]; + sum += subgroupShuffle(val, 15) * qk_4[3]; o_tile[i] = o_tile[i] * o_ratio + sum; } } From 71c8d59df6905ce5894b7e979be868dc6cd1bf59 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 3 Feb 2025 16:29:32 -0800 Subject: [PATCH 05/11] Add comments --- .../webgpu/bert/flash_attention.cc | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 974b8db303ac3..706800fe7ec26 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -249,7 +249,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { qk_3 = qk_3 * q_element_t(uniforms.alpha); qk_4 = qk_4 * q_element_t(uniforms.alpha); - // Neuter out of bounds qk values + // Neuter qk values where K is out of bounds. qk_1[1] = select(min_value, qk_1[1], k_start+1 < uniforms.present_sequence_length); qk_1[2] = select(min_value, qk_1[2], k_start+2 < uniforms.present_sequence_length); qk_1[3] = select(min_value, qk_1[3], k_start+3 < uniforms.present_sequence_length); @@ -266,7 +266,29 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { qk_4[2] = select(min_value, qk_4[2], k_start+14 < uniforms.present_sequence_length); qk_4[3] = select(min_value, qk_4[3], k_start+15 < uniforms.present_sequence_length); - // Compute SoftMax + // + // Compute SoftMax as per Flash Attention technique. + // + // Crux of Flash Attention is here, that allows for partial softmax computation, + // direct update of output and merging with previous results. + // https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf + // Where b is the block size of the tile. Xi is storing QKtranspose for the ith tile. + // mi_local is the max of Xi. Note: _ in this notation means what follows is a + // subscript. max_j=1:b (Xi[j]) is the max of Xi[j] for j=1 to b. + // + // for i = 1, #tiles do + // Xi = Q[k,:] Kt[:, (i-1) b : i b] + // mi_local= max_j=1:b (Xi[j]) + // Mi = max(M_(i-1), mi_local) + // d'_i = d'_(i-1) * e^(M_(i-1)-M_i) + Σ_j=1:b e^(Xi[j]-Mi) + // o'_i = o'_(i-1) * d'_(i-1) * e^(M_(i-1)-M_i) / d'_i + Σ_j=1:b (e^(Xi[j]-Mi) / d'_i) V[j + (i - 1)b,:] + // end + // + // In the code below: + // dleft is the first term of d'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i). + // sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi) + // o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i + // var local_max_temp = max(qk_1, qk_2); local_max_temp = max(local_max_temp, qk_3); local_max_temp = max(local_max_temp, qk_4); From 05b0f250e3cb93d2a4e25f42fbc7199e1001be95 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 3 Feb 2025 18:45:14 -0800 Subject: [PATCH 06/11] Support all sg_size and restrict FA to prefill only. On ADL, WU drivers - TTFT for 1K tokens is avg (us): 4.27937e+06. --- .../webgpu/bert/flash_attention.cc | 246 +++++++++++------- .../contrib_ops/webgpu/bert/flash_attention.h | 2 +- .../webgpu/bert/multihead_attention.cc | 2 +- 3 files changed, 156 insertions(+), 94 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 706800fe7ec26..a92b00bd7bf88 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -123,14 +123,16 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << "const num_heads: u32 =" << qkv_num_heads_ << ";\n"; shader.AdditionalImplementation() << R"HELPER_FN( - const k_step: u32 = 16u; + // For max performance max_k_step should be the same as sg_size, however we might run out of registers + // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). + const max_k_step: u32 = 16u; const vec_factor: u32 = 4u; const qkv_head_size_vec: u32 = qkv_head_size / vec_factor; const min_value : q_element_t = q_element_t(-65504.0h); // Default SHM usage limit is 16KB in Dawn. - var k_tile : array, k_step>; // 96 * 2 * 16 = 3KB. - var v_tile : array, k_step>; // 96 * 2 * 16 = 3KB. + var k_tile : array, max_k_step>; // 96 * 2 * 16 = 3KB. + var v_tile : array, max_k_step>; // 96 * 2 * 16 = 3KB. // Private memory per lane. var q_tile : array; @@ -147,7 +149,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { q_tile[idx] = q[idx+offset]; } } - fn loadk(k_start : u32, head_idx: u32, local_idx: u32) + fn loadk(k_start : u32, head_idx: u32, local_idx: u32, k_step: u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + k_start * qkv_head_size_vec; @@ -158,7 +160,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { k_tile[slot][idx%qkv_head_size_vec] = val; } } - fn loadv(v_start : u32, head_idx: u32, local_idx: u32) + fn loadv(v_start : u32, head_idx: u32, local_idx: u32, k_step: u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + v_start * qkv_head_size_vec; @@ -169,21 +171,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { v_tile[slot][idx%qkv_head_size_vec] = val; } } - fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 - { - // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) { - return vec4(0); - } - let offset_base = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length; - let offset = offset_base + k_idx_global; - let offset_max = offset_base + uniforms.present_sequence_length; - let c1 = q_element_t(attention_bias[min(offset, offset_max)]); - let c2 = q_element_t(attention_bias[min(offset+1, offset_max)]); - let c3 = q_element_t(attention_bias[min(offset+2, offset_max)]); - let c4 = q_element_t(attention_bias[min(offset+3, offset_max)]); - return vec4(c1,c2,c3,c4); - } fn writeo(o_idx_global: u32, head_idx: u32) { // Stored as float16[batch_size,sequence_length,3072] @@ -195,10 +182,41 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { } )HELPER_FN"; + if (has_attention_bias_) { + shader.AdditionalImplementation() << R"HELPER_FN( + fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 + { + // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) { + return vec4(0); + } + let offset_base = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length; + let offset = offset_base + k_idx_global; + let offset_max = offset_base + uniforms.present_sequence_length; + let c1 = q_element_t(attention_bias[min(offset, offset_max)]); + let c2 = q_element_t(attention_bias[min(offset+1, offset_max)]); + let c3 = q_element_t(attention_bias[min(offset+2, offset_max)]); + let c4 = q_element_t(attention_bias[min(offset+3, offset_max)]); + return vec4(c1,c2,c3,c4); + } + )HELPER_FN"; + } + else + { + shader.AdditionalImplementation() << R"HELPER_FN( + fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 + { + return vec4(0); + } + )HELPER_FN"; + } + // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / workgroup_size_x, 1) // Each lane/thread is responsible for a single q. shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; + let capped_sg_id = min(sg_id, max_k_step); + let capped_sg_size = min(sg_size, max_k_step); // Load Q let q_idx_global = workgroup_id.y * workgroup_size_x + local_idx; @@ -211,43 +229,68 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; - for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=k_step) + for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=capped_sg_size) { workgroupBarrier(); - loadk(k_start, head_idx, local_idx); - loadv(k_start, head_idx, local_idx); + loadk(k_start, head_idx, local_idx, capped_sg_size); + loadv(k_start, head_idx, local_idx, capped_sg_size); workgroupBarrier(); // Compute QKt var qk_1:vec4 = loadAttentionBias(q_idx_global, k_start, head_idx); var qk_2:vec4 = loadAttentionBias(q_idx_global, k_start+4, head_idx); - var qk_3:vec4 = loadAttentionBias(q_idx_global, k_start+8, head_idx); - var qk_4:vec4 = loadAttentionBias(q_idx_global, k_start+12, head_idx); - for (var i:u32 = 0u; i < qkv_head_size_vec; i++) + var qk_3:vec4; + var qk_4:vec4; + if (sg_size > 8) + { + qk_3 = loadAttentionBias(q_idx_global, k_start+8, head_idx); + qk_4 = loadAttentionBias(q_idx_global, k_start+12, head_idx); + for (var i:u32 = 0u; i < qkv_head_size_vec; i++) + { + var k_local = k_tile[capped_sg_id][i]; + var q_own = q_tile[i]; + qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); + qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); + qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); + qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); + qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); + qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); + qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); + qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); + qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8)); + qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9)); + qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10)); + qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11)); + qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12)); + qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13)); + qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14)); + qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15)); + } + } + else { - var k_local = k_tile[sg_id][i]; - var q_own = q_tile[i]; - qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); - qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); - qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); - qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); - qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); - qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); - qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); - qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); - qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8)); - qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9)); - qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10)); - qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11)); - qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12)); - qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13)); - qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14)); - qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15)); + for (var i:u32 = 0u; i < qkv_head_size_vec; i++) + { + var k_local = k_tile[capped_sg_id][i]; + var q_own = q_tile[i]; + qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); + qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); + qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); + qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); + qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); + qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); + qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); + qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); + } } + qk_1 = qk_1 * q_element_t(uniforms.alpha); qk_2 = qk_2 * q_element_t(uniforms.alpha); - qk_3 = qk_3 * q_element_t(uniforms.alpha); - qk_4 = qk_4 * q_element_t(uniforms.alpha); + if (sg_size > 8) + { + qk_3 = qk_3 * q_element_t(uniforms.alpha); + qk_4 = qk_4 * q_element_t(uniforms.alpha); + } // Neuter qk values where K is out of bounds. qk_1[1] = select(min_value, qk_1[1], k_start+1 < uniforms.present_sequence_length); @@ -257,14 +300,17 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { qk_2[1] = select(min_value, qk_2[1], k_start+5 < uniforms.present_sequence_length); qk_2[2] = select(min_value, qk_2[2], k_start+6 < uniforms.present_sequence_length); qk_2[3] = select(min_value, qk_2[3], k_start+7 < uniforms.present_sequence_length); - qk_3[0] = select(min_value, qk_3[0], k_start+8 < uniforms.present_sequence_length); - qk_3[1] = select(min_value, qk_3[1], k_start+9 < uniforms.present_sequence_length); - qk_3[2] = select(min_value, qk_3[2], k_start+10 < uniforms.present_sequence_length); - qk_3[3] = select(min_value, qk_3[3], k_start+11 < uniforms.present_sequence_length); - qk_4[0] = select(min_value, qk_4[0], k_start+12 < uniforms.present_sequence_length); - qk_4[1] = select(min_value, qk_4[1], k_start+13 < uniforms.present_sequence_length); - qk_4[2] = select(min_value, qk_4[2], k_start+14 < uniforms.present_sequence_length); - qk_4[3] = select(min_value, qk_4[3], k_start+15 < uniforms.present_sequence_length); + if (sg_size > 8) + { + qk_3[0] = select(min_value, qk_3[0], k_start+8 < uniforms.present_sequence_length); + qk_3[1] = select(min_value, qk_3[1], k_start+9 < uniforms.present_sequence_length); + qk_3[2] = select(min_value, qk_3[2], k_start+10 < uniforms.present_sequence_length); + qk_3[3] = select(min_value, qk_3[3], k_start+11 < uniforms.present_sequence_length); + qk_4[0] = select(min_value, qk_4[0], k_start+12 < uniforms.present_sequence_length); + qk_4[1] = select(min_value, qk_4[1], k_start+13 < uniforms.present_sequence_length); + qk_4[2] = select(min_value, qk_4[2], k_start+14 < uniforms.present_sequence_length); + qk_4[3] = select(min_value, qk_4[3], k_start+15 < uniforms.present_sequence_length); + } // // Compute SoftMax as per Flash Attention technique. @@ -290,14 +336,19 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i // var local_max_temp = max(qk_1, qk_2); - local_max_temp = max(local_max_temp, qk_3); - local_max_temp = max(local_max_temp, qk_4); + if (sg_size > 8) + { + local_max_temp = max(local_max_temp, qk_3); + local_max_temp = max(local_max_temp, qk_4); + } let local_max = max(max(local_max_temp.x, local_max_temp.y),max(local_max_temp.z, local_max_temp.w)); let new_max = max(previous_max, local_max); qk_1 = q_value_t(exp(vec4(qk_1) - f32(new_max))); qk_2 = q_value_t(exp(vec4(qk_2) - f32(new_max))); - qk_3 = q_value_t(exp(vec4(qk_3) - f32(new_max))); - qk_4 = q_value_t(exp(vec4(qk_4) - f32(new_max))); + if (sg_size > 8) { + qk_3 = q_value_t(exp(vec4(qk_3) - f32(new_max))); + qk_4 = q_value_t(exp(vec4(qk_4) - f32(new_max))); + } let sum_vec = qk_1 + qk_2 + qk_3 + qk_4; let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w; // Compute lhs term of update di prime and the compute di prime. @@ -306,32 +357,52 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { d = select(d,q_element_t(0.0000001h),d==0); qk_1 = qk_1 / d; qk_2 = qk_2 / d; - qk_3 = qk_3 / d; - qk_4 = qk_4 / d; + if (sg_size > 8) { + qk_3 = qk_3 / d; + qk_4 = qk_4 / d; + } previous_max = new_max; previous_denom = d; let o_ratio = dleft / d; - for (var i:u32 = 0; i < qkv_head_size_vec; i++) + if (sg_size > 8) { + for (var i:u32 = 0; i < qkv_head_size_vec; i++) + { + var val = select(vec4(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < uniforms.present_sequence_length); + var sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + sum += subgroupShuffle(val, 8) * qk_3[0]; + sum += subgroupShuffle(val, 9) * qk_3[1]; + sum += subgroupShuffle(val, 10) * qk_3[2]; + sum += subgroupShuffle(val, 11) * qk_3[3]; + sum += subgroupShuffle(val, 12) * qk_4[0]; + sum += subgroupShuffle(val, 13) * qk_4[1]; + sum += subgroupShuffle(val, 14) * qk_4[2]; + sum += subgroupShuffle(val, 15) * qk_4[3]; + o_tile[i] = o_tile[i] * o_ratio + sum; + } + } + else { - var val = select(vec4(0), v_tile[sg_id][i], k_start + sg_id < uniforms.present_sequence_length); - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - sum += subgroupShuffle(val, 8) * qk_3[0]; - sum += subgroupShuffle(val, 9) * qk_3[1]; - sum += subgroupShuffle(val, 10) * qk_3[2]; - sum += subgroupShuffle(val, 11) * qk_3[3]; - sum += subgroupShuffle(val, 12) * qk_4[0]; - sum += subgroupShuffle(val, 13) * qk_4[1]; - sum += subgroupShuffle(val, 14) * qk_4[2]; - sum += subgroupShuffle(val, 15) * qk_4[3]; - o_tile[i] = o_tile[i] * o_ratio + sum; + for (var i:u32 = 0; i < qkv_head_size_vec; i++) + { + var val = select(vec4(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < uniforms.present_sequence_length); + var sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + o_tile[i] = o_tile[i] * o_ratio + sum; + } } } @@ -362,7 +433,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co std::to_string(parameters.head_size_) + std::to_string(parameters.num_heads_); program.SetDispatchGroupSize(parameters.num_heads_, (parameters.sequence_length_ + tile_size - 1) / tile_size, 1) - .SetWorkgroupSize(64) + .SetWorkgroupSize(tile_size) .CacheHint(cache_hint) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, @@ -372,19 +443,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co } bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - // The min subgroup size affects the block size while going through the sequence length. - // 16 is the smallest size tested, smaller sized would impact performance. - // Checking for this also ensures that we dont run flash attention where subgroup is not supported. - constexpr int kMinSupportedSubgroupSize = 16; - // Workgroup size is set to be (subgroup_size * subgroup_size), check that it is allowed. - // Flash attention is written only to support batch_size of 1, algorithm can be extended to support - // batch_size > 1. What bias is used for is not clear, so it is not implemented in the shader. - // The Flash attention implementation is vectorized, to keep things simple, only vec4 is implemented - - // this implies that head_size has to be a multiple of 4. - return context.DeviceLimits().maxComputeWorkgroupSizeX >= (kMinSupportedSubgroupSize * kMinSupportedSubgroupSize) && - parameters.batch_size_ == 1 && + const WebgpuAttentionParameters& parameters) { + return parameters.batch_size_ == 1 && bias == nullptr && + parameters.sequence_length_ > 1 && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size_ % 4 == 0; } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index c80b4d3f3fc39..801ae42d864e6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -61,7 +61,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); + const WebgpuAttentionParameters& parameters); } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index ffa0f56ca126b..0caac58247ece 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -75,7 +75,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); - if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) { + if (CanApplyFlashAttention(bias, present_key, present_value, parameters)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); } From 362e969d4493af2a05e20735601a1e9a6e919840 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 3 Feb 2025 20:39:57 -0800 Subject: [PATCH 07/11] lint runner --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 6 ++---- onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index a92b00bd7bf88..98f0263cd1041 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -120,7 +120,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("output", ShaderUsage::UseUniform); shader.AdditionalImplementation() << "const qkv_head_size: u32 = " << qkv_head_size_ << ";\n" - << "const num_heads: u32 =" << qkv_num_heads_ << ";\n"; + << "const num_heads: u32 =" << qkv_num_heads_ << ";\n"; shader.AdditionalImplementation() << R"HELPER_FN( // For max performance max_k_step should be the same as sg_size, however we might run out of registers @@ -200,9 +200,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { return vec4(c1,c2,c3,c4); } )HELPER_FN"; - } - else - { + } else { shader.AdditionalImplementation() << R"HELPER_FN( fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 { diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 0caac58247ece..8b8f0bd71917e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -76,8 +76,8 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_shape); if (CanApplyFlashAttention(bias, present_key, present_value, parameters)) { - return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context); + return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context); } TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, From 4ab90c31b1e782f120a8604f0b9884262bb3cd86 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 5 Feb 2025 17:33:37 -0800 Subject: [PATCH 08/11] Remove components --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 2 +- onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 98f0263cd1041..a349a1cba35d4 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -73,7 +73,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt // number of input buffers in the shader, which we run out of (<=8) without this optimization. const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); bool has_past = (past_sequence_length != 0); - CopyKVCacheProgram program{"CopyKVCache", components, has_past}; + CopyKVCacheProgram program{"CopyKVCache", has_past}; if (has_past) { program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 801ae42d864e6..6b902651c2d6d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -17,8 +17,8 @@ using namespace onnxruntime::webgpu; class CopyKVCacheProgram final : public Program { public: - CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past) - : Program{kernel_name}, components_(components), has_past_(has_past) { + CopyKVCacheProgram(const std::string& kernel_name, bool has_past) + : Program{kernel_name}, has_past_(has_past) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -28,7 +28,6 @@ class CopyKVCacheProgram final : public Program { {"vectorized_head_size", ProgramUniformVariableDataType::Uint32}); private: - int components_; bool has_past_; }; From 635cd2198bf169b7d10b4b371b4c48236354c3bc Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 5 Feb 2025 20:14:01 -0800 Subject: [PATCH 09/11] remove half float notation from constants --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index a349a1cba35d4..6ca765be36be8 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -128,7 +128,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { const max_k_step: u32 = 16u; const vec_factor: u32 = 4u; const qkv_head_size_vec: u32 = qkv_head_size / vec_factor; - const min_value : q_element_t = q_element_t(-65504.0h); + const min_value : q_element_t = q_element_t(-65504.0); // Default SHM usage limit is 16KB in Dawn. var k_tile : array, max_k_step>; // 96 * 2 * 16 = 3KB. @@ -352,7 +352,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Compute lhs term of update di prime and the compute di prime. let dleft = previous_denom * exp(previous_max-new_max); var d = dleft + sum; - d = select(d,q_element_t(0.0000001h),d==0); + d = select(d,q_element_t(0.0000001),d==0); qk_1 = qk_1 / d; qk_2 = qk_2 / d; if (sg_size > 8) { From f289c64420cd74514d426caf5d9625f783ebde5f Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 5 Feb 2025 23:52:02 -0800 Subject: [PATCH 10/11] Fix Attention bias. --- .../contrib_ops/webgpu/bert/flash_attention.cc | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 6ca765be36be8..4092f678f94c9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -235,14 +235,12 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { workgroupBarrier(); // Compute QKt - var qk_1:vec4 = loadAttentionBias(q_idx_global, k_start, head_idx); - var qk_2:vec4 = loadAttentionBias(q_idx_global, k_start+4, head_idx); + var qk_1:vec4; + var qk_2:vec4; var qk_3:vec4; var qk_4:vec4; if (sg_size > 8) { - qk_3 = loadAttentionBias(q_idx_global, k_start+8, head_idx); - qk_4 = loadAttentionBias(q_idx_global, k_start+12, head_idx); for (var i:u32 = 0u; i < qkv_head_size_vec; i++) { var k_local = k_tile[capped_sg_id][i]; @@ -282,12 +280,12 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { } } - qk_1 = qk_1 * q_element_t(uniforms.alpha); - qk_2 = qk_2 * q_element_t(uniforms.alpha); + qk_1 = qk_1 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start, head_idx); + qk_2 = qk_2 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+4, head_idx); if (sg_size > 8) { - qk_3 = qk_3 * q_element_t(uniforms.alpha); - qk_4 = qk_4 * q_element_t(uniforms.alpha); + qk_3 = qk_3 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+8, head_idx); + qk_4 = qk_4 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+12, head_idx); } // Neuter qk values where K is out of bounds. From a5fd8a67e0a3438cda6a972936372e1502035226 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Thu, 6 Feb 2025 09:53:15 -0800 Subject: [PATCH 11/11] exclude fa on devices without subgroups --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 3 ++- onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | 2 +- onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 4092f678f94c9..b51c2fbe27e1d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -439,10 +439,11 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co } bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, - const WebgpuAttentionParameters& parameters) { + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { return parameters.batch_size_ == 1 && bias == nullptr && parameters.sequence_length_ > 1 && + context.Device().HasFeature(wgpu::FeatureName::Subgroups) && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size_ % 4 == 0; } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 6b902651c2d6d..489ae7375ecc3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -60,7 +60,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, - const WebgpuAttentionParameters& parameters); + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 8b8f0bd71917e..72931a7310a75 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -75,7 +75,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); - if (CanApplyFlashAttention(bias, present_key, present_value, parameters)) { + if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); }