Skip to content

Commit

Permalink
[WebNN EP] Support Einsum op (microsoft#19558)
Browse files Browse the repository at this point in the history
Adds support for einsum via WebNN matmul, transpose, reshape, reducesum,
identity and element-wise binary ops.
  • Loading branch information
peishenyan authored and ankitm3k committed Dec 11, 2024
1 parent 5388c4e commit 79ad4e1
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

// Helper functions, thanks for DML EP's OperatorHelper.
Expand Down Expand Up @@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}

bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

const auto& op_type = node.OpType();
Expand Down Expand Up @@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
return false;
} else if (recognized_operator_type == RecognizedOperatorType::Pairwise) {
// Map to WebNN's gemm or matmul
return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger);
return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger);
} else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) {
return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger);
} else {
return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger);
}
}

Expand Down

0 comments on commit 79ad4e1

Please sign in to comment.