diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 329db75316e82..52fcc39ae5418 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + x_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } if (input_defs.size() >= 4) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + w_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } output = model_builder.GetBuilder().call("conv2dInteger", input, x_zero_point, filter, w_zero_point, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index 5434194a214ac..9bb930c63b009 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -59,22 +59,14 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector mask_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape"); std::vector dims = GetVecUint32FromVecInt64(mask_shape); - - emscripten::val desc = emscripten::val::object(); - desc.set("dataType", "uint8"); - desc.set("dimensions", emscripten::val::array(dims)); - desc.set("shape", emscripten::val::array(dims)); - const auto num_elements = narrow(Product(mask_shape)); - emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements); - ones_buffer.call("fill", 1); - - emscripten::val mask_output = model_builder.GetBuilder().call("constant", desc, ones_buffer); + emscripten::val one_constant = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_BOOL, 1, dims); emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, // beacuse WebNN does not support a constant operand as output. - mask_output = model_builder.GetBuilder().call("identity", mask_output, options); + emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 1477530ce1894..252d49a2f4d4d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + a_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } if (input_defs.size() >= 4) { b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + b_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } output = model_builder.GetBuilder().call("matmulInteger", a, diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc index bdd1283c720f3..19f6d6aff8f97 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -29,7 +29,8 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - const auto input_data_type = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t input_data_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type"); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); const auto node_name = node.Name(); emscripten::val wnn_builder = model_builder.GetBuilder(); @@ -42,10 +43,10 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Prepare WebNN constants for alpha, beta, bias attributes. // Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'. - emscripten::val alpha_constant = model_builder.CreateOrGetScalarConstant(input_data_type, alpha); - emscripten::val beta_constant = model_builder.CreateOrGetScalarConstant(input_data_type, beta); - emscripten::val bias_constant = model_builder.CreateOrGetScalarConstant(input_data_type, bias); - emscripten::val pow1_constant = model_builder.CreateOrGetScalarConstant(input_data_type, 2); + emscripten::val alpha_constant = model_builder.CreateOrGetConstant(input_data_type, alpha); + emscripten::val beta_constant = model_builder.CreateOrGetConstant(input_data_type, beta); + emscripten::val bias_constant = model_builder.CreateOrGetConstant(input_data_type, bias); + emscripten::val pow1_constant = model_builder.CreateOrGetConstant(input_data_type, 2); /** WebNN doesn't support LRN. So decompose it into a series of ops: diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index fa82c2f85f0d8..79ed0393e3044 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -100,7 +100,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul ^ ^ ^ ^ ^ | | | | | - Y:2 axis B:epsilon A:X A:scale + Y:2 axis B:epsilon A:X A:scale */ int32_t input_type; @@ -108,13 +108,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder emscripten::val common_options = emscripten::val::object(); // Pow - emscripten::val pow_constant_desc = emscripten::val::object(); - ORT_RETURN_IF_NOT(SetWebnnDataType(pow_constant_desc, input_type), "Unsupported data type"); - pow_constant_desc.set("shape", emscripten::val::array()); - emscripten::val pow_buffer = emscripten::val::global("Float32Array").new_(1); - pow_buffer.set(0, 2); - emscripten::val pow_constant = - model_builder.GetBuilder().call("constant", pow_constant_desc, pow_buffer); + emscripten::val pow_constant = model_builder.CreateOrGetConstant(input_type, 2); common_options.set("label", node.Name() + "_pow"); emscripten::val pow = model_builder.GetBuilder().call("pow", input, pow_constant, common_options); @@ -127,13 +121,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder emscripten::val reduce_mean = model_builder.GetBuilder().call("reduceMean", pow, reduce_options); // Add - emscripten::val add_constant_desc = emscripten::val::object(); - ORT_RETURN_IF_NOT(SetWebnnDataType(add_constant_desc, input_type), "Unsupported data type"); - add_constant_desc.set("shape", emscripten::val::array()); - emscripten::val add_buffer = emscripten::val::global("Float32Array").new_(1); - add_buffer.set(0, epsilon); - emscripten::val add_constant = - model_builder.GetBuilder().call("constant", add_constant_desc, add_buffer); + emscripten::val add_constant = model_builder.CreateOrGetConstant(input_type, epsilon); common_options.set("label", node.Name() + "_add"); emscripten::val add = model_builder.GetBuilder().call("add", reduce_mean, add_constant, common_options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index 88fb79b146cd9..ca15e123d0999 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -100,7 +100,10 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // zero_point has the same shape as the scale tensor. zero_point_shape = GetVecUint32FromVecInt64(scale_shape); } - zero_point = model_builder.GetZeroConstant(zero_point_type, zero_point_shape); + // Create a zero constant with the same shape as the scale tensor. + // The zero value has been pre-processed in the CreateOrGetConstant function, + // so the type of T is not relevant here. + zero_point = model_builder.CreateOrGetConstant(zero_point_type, 0, zero_point_shape); } emscripten::val options = emscripten::val::object(); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 8a82fce42189d..e8f116d390199 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -14,7 +14,6 @@ #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" -#include #include namespace onnxruntime { @@ -385,73 +384,6 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op wnn_operands_.insert(std::make_pair(name, operand)); } -// Get the zero constant with shape. -const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type, - const std::vector& shape) { - std::string name = "webnn_zero_constant_" + std::to_string(data_type); - emscripten::val dims = emscripten::val::array(); - if (!shape.empty()) { - dims = emscripten::val::array(shape); - std::ostringstream name_stream; - name_stream << name; - for (const auto& dim : shape) { - name_stream << "_" << dim; - } - name = name_stream.str(); - } - // If the operand does not exist, create it. - if (wnn_operands_.find(name) == wnn_operands_.end()) { - emscripten::val desc = emscripten::val::object(); - desc.set("dimensions", dims); - desc.set("shape", dims); - emscripten::val zero_buffer = emscripten::val::undefined(); - if (!SetWebnnDataType(desc, data_type)) { - ORT_THROW("Unsupported data type: " + std::to_string(data_type)); - } - auto num_elements = Product(shape); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_INT4: - case ONNX_NAMESPACE::TensorProto_DataType_UINT4: - // For WebNN int4 and uint4 tensors are stored in Uint8Array, - // so we need to adjust the number of elements. - num_elements = (num_elements + 1) / 2; - zero_buffer = emscripten::val::global("Uint8Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - zero_buffer = emscripten::val::global("Uint8Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - zero_buffer = emscripten::val::global("Int8Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - zero_buffer = emscripten::val::global("Uint16Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - zero_buffer = emscripten::val::global("Float32Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - zero_buffer = emscripten::val::global("Int32Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - zero_buffer = emscripten::val::global("BigInt64Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - zero_buffer = emscripten::val::global("Uint32Array").new_(num_elements); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - zero_buffer = emscripten::val::global("BigUint64Array").new_(num_elements); - break; - default: - break; - } - - emscripten::val zero_constant = wnn_builder_.call("constant", desc, zero_buffer); - wnn_operands_.insert(std::make_pair(name, zero_constant)); - } - return wnn_operands_.at(name); -} - void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { skipped_initializers_.insert(tensor_name); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index c482e9d05b301..0fc2fa20670c7 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -11,6 +11,7 @@ #include "core/framework/execution_provider.h" #include "core/providers/webnn/builders/helper.h" +#include #include #include @@ -38,11 +39,10 @@ class ModelBuilder { const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } void AddOperand(const std::string& name, const emscripten::val& operand); - const emscripten::val& GetZeroConstant( - const int32_t& data_type, const std::vector& shape = {}); template - const emscripten::val& CreateOrGetScalarConstant(const int32_t& data_type, T value); + const emscripten::val& CreateOrGetConstant(const int32_t& data_type, T value, + const std::vector& shape = {}); // Use the buffers to persist WebNN allocated data like transposed weight. // It ensures the validity during inference session. @@ -103,11 +103,12 @@ class ModelBuilder { static const IOpBuilder* GetOpBuilder(const Node& node); }; -// Create a scalar constant MLOperand of the specified value and data type. -// Workaround for builer.constant(type, value) method since it has not been implemented now. +// Create or retrieve one of the following: +// - A WebNN constant MLOperand filled with the specified value, data type, and shape. +// - A WebNN scalar constant MLOperand with the specified value and data type. +// For scalar constant, it is workaround for builer.constant(type, value) method since +// it has not been implemented now. // https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value -// BTW, the spec is discussing if the builder.constant(type, value) should be dropped at -// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision. // // This function enforces a mapping between the data_type and the value types: // - TensorProto_DataType_INT4 <-> int8_t @@ -122,69 +123,96 @@ class ModelBuilder { // - TensorProto_DataType_UINT32 <-> uint32_t // - TensorProto_DataType_UINT64 <-> uint64_t template -const emscripten::val& ModelBuilder::CreateOrGetScalarConstant(const int32_t& data_type, T value) { - std::string name = "webnn_scalar_constant_" + std::to_string(data_type) + "_" + std::to_string(value); - emscripten::val desc = emscripten::val::object(); - desc.set("shape", emscripten::val::array()); - emscripten::val scalar_buffer = emscripten::val::undefined(); - uint16_t value_uint16 = 0; - uint8_t value_uint8 = 0; - if (!SetWebnnDataType(desc, data_type)) { - ORT_THROW("Unsupported data type: " + std::to_string(data_type)); +const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_type, T value, + const std::vector& shape) { + std::string name = "webnn_constant_" + std::to_string(data_type) + "_" + std::to_string(value); + emscripten::val dims = emscripten::val::array(); + if (!shape.empty()) { + dims = emscripten::val::array(shape); + std::ostringstream name_stream; + name_stream << name; + for (const auto& dim : shape) { + name_stream << "_" << dim; + } + name = name_stream.str(); } // If the operand does not exist, create it. if (wnn_operands_.find(name) == wnn_operands_.end()) { + emscripten::val desc = emscripten::val::object(); + desc.set("shape", dims); + desc.set("dimensions", dims); + emscripten::val buffer = emscripten::val::undefined(); + if (!SetWebnnDataType(desc, data_type)) { + ORT_THROW("Unsupported data type: " + std::to_string(data_type)); + } + auto num_elements = Product(shape); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType_UINT4: - scalar_buffer = emscripten::val::global("Uint8Array").new_(1); - value_uint8 = PackInt8ToUint8AsNibble(value, data_type); - scalar_buffer.call("fill", emscripten::val(value_uint8)); + // For WebNN int4 and uint4 tensors are stored in Uint8Array, + // so we need to adjust the number of elements. + num_elements = (num_elements + 1) / 2; + buffer = emscripten::val::global("Uint8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(PackInt8ToUint8AsNibble(value, data_type))); + } break; case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - scalar_buffer = emscripten::val::global("Uint8Array").new_(1); - scalar_buffer.call("fill", emscripten::val(value ? 1 : 0)); - break; case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - scalar_buffer = emscripten::val::global("Uint8Array").new_(1); - scalar_buffer.call("fill", emscripten::val(value)); + buffer = emscripten::val::global("Uint8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } break; case ONNX_NAMESPACE::TensorProto_DataType_INT8: - scalar_buffer = emscripten::val::global("Int8Array").new_(1); - scalar_buffer.call("fill", emscripten::val(value)); + buffer = emscripten::val::global("Int8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - scalar_buffer = emscripten::val::global("Uint16Array").new_(1); - value_uint16 = PackFloat32ToUint16AsFloat16(value); - scalar_buffer.call("fill", emscripten::val(value_uint16)); + buffer = emscripten::val::global("Uint16Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value))); + } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - scalar_buffer = emscripten::val::global("Float32Array").new_(1); - scalar_buffer.call("fill", emscripten::val(value)); + buffer = emscripten::val::global("Float32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - scalar_buffer = emscripten::val::global("Int32Array").new_(1); - scalar_buffer.call("fill", emscripten::val(value)); + buffer = emscripten::val::global("Int32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } break; case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - scalar_buffer = emscripten::val::global("Uint32Array").new_(1); - scalar_buffer.call("fill", emscripten::val(value)); + buffer = emscripten::val::global("Uint32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: - scalar_buffer = emscripten::val::global("BigInt64Array").new_(1); - scalar_buffer.call("fill", emscripten::val::global("BigInt")(value)); + buffer = emscripten::val::global("BigInt64Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val::global("BigInt")(value)); + } break; case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - scalar_buffer = emscripten::val::global("BigUint64Array").new_(1); - scalar_buffer.call("fill", emscripten::val::global("BigInt")(value)); + buffer = emscripten::val::global("BigUint64Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val::global("BigInt")(value)); + } break; default: break; } - const emscripten::val scalar_constant = wnn_builder_.call("constant", desc, scalar_buffer); - wnn_operands_.insert(std::make_pair(name, scalar_constant)); + const emscripten::val constant = wnn_builder_.call("constant", desc, buffer); + wnn_operands_.insert(std::make_pair(name, constant)); } return wnn_operands_.at(name);