Skip to content

Commit

Permalink
[WebNN] Support steps >= 1 for slice operator (microsoft#22708)
Browse files Browse the repository at this point in the history
Co-authored-by: Wanming Lin <wanming.lin@intel.com>
  • Loading branch information
2 people authored and ankitm3k committed Dec 11, 2024
1 parent c3f5d6e commit 46f1d2a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Softplus | ai.onnx(7+) | softplus ||| |
| Softsign | ai.onnx(7+) | softsign ||| |
| Sin | ai.onnx(7+) | sin ||| |
| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice ||| Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 |
| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice ||| Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value >= 1 |
| Softmax | ai.onnx(7-10, 11-12, 13+) | softmax ||| |
| Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split ||| Input 'split' if present should be a constant |
| Sqrt | ai.onnx(7-12, 13+) | sqrt ||| |
Expand Down
16 changes: 8 additions & 8 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2362,14 +2362,14 @@
// "test_sinh",
// // "test_size_example",
// // "test_size",
// "test_slice_default_axes",
// "test_slice_default_steps",
// "test_slice_end_out_of_bounds",
// "test_slice_neg_steps",
// "test_slice_neg",
// "test_slice_negative_axes",
// "test_slice_start_out_of_bounds",
// "test_slice",
"test_slice_default_axes",
"test_slice_default_steps",
"test_slice_end_out_of_bounds",
"test_slice_neg_steps",
"test_slice_neg",
"test_slice_negative_axes",
"test_slice_start_out_of_bounds",
"test_slice",
// "test_softmax_axis_0_expanded",
"test_softmax_axis_0",
// "test_softmax_axis_1_expanded",
Expand Down
21 changes: 13 additions & 8 deletions onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name());
std::vector<int32_t> starts(rank);
std::vector<int32_t> sizes(rank);
std::vector<int32_t> steps(rank);

// Copy the data from the starts/ends/axes/steps initializers.
std::vector<int64_t> input_starts;
Expand Down Expand Up @@ -94,8 +95,11 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::transform(compute_metadata.ends_.cbegin(), compute_metadata.ends_.cend(), compute_metadata.starts_.cbegin(),
sizes.begin(),
[](int64_t i, int64_t j) { return SafeInt<uint32_t>(i - j); });
std::transform(compute_metadata.steps_.cbegin(), compute_metadata.steps_.cend(), steps.begin(),
[](int64_t i) { return SafeInt<uint32_t>(i); });

emscripten::val options = emscripten::val::object();
options.set("strides", emscripten::val::array(steps));
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("slice", inputs,
emscripten::val::array(starts),
Expand Down Expand Up @@ -144,18 +148,19 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return false;
}
const auto data_type = steps_tensor.data_type();
// WebNN doesn't support steps other than 1.
// WebNN doesn't support steps less than 1.
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
if (!std::all_of(reinterpret_cast<int64_t*>(unpacked_tensor.data()),
reinterpret_cast<int64_t*>(unpacked_tensor.data() + unpacked_tensor.size()),
[](int64_t i) { return i == 1; })) {
if (std::any_of(reinterpret_cast<int64_t*>(unpacked_tensor.data()),
reinterpret_cast<int64_t*>(unpacked_tensor.data() + unpacked_tensor.size()),
[](int64_t i) { return i < 1; })) {
LOGS(logger, VERBOSE) << "WebNN slice doesn't support steps less than 1";
return false;
}
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
if (!std::all_of(reinterpret_cast<int32_t*>(unpacked_tensor.data()),
reinterpret_cast<int32_t*>(unpacked_tensor.data()) +
unpacked_tensor.size() / sizeof(int32_t),
[](int32_t i) { return i == 1; })) {
if (std::any_of(reinterpret_cast<int32_t*>(unpacked_tensor.data()),
reinterpret_cast<int32_t*>(unpacked_tensor.data()) + unpacked_tensor.size() / sizeof(int32_t),
[](int32_t i) { return i < 1; })) {
LOGS(logger, VERBOSE) << "WebNN slice doesn't support steps less than 1";
return false;
}
}
Expand Down

0 comments on commit 46f1d2a

Please sign in to comment.