diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 3cbe3ce5a631a..301d795912443 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -97,6 +97,7 @@ static const InlinedHashMap op_map = { {"Erf", "erf"}, {"Not", "logicalNot"}, {"Floor", "floor"}, + {"Flatten", "flattenTo2d"}, {"Sin", "sin"}, {"Sqrt", "sqrt"}, {"Relu", "relu"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc new file mode 100644 index 0000000000000..6c59ca451f333 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class FlattenOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; +}; + +// Add operator related. + +Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF(input_defs.size() < 1, "Flatten has no input tensor"); + if (!GetShape(*input_defs[0], input_shape, logger)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "FlattenOpBuilder::AddToModelBuilderImpl, cannot get input shape"); + } + int64_t rank = input_shape.size(); + NodeAttrHelper helper(node); + int64_t axis = helper.Get("axis", 1); + ORT_ENFORCE(axis >= -rank && axis <= rank, "axis ", axis, + " is not in valid range [-", rank, ",", rank, "]"); + if (axis < 0) { + axis += rank; + } + emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val output = model_builder.GetBuilder().call("flattenTo2d", inputs, + static_cast(axis)); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index a43fe4a41b92d..82a45719ec1c3 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -68,6 +68,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGatherOpBuilder("Gather", op_registrations); } + { // Flatten + CreateFlattenOpBuilder("Flatten", op_registrations); + } + { // Gemm/MatMul CreateGemmOpBuilder("Gemm", op_registrations); CreateGemmOpBuilder("MatMul", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index b8cf7322f2fe9..8f8299e5138a4 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -27,6 +27,7 @@ void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);