Skip to content

Commit

Permalink
Merge pull request #3 from BruceDai/support_cast
Browse files Browse the repository at this point in the history
Support Cast op
  • Loading branch information
Honry authored May 10, 2023
2 parents 208565c + 7698b8c commit 4becf61
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 0 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"LeakyRelu", "leakyRelu"},
{"Sigmoid", "sigmoid"},
{"Softmax", "softmax"},
{"Cast", "cast"},
{"Clip", "clamp"},
{"Conv", "conv2d"},
{"ConvTranspose", "convTranspose2d"},
Expand Down
78 changes: 78 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"
#include "core/providers/shared/utils/utils.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class CastOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const override;

int GetMinSupportedOpSet(const Node& node) const override;
};

// Add operator related.

Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_name = node.InputDefs()[0]->Name();
emscripten::val input = model_builder.GetOperand(input_name);

NodeAttrHelper helper(node);
// We already checked the "to" type in IsOpSupportedImpl.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
std::string operand_type =
to_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ? "float32" : "float16";

emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("cast", input, emscripten::val(operand_type));

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

// Operator support related.

int CastOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
// Since opset 6, Cast uses attribute "to" as int type.
return 6;
}

bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
// Check cast output type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
if (!IsSupportedDataType(to_type)) {
LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type
<< " . Current WebNN only support cast to float32 or float16.";
return false;
}

return true;
}

void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<CastOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateActivationOpBuilder("Sigmoid", op_registrations);
}

{ // Cast
CreateCastOpBuilder("Cast", op_registrations);
}

{ // Clip
CreateClipOpBuilder("Clip", op_registrations);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const InlinedHashMap<std::string, const IOpBuilder*>& GetOpBuilders();

void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down

0 comments on commit 4becf61

Please sign in to comment.