Skip to content

Commit

Permalink
Revert "Reapply "Calculate output chunk size based on whether the ker…
Browse files Browse the repository at this point in the history
…nel is GQA or not.""

This reverts commit b494c73.
  • Loading branch information
satyajandhyala committed Dec 2, 2024
1 parent f209a38 commit 1aff4d4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
11 changes: 5 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "}\n";

shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
<< " let headOffset = workgroup_id.z * uniforms.M * " << (is_gqa_ ? "uniforms.present_sequence_length" : "uniforms.N") << ";\n"
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";

Expand Down Expand Up @@ -181,7 +181,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
Expand Down Expand Up @@ -375,9 +375,8 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {

shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
<< "if (m < uniforms.M && n < uniforms.N) {\n"
<< " let tmp = " << (is_gqa_ ? "uniforms.num_heads * uniforms.present_sequence_length" : "uniforms.v_hidden_size") << ";\n"
<< " let outputIdx = batch_idx * uniforms.M * tmp + "
<< " m * tmp + head_idx * uniforms.N + n;\n"
<< " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + "
<< " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n"
<< " output[outputIdx] = value;\n"
<< "}\n";

Expand All @@ -398,7 +397,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
const bool has_present_value = output_count > 1 && past_value != nullptr;
constexpr int tile_size = 12;

VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand Down Expand Up @@ -63,7 +63,6 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_gqa_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
Expand All @@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)

Check warning on line 92 in onnxruntime/contrib_ops/webgpu/bert/attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/attention.h:92: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -117,7 +116,6 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_gqa_;
};

} // namespace webgpu
Expand Down

0 comments on commit 1aff4d4

Please sign in to comment.