Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Native WebGPU EP] Add packedQKV and do_rotary attribute support to GroupQueryAttention operator #23386

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fdd5ceb
Added GroupQuerryAttention do_rotary attribute.
satyajandhyala Jan 15, 2025
f6b0222
Added packed QKV and rotary embedding support for GQA
satyajandhyala Jan 16, 2025
ae87526
Fix lint errors.
satyajandhyala Jan 16, 2025
df90ffa
Fixed shader code compilation errors.
satyajandhyala Jan 16, 2025
0704462
more lint stuff
satyajandhyala Jan 16, 2025
177f535
Fixed shader code issues.
satyajandhyala Jan 17, 2025
0b94f10
Added split functionality to unpack packed-QKV.
satyajandhyala Jan 21, 2025
f0d238a
Removed unnecessary uniforms in GeneratePositionIdsProgram
satyajandhyala Jan 21, 2025
1009fc9
Apply split and rotrary embedding before converting input ro BSD to BNSH
satyajandhyala Jan 21, 2025
a4d8482
Fix the input_output_stride for 4-dim input.
satyajandhyala Jan 21, 2025
4d42d06
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Jan 22, 2025
e406b81
Allocate position_ids tensor size/shape even for the first prompt
satyajandhyala Feb 3, 2025
9e38efd
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 5, 2025
de0d4b0
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 12, 2025
0b08117
Fixed the input_output_strides
satyajandhyala Feb 18, 2025
a7328f5
Added is_first_first prompt to the shader that generates position ids…
satyajandhyala Feb 19, 2025
531c6e3
Fixed position_ids generation code.
satyajandhyala Feb 20, 2025
29819ed
Check is_first_prompt and is_subsequence_prompt flags in the c++ code…
satyajandhyala Feb 20, 2025
91e8801
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 20, 2025
6bbef62
lint
satyajandhyala Feb 21, 2025
ff84b7b
Removed unused variable.
satyajandhyala Feb 21, 2025
d4e4f29
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 26, 2025
e468128
Added condition to check do_rotary before call fa2
satyajandhyala Feb 26, 2025
9f2782c
typo
satyajandhyala Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {

Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (!is_packed_qkv_) {
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
}
if (feed_past_key_) {
shader.AddInput("past_key", ShaderUsage::UseUniform);
}
Expand All @@ -96,42 +98,51 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

shader.AdditionalImplementation() << "var<workgroup> tileQ: array<q_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<" << (is_packed_qkv_ ? "q_value_t" : "key_value_t") << ", " << tile_size_ * tile_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
<< "let m = workgroup_id.y * TILE_SIZE;\n"
<< "let n = workgroup_id.x * TILE_SIZE;\n"
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
if (is_packed_qkv_) {
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
<< "let kv_num_heads = uniforms.num_heads /" << n_reps_ << ";\n"
<< "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n"
<< "let qOffset = batch_idx * packed_batch_stride + head_idx * uniforms.M * uniforms.K;\n"
<< "let kvHeadIdx = head_idx % kv_num_heads;\n"
<< "let kOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K;\n";
} else {
shader.MainFunctionBody() << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
}
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n";
}

shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
" if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
" }\n"
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";
<< "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
<< " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
<< " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
<< " }\n"
<< " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";

if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
<< " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
<< " }\n";
} else {
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" }\n";
<< " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
<< " }\n";
}

if (has_present_key_) {
Expand Down Expand Up @@ -181,9 +192,11 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
components, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInput({Q, ProgramTensorMetadataDependency::TypeAndRank, components});
if (K != nullptr) {
program.AddInput({K, ProgramTensorMetadataDependency::TypeAndRank, components});
}
if (feed_past_key) {
program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components});
}
Expand All @@ -203,7 +216,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_, parameters.is_packed_qkv_)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(vectorized_head_size)},
{static_cast<uint32_t>(total_sequence_length)},
Expand Down Expand Up @@ -331,7 +344,14 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
if (is_packed_qkv_) {
shader.MainFunctionBody() << "let kv_num_heads = uniforms.num_heads / " << n_reps_ << ";\n"
<< "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n"
<< "let kvHeadIdx = head_idx % kv_num_heads;\n"
<< "let vOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length + n;\n";
} else {
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
}
if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n";
}
Expand Down Expand Up @@ -400,7 +420,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1);
constexpr int tile_size = 12;
int tile_n_size = tile_size * components;
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_value) {
Expand All @@ -417,7 +437,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_, parameters.is_packed_qkv_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(total_sequence_length)},
Expand Down Expand Up @@ -452,7 +472,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));

ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, parameters.is_packed_qkv_ ? Q : V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length, seqlen_k));

return Status::OK();
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand Down Expand Up @@ -64,6 +64,7 @@
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_packed_qkv_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
Expand All @@ -90,8 +91,8 @@

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)

Check warning on line 94 in onnxruntime/contrib_ops/webgpu/bert/attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/attention.h:94: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -118,6 +119,7 @@
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_packed_qkv_;
};

} // namespace webgpu
Expand Down
Loading
Loading