diff --git a/onnxruntime/core/providers/webgpu/nn/batch_norm.cc b/onnxruntime/core/providers/webgpu/nn/batch_norm.cc new file mode 100644 index 0000000000000..687f8cb0c684b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/batch_norm.cc @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/nn/batch_norm.h" +#include "core/providers/cpu/nn/batch_norm_helper.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define WEBGPU_BATCH_NORM_VERSIONED_KERNEL(start, end, domain, is_nhwc) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + BatchNormalization, \ + domain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + BatchNormalization); + +#define WEBGPU_BATCH_NORM_KERNEL(version, domain, is_nhwc) \ + ONNX_OPERATOR_KERNEL_EX( \ + BatchNormalization, \ + domain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + BatchNormalization); + +WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kOnnxDomain, false) +WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kOnnxDomain, false) +WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kOnnxDomain, false) +WEBGPU_BATCH_NORM_KERNEL(15, kOnnxDomain, false) + +WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kMSInternalNHWCDomain, true) +WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kMSInternalNHWCDomain, true) +WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kMSInternalNHWCDomain, true) +WEBGPU_BATCH_NORM_KERNEL(15, kMSInternalNHWCDomain, true) + +Status BatchNormalizationProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input_tensor = shader.AddInput("input_tensor", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& input_mean = shader.AddInput("input_mean", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& input_var = shader.AddInput("input_var", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let idx = global_idx * " << components_ << ";\n" + << " var outputIndices = " << output.OffsetToIndices("idx") << ";\n"; + if (spatial_) { + if (input_tensor.Rank() == 1) { + shader.MainFunctionBody() << " let cOffset = 0u;\n"; + } else { + if (format_ == DataLayout::NHWC) { + shader.MainFunctionBody() << " let cOffset = outputIndices[" << input_tensor.Rank() - 1 << "] / " << components_ << ";\n"; + } else { + shader.MainFunctionBody() << " let cOffset = outputIndices[1];\n"; + } + } + } else { + if (format_ == DataLayout::NCHW) { + shader.MainFunctionBody() << " " << output.IndicesSet("outputIndices", "0", "0") << "\n" + << " let cOffset = " << output.IndicesToOffset("outputIndices") << ";\n"; + } else { + // update C channel + shader.MainFunctionBody() << " var cIndices = scale_indices_t(0);\n" + << " cIndices[0] = outputIndices[" << input_tensor.Rank() - 1 << "];\n"; + // update D1 x ... x Dn channels + for (int i = 1; i < scale.Rank(); i++) { + shader.MainFunctionBody() << " cIndices[" << i << "] = outputIndices[" << i << "];\n"; + } + shader.MainFunctionBody() << " let cOffset = " << scale.IndicesToOffset("cIndices") << ";\n"; + } + } + + shader.MainFunctionBody() << " let scale = " << scale.GetByOffset("cOffset") << ";\n" + << " let B = " << B.GetByOffset("cOffset") << ";\n" + << " let input_mean = " << input_mean.GetByOffset("cOffset") << ";\n" + << " let input_var = " << input_var.GetByOffset("cOffset") << ";\n" + << " let x = " << input_tensor.GetByOffset("global_idx") << ";\n" + << " let value = (x - input_mean) * inverseSqrt(input_var + " << epsilon_ << ") * scale + B;\n" + << " " << output.SetByOffset("global_idx", "value") << "\n"; + + return Status::OK(); +} + +template +Status BatchNormalization::ComputeInternal(ComputeContext& context) const { + if (training_mode_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization trainingMode is not supported yet."); + } + + if (context.InputCount() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization requires 5 inputs."); + } + + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + const int components = spatial_ ? ((input_shape[input_rank - 1] % 4 == 0) ? 4 : ((input_shape[input_rank - 1] % 2 == 0) ? 2 : 1)) : 1; + + auto output_dims = input_shape.AsShapeVector(); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_tensor->Shape().Size() / static_cast(components); + + if (output_size == 0) { + return Status::OK(); + } + + const auto* scale = context.Input(1); + const auto* B = context.Input(2); + const auto* input_mean = context.Input(3); + const auto* input_var = context.Input(4); + + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(input_tensor, scale, B, input_mean, input_var, spatial_ == 1, format_ == DataLayout::NHWC)); + + BatchNormalizationProgram program{epsilon_, spatial_, format_, static_cast(components)}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {scale, ProgramTensorMetadataDependency::TypeAndRank}, + {B, ProgramTensorMetadataDependency::TypeAndRank}, + {input_mean, ProgramTensorMetadataDependency::TypeAndRank}, + {input_var, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/nn/batch_norm.h b/onnxruntime/core/providers/webgpu/nn/batch_norm.h new file mode 100644 index 0000000000000..00dc7679620fb --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/batch_norm.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class BatchNormalizationProgram final : public Program { + public: + BatchNormalizationProgram(float epsilon, int64_t spatial, DataLayout format, int64_t components) : Program{"BatchNormalization"}, + epsilon_{epsilon}, + spatial_{spatial}, + format_{format}, + components_{components} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + float epsilon_; + int64_t spatial_; + DataLayout format_; + int64_t components_; +}; + +template +class BatchNormalization final : public WebGpuKernel { + public: + BatchNormalization(const OpKernelInfo& info) : WebGpuKernel(info) { + epsilon_ = info.GetAttrOrDefault("epsilon", 1e-5f); + momentum_ = info.GetAttrOrDefault("momentum", 0.9f); + spatial_ = info.GetAttrOrDefault("spatial", 1); + training_mode_ = info.GetAttrOrDefault("training_mode", 0); + // NCHW for ai.onnx domain, NHWC for com.ms.internal.nhwc domain + format_ = is_nhwc ? DataLayout::NHWC : DataLayout::NCHW; + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + float epsilon_; + float momentum_; + int64_t spatial_; + int64_t training_mode_; + DataLayout format_; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index dec7e48786bf5..f517ef9d36458 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -696,14 +696,14 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 08c4e608aada3..f8ebca5ff9a1b 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -924,7 +924,8 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, - kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, + kWebGpuExecutionProvider}); } TEST(BatchNormTest, ForwardTrainingTestOpset14) { @@ -953,7 +954,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) { // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, - kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, + kWebGpuExecutionProvider}); } TEST(BatchNormTest, ForwardTrainingTestOpset15) { @@ -982,7 +984,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) { // Same exclusions as the opset 14 test test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, - kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, + kWebGpuExecutionProvider}); } #endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT