Skip to content

Commit

Permalink
[WebNN] Improve the util function of creating WebNN constant MLOperand (
Browse files Browse the repository at this point in the history
#22935)

Merge the util functions to create or retrieve:
- A WebNN constant MLOperand filled with the specified value, data type,
and shape.
- A WebNN scalar constant MLOperand with the specified value and data
type.
  • Loading branch information
Honry authored Dec 4, 2024
1 parent fbe22fd commit cacd97d
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (input_defs.size() >= 3) {
x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
x_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
if (input_defs.size() >= 4) {
w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
w_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
output = model_builder.GetBuilder().call<emscripten::val>("conv2dInteger",
input, x_zero_point, filter, w_zero_point, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,14 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::vector<int64_t> mask_shape;
ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape");
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(mask_shape);

emscripten::val desc = emscripten::val::object();
desc.set("dataType", "uint8");
desc.set("dimensions", emscripten::val::array(dims));
desc.set("shape", emscripten::val::array(dims));
const auto num_elements = narrow<uint32_t>(Product(mask_shape));
emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
ones_buffer.call<void>("fill", 1);

emscripten::val mask_output = model_builder.GetBuilder().call<emscripten::val>("constant", desc, ones_buffer);
emscripten::val one_constant = model_builder.CreateOrGetConstant<uint8_t>(
ONNX_NAMESPACE::TensorProto_DataType_BOOL, 1, dims);

emscripten::val options = emscripten::val::object();
options.set("label", output_defs[1]->Name() + "_identity");
// Add additional identity op in case the mask is the output of a WebNN graph,
// beacuse WebNN does not support a constant operand as output.
mask_output = model_builder.GetBuilder().call<emscripten::val>("identity", mask_output, options);
emscripten::val mask_output = model_builder.GetBuilder().call<emscripten::val>("identity", one_constant, options);
model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output));
}
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (input_defs.size() >= 3) {
a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
a_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
if (input_defs.size() >= 4) {
b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
b_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
output = model_builder.GetBuilder().call<emscripten::val>("matmulInteger",
a,
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto input_data_type = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
int32_t input_data_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type");
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
const auto node_name = node.Name();
emscripten::val wnn_builder = model_builder.GetBuilder();
Expand All @@ -42,10 +43,10 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

// Prepare WebNN constants for alpha, beta, bias attributes.
// Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'.
emscripten::val alpha_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, alpha);
emscripten::val beta_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, beta);
emscripten::val bias_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, bias);
emscripten::val pow1_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, 2);
emscripten::val alpha_constant = model_builder.CreateOrGetConstant<float>(input_data_type, alpha);
emscripten::val beta_constant = model_builder.CreateOrGetConstant<float>(input_data_type, beta);
emscripten::val bias_constant = model_builder.CreateOrGetConstant<float>(input_data_type, bias);
emscripten::val pow1_constant = model_builder.CreateOrGetConstant<float>(input_data_type, 2);

/**
WebNN doesn't support LRN. So decompose it into a series of ops:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,15 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul
^ ^ ^ ^ ^
| | | | |
Y:2 axis B:epsilon A:X A:scale
Y:2 axis B:epsilon A:X A:scale
*/

int32_t input_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type");
emscripten::val common_options = emscripten::val::object();

// Pow
emscripten::val pow_constant_desc = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(pow_constant_desc, input_type), "Unsupported data type");
pow_constant_desc.set("shape", emscripten::val::array());
emscripten::val pow_buffer = emscripten::val::global("Float32Array").new_(1);
pow_buffer.set(0, 2);
emscripten::val pow_constant =
model_builder.GetBuilder().call<emscripten::val>("constant", pow_constant_desc, pow_buffer);
emscripten::val pow_constant = model_builder.CreateOrGetConstant<float>(input_type, 2);
common_options.set("label", node.Name() + "_pow");
emscripten::val pow =
model_builder.GetBuilder().call<emscripten::val>("pow", input, pow_constant, common_options);
Expand All @@ -127,13 +121,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
emscripten::val reduce_mean = model_builder.GetBuilder().call<emscripten::val>("reduceMean", pow, reduce_options);

// Add
emscripten::val add_constant_desc = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(add_constant_desc, input_type), "Unsupported data type");
add_constant_desc.set("shape", emscripten::val::array());
emscripten::val add_buffer = emscripten::val::global("Float32Array").new_(1);
add_buffer.set(0, epsilon);
emscripten::val add_constant =
model_builder.GetBuilder().call<emscripten::val>("constant", add_constant_desc, add_buffer);
emscripten::val add_constant = model_builder.CreateOrGetConstant<float>(input_type, epsilon);
common_options.set("label", node.Name() + "_add");
emscripten::val add =
model_builder.GetBuilder().call<emscripten::val>("add", reduce_mean, add_constant, common_options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// zero_point has the same shape as the scale tensor.
zero_point_shape = GetVecUint32FromVecInt64(scale_shape);
}
zero_point = model_builder.GetZeroConstant(zero_point_type, zero_point_shape);
// Create a zero constant with the same shape as the scale tensor.
// The zero value has been pre-processed in the CreateOrGetConstant function,
// so the type of T is not relevant here.
zero_point = model_builder.CreateOrGetConstant<uint8_t>(zero_point_type, 0, zero_point_shape);
}

emscripten::val options = emscripten::val::object();
Expand Down
68 changes: 0 additions & 68 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"

#include <sstream>
#include <utility>

namespace onnxruntime {
Expand Down Expand Up @@ -385,73 +384,6 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op
wnn_operands_.insert(std::make_pair(name, operand));
}

// Get the zero constant with shape.
const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type,
const std::vector<uint32_t>& shape) {
std::string name = "webnn_zero_constant_" + std::to_string(data_type);
emscripten::val dims = emscripten::val::array();
if (!shape.empty()) {
dims = emscripten::val::array(shape);
std::ostringstream name_stream;
name_stream << name;
for (const auto& dim : shape) {
name_stream << "_" << dim;
}
name = name_stream.str();
}
// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
emscripten::val desc = emscripten::val::object();
desc.set("dimensions", dims);
desc.set("shape", dims);
emscripten::val zero_buffer = emscripten::val::undefined();
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}
auto num_elements = Product(shape);
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
// For WebNN int4 and uint4 tensors are stored in Uint8Array,
// so we need to adjust the number of elements.
num_elements = (num_elements + 1) / 2;
zero_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
zero_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
zero_buffer = emscripten::val::global("Int8Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
zero_buffer = emscripten::val::global("Uint16Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
zero_buffer = emscripten::val::global("Float32Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
zero_buffer = emscripten::val::global("Int32Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
zero_buffer = emscripten::val::global("BigInt64Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
zero_buffer = emscripten::val::global("Uint32Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
zero_buffer = emscripten::val::global("BigUint64Array").new_(num_elements);
break;
default:
break;
}

emscripten::val zero_constant = wnn_builder_.call<emscripten::val>("constant", desc, zero_buffer);
wnn_operands_.insert(std::make_pair(name, zero_constant));
}
return wnn_operands_.at(name);
}

void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) {
skipped_initializers_.insert(tensor_name);
}
Expand Down
110 changes: 69 additions & 41 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/framework/execution_provider.h"
#include "core/providers/webnn/builders/helper.h"

#include <sstream>
#include <emscripten.h>
#include <emscripten/val.h>

Expand Down Expand Up @@ -38,11 +39,10 @@ class ModelBuilder {
const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; }

void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(
const int32_t& data_type, const std::vector<uint32_t>& shape = {});

template <typename T>
const emscripten::val& CreateOrGetScalarConstant(const int32_t& data_type, T value);
const emscripten::val& CreateOrGetConstant(const int32_t& data_type, T value,
const std::vector<uint32_t>& shape = {});

// Use the buffers to persist WebNN allocated data like transposed weight.
// It ensures the validity during inference session.
Expand Down Expand Up @@ -103,11 +103,12 @@ class ModelBuilder {
static const IOpBuilder* GetOpBuilder(const Node& node);
};

// Create a scalar constant MLOperand of the specified value and data type.
// Workaround for builer.constant(type, value) method since it has not been implemented now.
// Create or retrieve one of the following:
// - A WebNN constant MLOperand filled with the specified value, data type, and shape.
// - A WebNN scalar constant MLOperand with the specified value and data type.
// For scalar constant, it is workaround for builer.constant(type, value) method since
// it has not been implemented now.
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value
// BTW, the spec is discussing if the builder.constant(type, value) should be dropped at
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
//
// This function enforces a mapping between the data_type and the value types:
// - TensorProto_DataType_INT4 <-> int8_t
Expand All @@ -122,69 +123,96 @@ class ModelBuilder {
// - TensorProto_DataType_UINT32 <-> uint32_t
// - TensorProto_DataType_UINT64 <-> uint64_t
template <typename T>
const emscripten::val& ModelBuilder::CreateOrGetScalarConstant(const int32_t& data_type, T value) {
std::string name = "webnn_scalar_constant_" + std::to_string(data_type) + "_" + std::to_string(value);
emscripten::val desc = emscripten::val::object();
desc.set("shape", emscripten::val::array());
emscripten::val scalar_buffer = emscripten::val::undefined();
uint16_t value_uint16 = 0;
uint8_t value_uint8 = 0;
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_type, T value,
const std::vector<uint32_t>& shape) {
std::string name = "webnn_constant_" + std::to_string(data_type) + "_" + std::to_string(value);
emscripten::val dims = emscripten::val::array();
if (!shape.empty()) {
dims = emscripten::val::array(shape);
std::ostringstream name_stream;
name_stream << name;
for (const auto& dim : shape) {
name_stream << "_" << dim;
}
name = name_stream.str();
}

// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
emscripten::val desc = emscripten::val::object();
desc.set("shape", dims);
desc.set("dimensions", dims);
emscripten::val buffer = emscripten::val::undefined();
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}
auto num_elements = Product(shape);
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
value_uint8 = PackInt8ToUint8AsNibble(value, data_type);
scalar_buffer.call<void>("fill", emscripten::val(value_uint8));
// For WebNN int4 and uint4 tensors are stored in Uint8Array,
// so we need to adjust the number of elements.
num_elements = (num_elements + 1) / 2;
buffer = emscripten::val::global("Uint8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(PackInt8ToUint8AsNibble(value, data_type)));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value ? 1 : 0));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Uint8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
scalar_buffer = emscripten::val::global("Int8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Int8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
scalar_buffer = emscripten::val::global("Uint16Array").new_(1);
value_uint16 = PackFloat32ToUint16AsFloat16(value);
scalar_buffer.call<void>("fill", emscripten::val(value_uint16));
buffer = emscripten::val::global("Uint16Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value)));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
scalar_buffer = emscripten::val::global("Float32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Float32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
scalar_buffer = emscripten::val::global("Int32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Int32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
scalar_buffer = emscripten::val::global("Uint32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Uint32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
scalar_buffer = emscripten::val::global("BigInt64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
buffer = emscripten::val::global("BigInt64Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
scalar_buffer = emscripten::val::global("BigUint64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
buffer = emscripten::val::global("BigUint64Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
}
break;
default:
break;
}

const emscripten::val scalar_constant = wnn_builder_.call<emscripten::val>("constant", desc, scalar_buffer);
wnn_operands_.insert(std::make_pair(name, scalar_constant));
const emscripten::val constant = wnn_builder_.call<emscripten::val>("constant", desc, buffer);
wnn_operands_.insert(std::make_pair(name, constant));
}

return wnn_operands_.at(name);
Expand Down

0 comments on commit cacd97d

Please sign in to comment.