From 1aff4d49463a3c0421d2acf86390d09cde587cc5 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Sun, 1 Dec 2024 21:20:52 -0800 Subject: [PATCH] Revert "Reapply "Calculate output chunk size based on whether the kernel is GQA or not."" This reverts commit b494c732515976098e9cf8ec0180960fa245c872. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 11 +++++------ onnxruntime/contrib_ops/webgpu/bert/attention.h | 10 ++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 2a839d2966976..86dc959cf2e83 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -153,7 +153,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "}\n"; shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" - << " let headOffset = workgroup_id.z * uniforms.M * " << (is_gqa_ ? "uniforms.present_sequence_length" : "uniforms.N") << ";\n" + << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; @@ -181,7 +181,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_}; + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -375,9 +375,8 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" << "if (m < uniforms.M && n < uniforms.N) {\n" - << " let tmp = " << (is_gqa_ ? "uniforms.num_heads * uniforms.present_sequence_length" : "uniforms.v_hidden_size") << ";\n" - << " let outputIdx = batch_idx * uniforms.M * tmp + " - << " m * tmp + head_idx * uniforms.N + n;\n" + << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" << " output[outputIdx] = value;\n" << "}\n"; @@ -398,7 +397,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const bool has_present_value = output_count > 1 && past_value != nullptr; constexpr int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 350f2387920f0..03279fffbc3ef 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -63,7 +63,6 @@ class AttentionProbsProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; - bool is_gqa_; }; class InPlaceSoftmaxProgram final : public Program { @@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -117,7 +116,6 @@ class VxAttentionScoreProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; - bool is_gqa_; }; } // namespace webgpu