Skip to content

Commit

Permalink
[webgpu] Use pushErrorScope()/popErrorScope() once for an inference r…
Browse files Browse the repository at this point in the history
…un (#23438)

The CPU walltime of waiting for PopErrorScope is non-trivial, and also
validation errors are not expected to happen in Release build.

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
jchen10 authored Feb 7, 2025
1 parent 65008cb commit 0887e36
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 16 deletions.
21 changes: 5 additions & 16 deletions onnxruntime/core/providers/webgpu/compute_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,16 @@ ComputeContext::ComputeContext(OpKernelContext& kernel_context)
}

void ComputeContext::PushErrorScope() {
if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) {
webgpu_context_.Device().PushErrorScope(wgpu::ErrorFilter::Validation);
if (webgpu_context_.ValidationMode() >= ValidationMode::Full) {
webgpu_context_.PushErrorScope();
}
}

Status ComputeContext::PopErrorScope() {
Status status{};

if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) {
ORT_RETURN_IF_ERROR(webgpu_context_.Wait(
webgpu_context_.Device().PopErrorScope(
wgpu::CallbackMode::WaitAnyOnly, [](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) {
ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped.");
if (error_type == wgpu::ErrorType::NoError) {
return;
}
*status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message);
},
&status)));
if (webgpu_context_.ValidationMode() >= ValidationMode::Full) {
return webgpu_context_.PopErrorScope();
}
return status;
return Status::OK();
}

} // namespace webgpu
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,23 @@ void WebGpuContext::EndProfiling(TimePoint /* tp */, profiling::Events& events,
}
}

void WebGpuContext::PushErrorScope() { device_.PushErrorScope(wgpu::ErrorFilter::Validation); }

Status WebGpuContext::PopErrorScope() {
Status status{};
ORT_RETURN_IF_ERROR(Wait(device_.PopErrorScope(
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) {
ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped.");
if (error_type == wgpu::ErrorType::NoError) {
return;
}
*status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message);
},
&status)));
return status;
}

void WebGpuContext::Flush() {
if (!current_command_encoder_) {
return;
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ class WebGpuContext final {
void CollectProfilingData(profiling::Events& events);
void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events);

//
// Push error scope.
//
// This is useful only when "skip_validation" is not set.
//
void PushErrorScope();

//
// Pop error scope.
//
// This is useful only when "skip_validation" is not set.
//
Status PopErrorScope();

Status Run(ComputeContext& context, const ProgramBase& program);

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,10 @@ std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler() {
}

Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
if (context_.ValidationMode() >= ValidationMode::Basic) {
context_.PushErrorScope();
}

if (profiler_->Enabled()) {
context_.StartProfiling();
}
Expand All @@ -858,6 +862,10 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
context_.CollectProfilingData(profiler_->Events());
}

if (context_.ValidationMode() >= ValidationMode::Basic) {
return context_.PopErrorScope();
}

return Status::OK();
}

Expand Down

0 comments on commit 0887e36

Please sign in to comment.