Skip to content

Commit

Permalink
[ONNX] Frontend refactoring: operations (openvinotoolkit#22044)
Browse files Browse the repository at this point in the history
* Updated com.microsoft/attention.cpp
* Updated com.microsoft/bias_gelu.cpp
* Updated com.microsoft/embed_layer_normalization.cpp
* Updated com.microsoft/fused_conv.cpp
* Updated com.microsoft/fusedgemm.cpp
* Updated com.microsoft/skip_layer_normalization.cpp
  • Loading branch information
gkrivor authored Jan 10, 2024
1 parent 916fcab commit 0cf87a1
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 240 deletions.
360 changes: 183 additions & 177 deletions src/frontends/onnx/frontend/src/op/com.microsoft/attention.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

#include "op/com.microsoft/bias_gelu.hpp"

#include "default_opset.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/gelu.hpp"

using namespace ov::op;

namespace ngraph {
namespace onnx_import {
Expand All @@ -14,7 +17,7 @@ namespace set_1 {
OutputVector bias_gelu(const Node& node) {
auto nodes = node.get_ng_inputs();
FRONT_END_GENERAL_CHECK(nodes.size() == 2, "BiasGelu takes 2 inputs. Provided " + std::to_string(nodes.size()));
return {std::make_shared<default_opset::Gelu>(std::make_shared<default_opset::Add>(nodes.at(0), nodes.at(1)))};
return {std::make_shared<v7::Gelu>(std::make_shared<v1::Add>(nodes.at(0), nodes.at(1)))};
}
} // namespace set_1
} // namespace op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@

#include "op/com.microsoft/embed_layer_normalization.hpp"

#include "default_opset.hpp"
#include "onnx_import/core/null_node.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/mvn.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"

using namespace ov::op;

namespace ngraph {
namespace onnx_import {
Expand All @@ -28,63 +38,62 @@ OutputVector embed_layer_normalization(const Node& node) {
const auto& gamma = nodes[5];
const auto& beta = nodes[6];

const auto zero = default_opset::Constant::create(element::i32, Shape{1}, {0});
std::shared_ptr<ngraph::Node> input = std::make_shared<default_opset::Gather>(word_embeddings, input_ids, zero, 0);
const auto zero = v0::Constant::create(element::i32, Shape{1}, {0});
std::shared_ptr<ov::Node> input = std::make_shared<v8::Gather>(word_embeddings, input_ids, zero, 0);
// add position embeddings
if (num_nodes > 8 && !ov::op::util::is_null(nodes[8])) {
// if we have position_ids
const auto& position_ids = nodes[8];
const auto gathered_position_embeddings =
std::make_shared<default_opset::Gather>(position_embeddings, position_ids, zero, 0);
input = std::make_shared<default_opset::Add>(input, gathered_position_embeddings);
std::make_shared<v8::Gather>(position_embeddings, position_ids, zero, 0);
input = std::make_shared<v1::Add>(input, gathered_position_embeddings);
} else {
// input_ids' shape is [batchsize, sequence_length]
// input's shape is [batchsize, sequence_length, hidden_size]
// position_embeddings's shape is [max_sequence_length, hidden_size]
// therefore input and position_embeddings cannot be added together
// so we need slice the position_embeddings to [sequence_length, hidden_size] first
// then add it with input.
const auto one = default_opset::Constant::create(element::i32, Shape{1}, {1});
const auto input_ids_shape = std::make_shared<default_opset::ShapeOf>(input_ids, element::i32);
const auto seqlen = std::make_shared<default_opset::Gather>(input_ids_shape, one, zero, 0);
const auto one = v0::Constant::create(element::i32, Shape{1}, {1});
const auto input_ids_shape = std::make_shared<v3::ShapeOf>(input_ids, element::i32);
const auto seqlen = std::make_shared<v8::Gather>(input_ids_shape, one, zero, 0);
const auto gathered_position_embeddings =
std::make_shared<default_opset::Slice>(position_embeddings, zero, seqlen, one, zero);
input = std::make_shared<default_opset::Add>(input, gathered_position_embeddings);
std::make_shared<v8::Slice>(position_embeddings, zero, seqlen, one, zero);
input = std::make_shared<v1::Add>(input, gathered_position_embeddings);
}
// add segment embeddings if available
if (!ov::op::util::is_null(segment_ids)) {
FRONT_END_GENERAL_CHECK(!ov::op::util::is_null(segment_embeddings),
"segment_ids provided, but segment_embedding input is missing");
FRONT_END_GENERAL_CHECK(nodes[1].get_element_type() == element::i32, "segment_ids must have int32 type");
auto gathered_segment_embeddings =
std::make_shared<default_opset::Gather>(segment_embeddings, segment_ids, zero, 0);
input = std::make_shared<default_opset::Add>(input, gathered_segment_embeddings);
auto gathered_segment_embeddings = std::make_shared<v8::Gather>(segment_embeddings, segment_ids, zero, 0);
input = std::make_shared<v1::Add>(input, gathered_segment_embeddings);
}

float eps = node.get_attribute_value<float>("epsilon");
// reduce over hidden_size
// hidden_size dimension is 2 here, because the shape after Gather(word_embedding, input_ids)
// is (batch_size, seq_len, hidden_size)
int hidden_size_dim = 2;
const auto reduction_axes = default_opset::Constant::create(element::i32, Shape{1}, {hidden_size_dim});
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT);
const auto reduction_axes = v0::Constant::create(element::i32, Shape{1}, {hidden_size_dim});
std::shared_ptr<ov::Node> result =
std::make_shared<v6::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT);

// result = gamma * result + beta
result = std::make_shared<default_opset::Multiply>(result, gamma);
result = std::make_shared<default_opset::Add>(result, beta);
result = std::make_shared<v1::Multiply>(result, gamma);
result = std::make_shared<v1::Add>(result, beta);

// compute mask_index output
std::shared_ptr<ngraph::Node> mask_index;
std::shared_ptr<ov::Node> mask_index;
if (num_nodes > 7 && !ov::op::util::is_null(nodes[7])) {
FRONT_END_GENERAL_CHECK(nodes[7].get_element_type() == element::i32, "mask must have int32 type");
auto axis = default_opset::Constant::create(element::i32, Shape{}, {1});
mask_index = std::make_shared<default_opset::ReduceSum>(nodes[7], axis, false);
auto axis = v0::Constant::create(element::i32, Shape{}, {1});
mask_index = std::make_shared<v1::ReduceSum>(nodes[7], axis, false);
} else {
auto batch_size = std::make_shared<default_opset::Gather>(std::make_shared<default_opset::ShapeOf>(nodes[0]),
zero, // indices
zero); // axis
mask_index = std::make_shared<default_opset::Broadcast>(zero, batch_size);
auto batch_size = std::make_shared<v8::Gather>(std::make_shared<v3::ShapeOf>(nodes[0]),
zero, // indices
zero); // axis
mask_index = std::make_shared<v3::Broadcast>(zero, batch_size);
}
return {result, mask_index};
}
Expand Down
32 changes: 21 additions & 11 deletions src/frontends/onnx/frontend/src/op/com.microsoft/fused_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,19 @@
#include <memory>
#include <vector>

#include "default_opset.hpp"
#include "exceptions.hpp"
#include "op/conv.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/clamp.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/hard_sigmoid.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/sigmoid.hpp"
#include "openvino/op/tan.hpp"
#include "openvino/op/tanh.hpp"

using namespace ov::op;

namespace ngraph {
namespace onnx_import {
Expand All @@ -19,36 +29,36 @@ OutputVector fused_conv(const Node& node) {
auto conv_res = conv(node).at(0);

if (node.get_ng_inputs().size() == 4) { // Z input provided
conv_res = std::make_shared<default_opset::Add>(conv_res, node.get_ng_inputs()[3]);
conv_res = std::make_shared<v1::Add>(conv_res, node.get_ng_inputs()[3]);
}

const auto activation_type = node.get_attribute_value<std::string>("activation");
const auto activation_params = node.get_attribute_value<std::vector<float>>("activation_params", {});

if (activation_type == "Relu") {
return {std::make_shared<default_opset::Relu>(conv_res)};
return {std::make_shared<v0::Relu>(conv_res)};
} else if (activation_type == "Tanh") {
return {std::make_shared<default_opset::Tanh>(conv_res)};
return {std::make_shared<v0::Tanh>(conv_res)};
} else if (activation_type == "Sigmoid") {
return {std::make_shared<default_opset::Sigmoid>(conv_res)};
return {std::make_shared<v0::Sigmoid>(conv_res)};
} else if (activation_type == "Clip") {
CHECK_VALID_NODE(node,
activation_params.size() == 2,
"min and max attributes of Clip activation function were not provided");
return {std::make_shared<default_opset::Clamp>(conv_res, activation_params[0], activation_params[1])};
return {std::make_shared<v0::Clamp>(conv_res, activation_params[0], activation_params[1])};
} else if (activation_type == "LeakyRelu") {
CHECK_VALID_NODE(node,
activation_params.size() == 1,
"activation_alpha attribute of LeakyRelu activation function was not provided");
const auto activation_alpha_node = default_opset::Constant::create(element::f32, Shape{}, activation_params);
return {std::make_shared<default_opset::PRelu>(conv_res, activation_alpha_node)};
const auto activation_alpha_node = v0::Constant::create(element::f32, Shape{}, activation_params);
return {std::make_shared<v0::PRelu>(conv_res, activation_alpha_node)};
} else if (activation_type == "HardSigmoid") {
CHECK_VALID_NODE(node,
activation_params.size() == 2,
"alpha and beta attributes of HardSigmoid activation function were not provided");
const auto alpha = default_opset::Constant::create<float>(element::f32, Shape{}, {activation_params[0]});
const auto beta = default_opset::Constant::create<float>(element::f32, Shape{}, {activation_params[1]});
return {std::make_shared<default_opset::HardSigmoid>(conv_res, alpha, beta)};
const auto alpha = v0::Constant::create<float>(element::f32, Shape{}, {activation_params[0]});
const auto beta = v0::Constant::create<float>(element::f32, Shape{}, {activation_params[1]});
return {std::make_shared<v0::HardSigmoid>(conv_res, alpha, beta)};
}
CHECK_VALID_NODE(node,
!activation_type.empty(),
Expand Down
37 changes: 20 additions & 17 deletions src/frontends/onnx/frontend/src/op/com.microsoft/fusedgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

#include <memory>

#include "default_opset.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/matmul.hpp"
#include "ngraph/op/multiply.hpp"
#include "onnx_import/core/null_node.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/relu.hpp"

using namespace ov::op;

namespace ngraph {
namespace onnx_import {
Expand All @@ -24,14 +27,14 @@ OutputVector fusedgemm(const Node& node) {
FRONT_END_GENERAL_CHECK(num_inputs == 2 || num_inputs == 3,
"FusedGemm takes 2/3 inputs. Provided " + std::to_string(num_inputs));

Output<ngraph::Node> input_a = inputs.at(0);
Output<ngraph::Node> input_b = inputs.at(1);
Output<ngraph::Node> input_c;
Output<ov::Node> input_a = inputs.at(0);
Output<ov::Node> input_b = inputs.at(1);
Output<ov::Node> input_c;

if (num_inputs == 3 && !ov::op::util::is_null(inputs[2])) {
input_c = inputs.at(2);
} else {
input_c = default_opset::Constant::create(input_b.get_element_type(), ngraph::Shape{}, {0});
input_c = v0::Constant::create(input_b.get_element_type(), ov::Shape{}, {0});
}

const auto alpha_node = node.get_attribute_as_constant<float>("alpha", 1, input_b.get_element_type());
Expand All @@ -40,22 +43,22 @@ OutputVector fusedgemm(const Node& node) {
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);

const auto matmul_node = std::make_shared<default_opset::MatMul>(input_a, input_b, trans_a, trans_b);
const auto matmul_times_alpha = std::make_shared<default_opset::Multiply>(matmul_node, alpha_node);
const auto matmul_node = std::make_shared<v0::MatMul>(input_a, input_b, trans_a, trans_b);
const auto matmul_times_alpha = std::make_shared<v1::Multiply>(matmul_node, alpha_node);

const auto beta_times_input_c = std::make_shared<default_opset::Multiply>(beta_node, input_c);
const auto beta_times_input_c = std::make_shared<v1::Multiply>(beta_node, input_c);
const std::string onnx_name = !node.get_name().empty() ? node.get_name() : node.output(0);
matmul_node->set_friendly_name(onnx_name + "/WithoutBiases");
const auto gemm_res = std::make_shared<default_opset::Add>(matmul_times_alpha, beta_times_input_c);
const auto gemm_res = std::make_shared<v1::Add>(matmul_times_alpha, beta_times_input_c);

const auto activation_type = node.get_attribute_value<std::string>("activation", "Relu");
if (activation_type == "LeakyRelu") {
double activation_alpha = node.get_attribute_value<double>("activation_alpha", 0.01);
std::shared_ptr<ngraph::Node> activation_alpha_node =
default_opset::Constant::create(input_c.get_element_type(), Shape{1}, {activation_alpha});
return {std::make_shared<default_opset::PRelu>(gemm_res, activation_alpha_node)};
std::shared_ptr<ov::Node> activation_alpha_node =
v0::Constant::create(input_c.get_element_type(), Shape{1}, {activation_alpha});
return {std::make_shared<v0::PRelu>(gemm_res, activation_alpha_node)};
}
return {std::make_shared<default_opset::Relu>(gemm_res)};
return {std::make_shared<v0::Relu>(gemm_res)};
}

} // namespace set_1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@

#include "op/com.microsoft/skip_layer_normalization.hpp"

#include "default_opset.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/mvn.hpp"

using namespace ov::op;

namespace ngraph {
namespace onnx_import {
Expand All @@ -18,22 +23,22 @@ OutputVector skip_layer_normalization(const Node& node) {
"SkipLayerNormalization takes 3, 4 or 5 inputs. Provided " + std::to_string(num_nodes));

// input + skip
std::shared_ptr<ngraph::Node> input = std::make_shared<default_opset::Add>(nodes[0], nodes[1]);
std::shared_ptr<ov::Node> input = std::make_shared<v1::Add>(nodes[0], nodes[1]);
// add bias if available
if (num_nodes == 5) {
input = std::make_shared<default_opset::Add>(input, nodes[4]);
input = std::make_shared<v1::Add>(input, nodes[4]);
}
float eps = node.get_attribute_value<float>("epsilon");
// reduce over hidden_size
int hidden_size_dim = 2;
const auto reduction_axes = default_opset::Constant::create(element::i32, Shape{1}, {hidden_size_dim});
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT);
const auto reduction_axes = v0::Constant::create(element::i32, Shape{1}, {hidden_size_dim});
std::shared_ptr<ov::Node> result =
std::make_shared<v6::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT);
// multiply by gamma
result = std::make_shared<default_opset::Multiply>(result, nodes[2]);
result = std::make_shared<v1::Multiply>(result, nodes[2]);
// add beta if available
if (num_nodes > 3) {
result = std::make_shared<default_opset::Add>(result, nodes[3]);
result = std::make_shared<v1::Add>(result, nodes[3]);
}
// spec mentions three outputs (output, mean, inv_std_var) while we support only first one, but:
// - onnxruntime also doesn't support the last two
Expand Down

0 comments on commit 0cf87a1

Please sign in to comment.