Skip to content

Commit

Permalink
[WebGPU EP] Batch Norm Implementation (#23525)
Browse files Browse the repository at this point in the history
Increases operator coverage for webgpu ep.

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and ashrit-ms committed Feb 11, 2025
1 parent 9742b37 commit 7e06740
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 11 deletions.
138 changes: 138 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/batch_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/nn/batch_norm.h"
#include "core/providers/cpu/nn/batch_norm_helper.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

#define WEBGPU_BATCH_NORM_VERSIONED_KERNEL(start, end, domain, is_nhwc) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
BatchNormalization, \
domain, \
start, \
end, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
BatchNormalization<is_nhwc>);

#define WEBGPU_BATCH_NORM_KERNEL(version, domain, is_nhwc) \
ONNX_OPERATOR_KERNEL_EX( \
BatchNormalization, \
domain, \
version, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
BatchNormalization<is_nhwc>);

WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kOnnxDomain, false)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kOnnxDomain, false)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kOnnxDomain, false)
WEBGPU_BATCH_NORM_KERNEL(15, kOnnxDomain, false)

WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kMSInternalNHWCDomain, true)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kMSInternalNHWCDomain, true)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kMSInternalNHWCDomain, true)
WEBGPU_BATCH_NORM_KERNEL(15, kMSInternalNHWCDomain, true)

Status BatchNormalizationProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& input_tensor = shader.AddInput("input_tensor", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& input_mean = shader.AddInput("input_mean", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& input_var = shader.AddInput("input_var", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< " let idx = global_idx * " << components_ << ";\n"
<< " var outputIndices = " << output.OffsetToIndices("idx") << ";\n";
if (spatial_) {
if (input_tensor.Rank() == 1) {
shader.MainFunctionBody() << " let cOffset = 0u;\n";
} else {
if (format_ == DataLayout::NHWC) {
shader.MainFunctionBody() << " let cOffset = outputIndices[" << input_tensor.Rank() - 1 << "] / " << components_ << ";\n";
} else {
shader.MainFunctionBody() << " let cOffset = outputIndices[1];\n";
}
}
} else {
if (format_ == DataLayout::NCHW) {
shader.MainFunctionBody() << " " << output.IndicesSet("outputIndices", "0", "0") << "\n"
<< " let cOffset = " << output.IndicesToOffset("outputIndices") << ";\n";
} else {
// update C channel
shader.MainFunctionBody() << " var cIndices = scale_indices_t(0);\n"
<< " cIndices[0] = outputIndices[" << input_tensor.Rank() - 1 << "];\n";
// update D1 x ... x Dn channels
for (int i = 1; i < scale.Rank(); i++) {
shader.MainFunctionBody() << " cIndices[" << i << "] = outputIndices[" << i << "];\n";
}
shader.MainFunctionBody() << " let cOffset = " << scale.IndicesToOffset("cIndices") << ";\n";
}
}

shader.MainFunctionBody() << " let scale = " << scale.GetByOffset("cOffset") << ";\n"
<< " let B = " << B.GetByOffset("cOffset") << ";\n"
<< " let input_mean = " << input_mean.GetByOffset("cOffset") << ";\n"
<< " let input_var = " << input_var.GetByOffset("cOffset") << ";\n"
<< " let x = " << input_tensor.GetByOffset("global_idx") << ";\n"
<< " let value = (x - input_mean) * inverseSqrt(input_var + " << epsilon_ << ") * scale + B;\n"
<< " " << output.SetByOffset("global_idx", "value") << "\n";

return Status::OK();
}

template <bool is_nhwc>
Status BatchNormalization<is_nhwc>::ComputeInternal(ComputeContext& context) const {
if (training_mode_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization trainingMode is not supported yet.");
}

if (context.InputCount() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization requires 5 inputs.");
}

const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
size_t input_rank = input_shape.NumDimensions();
const int components = spatial_ ? ((input_shape[input_rank - 1] % 4 == 0) ? 4 : ((input_shape[input_rank - 1] % 2 == 0) ? 2 : 1)) : 1;

auto output_dims = input_shape.AsShapeVector();
TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);
int64_t output_size = output_tensor->Shape().Size() / static_cast<int64_t>(components);

if (output_size == 0) {
return Status::OK();
}

const auto* scale = context.Input<Tensor>(1);
const auto* B = context.Input<Tensor>(2);
const auto* input_mean = context.Input<Tensor>(3);
const auto* input_var = context.Input<Tensor>(4);

ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(input_tensor, scale, B, input_mean, input_var, spatial_ == 1, format_ == DataLayout::NHWC));

BatchNormalizationProgram program{epsilon_, spatial_, format_, static_cast<int64_t>(components)};
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
{scale, ProgramTensorMetadataDependency::TypeAndRank},
{B, ProgramTensorMetadataDependency::TypeAndRank},
{input_mean, ProgramTensorMetadataDependency::TypeAndRank},
{input_var, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({output_tensor})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(output_size)}});
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime
54 changes: 54 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/batch_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class BatchNormalizationProgram final : public Program<BatchNormalizationProgram> {
public:
BatchNormalizationProgram(float epsilon, int64_t spatial, DataLayout format, int64_t components) : Program{"BatchNormalization"},
epsilon_{epsilon},
spatial_{spatial},
format_{format},
components_{components} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});

private:
float epsilon_;
int64_t spatial_;
DataLayout format_;
int64_t components_;
};

template <bool is_nhwc>
class BatchNormalization final : public WebGpuKernel {
public:
BatchNormalization(const OpKernelInfo& info) : WebGpuKernel(info) {
epsilon_ = info.GetAttrOrDefault<float>("epsilon", 1e-5f);
momentum_ = info.GetAttrOrDefault<float>("momentum", 0.9f);
spatial_ = info.GetAttrOrDefault<int64_t>("spatial", 1);
training_mode_ = info.GetAttrOrDefault<int64_t>("training_mode", 0);
// NCHW for ai.onnx domain, NHWC for com.ms.internal.nhwc domain
format_ = is_nhwc ? DataLayout::NHWC : DataLayout::NCHW;
}

Status ComputeInternal(ComputeContext& context) const override;

private:
float epsilon_;
float momentum_;
int64_t spatial_;
int64_t training_mode_;
DataLayout format_;
};

} // namespace webgpu
} // namespace onnxruntime
16 changes: 8 additions & 8 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -696,14 +696,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,8 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
// TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kWebGpuExecutionProvider});
}

TEST(BatchNormTest, ForwardTrainingTestOpset14) {
Expand Down Expand Up @@ -953,7 +954,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kWebGpuExecutionProvider});
}

TEST(BatchNormTest, ForwardTrainingTestOpset15) {
Expand Down Expand Up @@ -982,7 +984,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) {
// Same exclusions as the opset 14 test
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kWebGpuExecutionProvider});
}
#endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT

Expand Down

0 comments on commit 7e06740

Please sign in to comment.