diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/attention.cpp index 6a410ebf559aa9..c5ad926f30de80 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/attention.cpp @@ -7,33 +7,65 @@ #include "default_opset.hpp" #include "onnx_import/core/null_node.hpp" #include "openvino/frontend/exception.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/floor.hpp" +#include "openvino/op/floor_mod.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/greater_eq.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/log.hpp" +#include "openvino/op/logical_not.hpp" +#include "openvino/op/logical_or.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/pad.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" #include "ov_models/ov_builders/split.hpp" +using namespace ov::op; + namespace ngraph { namespace onnx_import { namespace op { namespace detail { namespace { -NodeVector split_to_QKV(const std::shared_ptr& node, +NodeVector split_to_QKV(const std::shared_ptr& node, int64_t num_heads, const std::vector& qkv_hidden_sizes); -using NodeTuple = std::tuple, std::shared_ptr>; +using NodeTuple = std::tuple, std::shared_ptr>; NodeTuple get_attention_mask(const OutputVector& op_inputs, bool unidirectional); -std::shared_ptr attention_softmax(const OutputVector& op_inputs, - const std::shared_ptr& Q, - std::shared_ptr K, - std::shared_ptr V, - const std::shared_ptr& attention_mask, - const std::shared_ptr& bin_mask, - const std::shared_ptr& head_size, - bool unidirectional); - -std::shared_ptr get_present_state(const std::shared_ptr& K, - const std::shared_ptr& V, - const OutputVector& op_inputs); +std::shared_ptr attention_softmax(const OutputVector& op_inputs, + const std::shared_ptr& Q, + std::shared_ptr K, + std::shared_ptr V, + const std::shared_ptr& attention_mask, + const std::shared_ptr& bin_mask, + const std::shared_ptr& head_size, + bool unidirectional); + +std::shared_ptr get_present_state(const std::shared_ptr& K, + const std::shared_ptr& V, + const OutputVector& op_inputs); } // namespace } // namespace detail @@ -52,8 +84,8 @@ OutputVector attention(const Node& node) { // So the approach here is to do a single big matrix multiply // and then split the result into Q, K, V matrices - auto matmul = std::make_shared(input, weights); - auto add = std::make_shared(matmul, bias); + auto matmul = std::make_shared(input, weights); + auto add = std::make_shared(matmul, bias); const auto num_heads = node.get_attribute_value("num_heads"); const auto qkv_hidden_sizes = node.get_attribute_value>("qkv_hidden_sizes", {}); @@ -64,7 +96,7 @@ OutputVector attention(const Node& node) { // broadcastable to (batch_size, num_heads, sequence_length, past_sequence_length + sequence_length) // so it can be added to Q x K' later // past_sequence_length can be 0 if 'past' input is not available - std::shared_ptr attention_mask = nullptr, bin_mask = nullptr; + std::shared_ptr attention_mask = nullptr, bin_mask = nullptr; std::tie(attention_mask, bin_mask) = detail::get_attention_mask(nodes, unidirectional); const auto& Q = split_result[0]; @@ -87,50 +119,48 @@ OutputVector attention(const Node& node) { namespace detail { namespace { -std::shared_ptr get_dimensions(const std::shared_ptr& shape, - const std::vector& dims) { - static const auto zero = default_opset::Constant::create(element::i32, Shape{}, {0}); - const auto dims_const = default_opset::Constant::create(element::i32, Shape{dims.size()}, dims); - return std::make_shared(shape, dims_const, zero); +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { + static const auto zero = v0::Constant::create(element::i32, Shape{}, {0}); + const auto dims_const = v0::Constant::create(element::i32, Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); } -std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { - return get_dimensions(std::make_shared(node), dims); +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); } -std::shared_ptr get_hidden_size(const std::shared_ptr& node_shape) { +std::shared_ptr get_hidden_size(const std::shared_ptr& node_shape) { // node has shape (batch_size, sequence_length, 3 * hidden_size) - const auto zero = default_opset::Constant::create(element::i32, Shape{}, {0}); + const auto zero = v0::Constant::create(element::i32, Shape{}, {0}); const auto hidden_size_x3 = get_dimensions(node_shape, {2}); - const auto three = default_opset::Constant::create(element::i64, Shape{}, {3}); - const auto hidden_size = std::make_shared(hidden_size_x3, three); + const auto three = v0::Constant::create(element::i64, Shape{}, {3}); + const auto hidden_size = std::make_shared(hidden_size_x3, three); return hidden_size; } -NodeVector split_to_QKV(const std::shared_ptr& node, +NodeVector split_to_QKV(const std::shared_ptr& node, int64_t num_heads, const std::vector& qkv_hidden_sizes) { OutputVector split; - std::shared_ptr head_size = nullptr; + std::shared_ptr head_size = nullptr; const auto& node_type = node->get_element_type(); - const auto node_shape = std::make_shared(node); + const auto node_shape = std::make_shared(node); // node has shape (batch_size, sequence_length, 3 * hidden_size) // fetch the first two dimensions const auto batch_size_seq_len = get_dimensions(node_shape, {0, 1}); - const auto num_heads_node = default_opset::Constant::create(element::i64, Shape{1}, {num_heads}); + const auto num_heads_node = v0::Constant::create(element::i64, Shape{1}, {num_heads}); if (qkv_hidden_sizes.size() == 0) { const auto hidden_size = get_hidden_size(node_shape); // head_size = hidden_size / num_heads - head_size = std::make_shared(hidden_size, num_heads_node); + head_size = std::make_shared(hidden_size, num_heads_node); // split the node into 3 even parts Q, K, V with shape (batch_size, sequence_len, hidden_size) split = ov::op::util::split(node, 3, 2); // and reshape each part to new shape (batch_size, sequence_len, num_heads, head_size) - auto new_shape = - std::make_shared(NodeVector{batch_size_seq_len, num_heads_node, head_size}, 0); + auto new_shape = std::make_shared(NodeVector{batch_size_seq_len, num_heads_node, head_size}, 0); for (size_t i = 0; i < split.size(); i++) { - split[i] = std::make_shared(split[i], new_shape, false); + split[i] = std::make_shared(split[i], new_shape, false); } - head_size = std::make_shared(head_size, node_type); + head_size = std::make_shared(head_size, node_type); } else { // in this case, weights have shape // (input_hidden_size, qkv_hidden_sizes[0] + qkv_hidden_sizes[1] + qkv_hidden_sizes[2]) @@ -145,23 +175,23 @@ NodeVector split_to_QKV(const std::shared_ptr& node, split = ov::op::util::split(node, qkv_hidden_sizes, 2); // and reshape each part to new shape (batch_size, sequence_len, num_heads, head_size) for (size_t i = 0; i < split.size(); i++) { - auto new_shape = std::make_shared( + auto new_shape = std::make_shared( NodeVector{batch_size_seq_len, num_heads_node, - default_opset::Constant::create(element::i64, Shape{1}, {qkv_hidden_sizes[i] / num_heads})}, + v0::Constant::create(element::i64, Shape{1}, {qkv_hidden_sizes[i] / num_heads})}, 0); - split[i] = std::make_shared(split[i], new_shape, false); + split[i] = std::make_shared(split[i], new_shape, false); } float head_size_val = qkv_hidden_sizes[0] > 0 ? static_cast(qkv_hidden_sizes[0]) / num_heads : static_cast(qkv_hidden_sizes[2]) / num_heads; - head_size = default_opset::Constant::create(node_type, Shape{1}, {head_size_val}); + head_size = v0::Constant::create(node_type, Shape{1}, {head_size_val}); } // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) - auto perm = default_opset::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}); - auto Q = std::make_shared(split[0], perm); - auto K = std::make_shared(split[1], perm); - auto V = std::make_shared(split[2], perm); + auto perm = v0::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}); + auto Q = std::make_shared(split[0], perm); + auto K = std::make_shared(split[1], perm); + auto V = std::make_shared(split[2], perm); return {Q, K, V, head_size}; } @@ -189,7 +219,7 @@ NodeVector split_to_QKV(const std::shared_ptr& node, // e.g., for batch = 1, -10000 values appear within two ranges [0, mask_index[4]] and [mask_index[1]:5] (or [0:2],[4:5]) // // -// This is how it's done with nGraph operations: +// This is how it's done with OpenVINO operations: // // First the 'base' is generated by range + broadcast: // base = range(0, all_seq_len) @@ -236,62 +266,51 @@ NodeVector split_to_QKV(const std::shared_ptr& node, // Handling both mask_index variants (so (batch_size) and (2 * batch_size)) is tricky since we don't // know its dimensions upfront. So we compute both variants and use Select operator to select // the right one in the runtime (unless it gets constantfolded before). -std::shared_ptr attention_mask_from_indices(const Output& mask_index, - const element::Type_t& type, - const std::shared_ptr& batch_size, - const std::shared_ptr& all_seq_len) { - const auto zero = default_opset::Constant::create(element::i64, Shape{}, {0}); - const auto one = default_opset::Constant::create(element::i64, Shape{}, {1}); - const auto stop = std::make_shared(all_seq_len, zero); - std::shared_ptr base = - std::make_shared(zero, stop, one, mask_index.get_element_type()); - const auto target_shape = std::make_shared(NodeVector{batch_size, all_seq_len}, 0); +std::shared_ptr attention_mask_from_indices(const Output& mask_index, + const element::Type_t& type, + const std::shared_ptr& batch_size, + const std::shared_ptr& all_seq_len) { + const auto zero = v0::Constant::create(element::i64, Shape{}, {0}); + const auto one = v0::Constant::create(element::i64, Shape{}, {1}); + const auto stop = std::make_shared(all_seq_len, zero); + std::shared_ptr base = std::make_shared(zero, stop, one, mask_index.get_element_type()); + const auto target_shape = std::make_shared(NodeVector{batch_size, all_seq_len}, 0); // broadcast 'base' to (batch_size, all_seq_len) - base = std::make_shared(base, target_shape); - const auto indices_shape = std::make_shared( - NodeVector{default_opset::Constant::create(element::i64, Shape{1}, {-1}), batch_size}, - 0); - std::shared_ptr indices = std::make_shared(mask_index, indices_shape, false); + base = std::make_shared(base, target_shape); + const auto indices_shape = + std::make_shared(NodeVector{v0::Constant::create(element::i64, Shape{1}, {-1}), batch_size}, 0); + std::shared_ptr indices = std::make_shared(mask_index, indices_shape, false); // fetch first row from indices - std::shared_ptr tail_range_indices = std::make_shared(indices, zero, zero); + std::shared_ptr tail_range_indices = std::make_shared(indices, zero, zero); tail_range_indices = - std::make_shared(tail_range_indices, - default_opset::Constant::create(element::i32, Shape{2}, {-1, 1}), - false); - const auto greater_eq = std::make_shared(base, tail_range_indices); - std::shared_ptr tail_range_mask = - std::make_shared(std::make_shared(greater_eq, type), - default_opset::Constant::create(type, Shape{}, {-10000})); + std::make_shared(tail_range_indices, v0::Constant::create(element::i32, Shape{2}, {-1, 1}), false); + const auto greater_eq = std::make_shared(base, tail_range_indices); + std::shared_ptr tail_range_mask = + std::make_shared(std::make_shared(greater_eq, type), + v0::Constant::create(type, Shape{}, {-10000})); tail_range_mask = - std::make_shared(tail_range_mask, - default_opset::Constant::create(element::i64, Shape{2}, {1, 2})); + std::make_shared(tail_range_mask, v0::Constant::create(element::i64, Shape{2}, {1, 2})); const auto gather_index = - std::make_shared(default_opset::Constant::create(element::i64, Shape{}, {1}), - get_dimensions(indices, {0})); + std::make_shared(v0::Constant::create(element::i64, Shape{}, {1}), get_dimensions(indices, {0})); // fetch indices from the second row (or first if not available) - std::shared_ptr head_range_indices = - std::make_shared(indices, gather_index, zero); + std::shared_ptr head_range_indices = std::make_shared(indices, gather_index, zero); head_range_indices = - std::make_shared(head_range_indices, - default_opset::Constant::create(element::i32, Shape{2}, {-1, 1}), - false); - const auto less = std::make_shared(base, head_range_indices); - std::shared_ptr mask = std::make_shared(less, greater_eq); - mask = std::make_shared(std::make_shared(mask, type), - default_opset::Constant::create(type, Shape{}, {-10000})); + std::make_shared(head_range_indices, v0::Constant::create(element::i32, Shape{2}, {-1, 1}), false); + const auto less = std::make_shared(base, head_range_indices); + std::shared_ptr mask = std::make_shared(less, greater_eq); + mask = std::make_shared(std::make_shared(mask, type), + v0::Constant::create(type, Shape{}, {-10000})); // reshape from (batch_size, all_seq_len) to (batch_size, 1, 1, all_seq_len) - mask = std::make_shared(mask, - default_opset::Constant::create(element::i64, Shape{2}, {1, 2})); + mask = std::make_shared(mask, v0::Constant::create(element::i64, Shape{2}, {1, 2})); const auto mask_index_first_dim = get_dimensions(mask_index.get_node_shared_ptr(), {0}); // compare mask_index.shape[0] with batch_size value // if they're equal - select tail_range_mask // else select full mask - mask = std::make_shared( - std::make_shared(batch_size, mask_index_first_dim), - tail_range_mask, - mask); + mask = std::make_shared(std::make_shared(batch_size, mask_index_first_dim), + tail_range_mask, + mask); return mask; } @@ -321,26 +340,24 @@ std::shared_ptr attention_mask_from_indices(const Output& seq_len, - const std::shared_ptr& all_seq_len, - const std::shared_ptr& past_seq_len) { - const auto zero = default_opset::Constant::create(element::i64, Shape{}, {0}); - const auto one = default_opset::Constant::create(element::i64, Shape{}, {1}); - const auto stop = std::make_shared(all_seq_len, zero); - std::shared_ptr bin_mask = std::make_shared(zero, stop, one, element::i32); - auto target_shape = std::make_shared(NodeVector{seq_len, all_seq_len}, 0); - bin_mask = std::make_shared(bin_mask, target_shape); - auto start = - std::make_shared(std::make_shared(past_seq_len, one), zero); - auto end = std::make_shared(std::make_shared(all_seq_len, one), zero); - auto indices = std::make_shared( - std::make_shared(start, end, one, element::i32), - default_opset::Constant::create(element::i32, Shape{1}, {1})); - bin_mask = std::make_shared(bin_mask, indices); - std::shared_ptr attention_mask = - std::make_shared(std::make_shared(bin_mask, type), - default_opset::Constant::create(type, Shape{}, {-10000})); - bin_mask = std::make_shared(std::make_shared(bin_mask), type); + const std::shared_ptr& seq_len, + const std::shared_ptr& all_seq_len, + const std::shared_ptr& past_seq_len) { + const auto zero = v0::Constant::create(element::i64, Shape{}, {0}); + const auto one = v0::Constant::create(element::i64, Shape{}, {1}); + const auto stop = std::make_shared(all_seq_len, zero); + std::shared_ptr bin_mask = std::make_shared(zero, stop, one, element::i32); + auto target_shape = std::make_shared(NodeVector{seq_len, all_seq_len}, 0); + bin_mask = std::make_shared(bin_mask, target_shape); + auto start = std::make_shared(std::make_shared(past_seq_len, one), zero); + auto end = std::make_shared(std::make_shared(all_seq_len, one), zero); + auto indices = std::make_shared(std::make_shared(start, end, one, element::i32), + v0::Constant::create(element::i32, Shape{1}, {1})); + bin_mask = std::make_shared(bin_mask, indices); + std::shared_ptr attention_mask = + std::make_shared(std::make_shared(bin_mask, type), + v0::Constant::create(type, Shape{}, {-10000})); + bin_mask = std::make_shared(std::make_shared(bin_mask), type); return NodeTuple{attention_mask, bin_mask}; } @@ -355,29 +372,23 @@ NodeTuple unidirectional_mask(const element::Type_t& type, // // Shape (batch_size, 1, max_sequence_length, max_sequence_length) is not supported in onnxruntime: // https://github.com/microsoft/onnxruntime/blob/851554536ca8185b3413ee57449ea5ac93370193/onnxruntime/contrib_ops/cpu/bert/attention_helper.h#L78 -std::shared_ptr raw_mask(const Output& mask_index, - ov::Dimension::value_type mask_rank, - const element::Type_t& type) { - std::shared_ptr mask = std::make_shared(mask_index, type); - mask = std::make_shared(mask, type); - mask = std::make_shared(default_opset::Constant::create(type, Shape{}, {1}), mask); - mask = std::make_shared(mask, default_opset::Constant::create(type, Shape{}, {-10000})); +std::shared_ptr raw_mask(const Output& mask_index, + ov::Dimension::value_type mask_rank, + const element::Type_t& type) { + std::shared_ptr mask = std::make_shared(mask_index, type); + mask = std::make_shared(mask, type); + mask = std::make_shared(v0::Constant::create(type, Shape{}, {1}), mask); + mask = std::make_shared(mask, v0::Constant::create(type, Shape{}, {-10000})); switch (mask_rank) { // Handle mask_index with (batch_size, past_sequence_length + sequence_length) shape // Reshape it to (batch_size, 1, 1, past_sequence_length + sequence_length) case 2: - mask = std::make_shared( - mask, - default_opset::Constant::create(element::i64, Shape{4}, {0, 1, 1, -1}), - true); + mask = std::make_shared(mask, v0::Constant::create(element::i64, Shape{4}, {0, 1, 1, -1}), true); break; // Handle mask_index with (batch_size, sequence_length, past_sequence_length + sequence_length) shape // Reshape it to (batch_size, 1, sequence_length, past_sequence_length + sequence_length) case 3: - mask = std::make_shared( - mask, - default_opset::Constant::create(element::i64, Shape{4}, {0, 1, 0, -1}), - true); + mask = std::make_shared(mask, v0::Constant::create(element::i64, Shape{4}, {0, 1, 0, -1}), true); break; } return mask; @@ -388,10 +399,10 @@ bool is_past_input_available(const OutputVector& op_inputs) { } NodeTuple get_attention_mask(const OutputVector& op_inputs, bool unidirectional) { - const auto zero = default_opset::Constant::create(element::i64, Shape{1}, {0}); - const auto one = default_opset::Constant::create(element::i64, Shape{1}, {1}); + const auto zero = v0::Constant::create(element::i64, Shape{1}, {0}); + const auto one = v0::Constant::create(element::i64, Shape{1}, {1}); - std::shared_ptr past_seq_len; + std::shared_ptr past_seq_len; // get the value of past_sequence_length if (is_past_input_available(op_inputs)) { const auto& past = op_inputs[4]; @@ -402,12 +413,12 @@ NodeTuple get_attention_mask(const OutputVector& op_inputs, bool unidirectional) } // 'input' node has shape (batch_size, sequence_length, input_hidden_size) - auto input_shape = std::make_shared(op_inputs[0]); + auto input_shape = std::make_shared(op_inputs[0]); auto seq_len = get_dimensions(input_shape, {1}); - auto all_seq_len = std::make_shared(seq_len, past_seq_len); + auto all_seq_len = std::make_shared(seq_len, past_seq_len); const auto& type = op_inputs[0].get_element_type(); - std::shared_ptr attention_mask = nullptr; - std::shared_ptr bin_mask = nullptr; + std::shared_ptr attention_mask = nullptr; + std::shared_ptr bin_mask = nullptr; if (unidirectional) { std::tie(attention_mask, bin_mask) = unidirectional_mask(type, seq_len, all_seq_len, past_seq_len); } @@ -418,7 +429,7 @@ NodeTuple get_attention_mask(const OutputVector& op_inputs, bool unidirectional) const auto mask_rank = mask_index.get_partial_shape().rank(); FRONT_END_GENERAL_CHECK(mask_rank.is_static(), "'mask_index' rank must be static"); auto mask_rank_val = mask_rank.get_length(); - std::shared_ptr mask; + std::shared_ptr mask; if (mask_rank_val == 1) { // case when mask_index has shape (batch_size) or (2 * batch_size) // so it contains positions that specify how mask should be generated @@ -431,7 +442,7 @@ NodeTuple get_attention_mask(const OutputVector& op_inputs, bool unidirectional) } // add the mask with unidirectional mask if available if (attention_mask) { - attention_mask = std::make_shared(attention_mask, mask); + attention_mask = std::make_shared(attention_mask, mask); } else { attention_mask = mask; } @@ -440,15 +451,15 @@ NodeTuple get_attention_mask(const OutputVector& op_inputs, bool unidirectional) } // Compute softmax(Q x K' / sqrt(head_size)) x V -std::shared_ptr attention_softmax(const OutputVector& op_inputs, - const std::shared_ptr& Q, - std::shared_ptr K, - std::shared_ptr V, - const std::shared_ptr& attention_mask, - const std::shared_ptr& bin_mask, - const std::shared_ptr& head_size, - bool unidirectional) { - auto zero = default_opset::Constant::create(element::i64, Shape{}, {0}); +std::shared_ptr attention_softmax(const OutputVector& op_inputs, + const std::shared_ptr& Q, + std::shared_ptr K, + std::shared_ptr V, + const std::shared_ptr& attention_mask, + const std::shared_ptr& bin_mask, + const std::shared_ptr& head_size, + bool unidirectional) { + auto zero = v0::Constant::create(element::i64, Shape{}, {0}); if (is_past_input_available(op_inputs)) { // concat past K and V with present ones const auto& past = op_inputs[4]; @@ -458,46 +469,46 @@ std::shared_ptr attention_softmax(const OutputVector& op_inputs, // so we need to split it into two parts, remove first dimension from each part and concatenate first part // with current K and second part with current V const auto split = ov::op::util::split(past, 2, 0); - const auto past_K = std::make_shared(split[0], zero); - K = std::make_shared(NodeVector{past_K, K}, 2); - const auto past_V = std::make_shared(split[1], zero); - V = std::make_shared(NodeVector{past_V, V}, 2); + const auto past_K = std::make_shared(split[0], zero); + K = std::make_shared(NodeVector{past_K, K}, 2); + const auto past_V = std::make_shared(split[1], zero); + V = std::make_shared(NodeVector{past_V, V}, 2); } // perform Q x K' - std::shared_ptr softmax_input = std::make_shared(Q, K, false, true); + std::shared_ptr softmax_input = std::make_shared(Q, K, false, true); // Q x K' + mask if (attention_mask) { if (unidirectional) { // Perform the equivalent of // https://github.com/microsoft/onnxruntime/blob/851554536ca8185b3413ee57449ea5ac93370193/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h#L158-L166 // For positions where unidirectional_mask has -10000 values - attention_mask is moved to softmax input - softmax_input = std::make_shared(softmax_input, bin_mask); + softmax_input = std::make_shared(softmax_input, bin_mask); } - softmax_input = std::make_shared(softmax_input, attention_mask); + softmax_input = std::make_shared(softmax_input, attention_mask); } - const auto sqrt = std::make_shared(head_size); + const auto sqrt = std::make_shared(head_size); // (Q x K' + mask) / sqrt(head_size) - softmax_input = std::make_shared(softmax_input, sqrt); + softmax_input = std::make_shared(softmax_input, sqrt); // handle 'extra_add' input if (op_inputs.size() > 5 && !ov::op::util::is_null(op_inputs[5])) { FRONT_END_GENERAL_CHECK(!is_past_input_available(op_inputs), "Cannot use both 'past' and 'extra_add' inputs in the same node"); const auto& extra_add = op_inputs[5]; - softmax_input = std::make_shared(softmax_input, extra_add); + softmax_input = std::make_shared(softmax_input, extra_add); } // softmax((Q x K' + mask) / sqrt(head_size)) - const auto softmax = std::make_shared(softmax_input, 3); + const auto softmax = std::make_shared(softmax_input, 3); // softmax((Q x K' + mask) / sqrt(head_size)) x V - std::shared_ptr output = std::make_shared(softmax, V); + std::shared_ptr output = std::make_shared(softmax, V); // transpose the result from (batch_size, num_heads, sequence_length, head_size) // to (batch_size, sequence_length, num_heads, head_size) - const auto perm = default_opset::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}); - output = std::make_shared(output, perm); - auto new_shape = default_opset::Constant::create(element::i32, Shape{3}, {0, 0, -1}); + const auto perm = v0::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}); + output = std::make_shared(output, perm); + auto new_shape = v0::Constant::create(element::i32, Shape{3}, {0, 0, -1}); // reshape the result from (batch_size, sequence_length, num_heads, head_size) to (batch_size, sequence_length, // num_heads * head_size) - output = std::make_shared(output, new_shape, true); + output = std::make_shared(output, new_shape, true); return output; } @@ -506,40 +517,35 @@ std::shared_ptr attention_softmax(const OutputVector& op_inputs, // (batch_size, num_heads, sequence_length, head_size) to (1, batch_size, num_heads, sequence_length, head_size) // and concatenating them along first axis to make 'present' output. // If fifth input ('past') is available, it gets concatenated with 'present' output along fourth axis. -std::shared_ptr get_present_state(const std::shared_ptr& K, - const std::shared_ptr& V, - const OutputVector& op_inputs) { - auto zero = default_opset::Constant::create(element::i64, Shape{1}, {0}); +std::shared_ptr get_present_state(const std::shared_ptr& K, + const std::shared_ptr& V, + const OutputVector& op_inputs) { + auto zero = v0::Constant::create(element::i64, Shape{1}, {0}); // expand K shape (batch_size, num_heads, sequence_length, head_size) to // (1, batch_size, num_heads, sequence_length, head_size) - auto K_unsqueezed = std::make_shared(K, zero); + auto K_unsqueezed = std::make_shared(K, zero); // similarly expand V shape - auto V_unsqueezed = std::make_shared(V, zero); + auto V_unsqueezed = std::make_shared(V, zero); // add padding in case K and V have different shapes (it happens when used provided uneven qkv_hidden_sizes) // if the shapes are equal (so padding will be zero), Pad gets eliminated in NopElimination pass - const auto K_shape = std::make_shared(K_unsqueezed); - const auto V_shape = std::make_shared(V_unsqueezed); - const auto K_pads_end = - std::make_shared(std::make_shared(V_shape, K_shape), zero); - const auto V_pads_end = - std::make_shared(std::make_shared(K_shape, V_shape), zero); - const auto pads_begin = - std::make_shared(zero, std::make_shared(K_shape)); - const auto K_padded = - std::make_shared(K_unsqueezed, pads_begin, K_pads_end, ngraph::op::PadMode::CONSTANT); - const auto V_padded = - std::make_shared(V_unsqueezed, pads_begin, V_pads_end, ngraph::op::PadMode::CONSTANT); + const auto K_shape = std::make_shared(K_unsqueezed); + const auto V_shape = std::make_shared(V_unsqueezed); + const auto K_pads_end = std::make_shared(std::make_shared(V_shape, K_shape), zero); + const auto V_pads_end = std::make_shared(std::make_shared(K_shape, V_shape), zero); + const auto pads_begin = std::make_shared(zero, std::make_shared(K_shape)); + const auto K_padded = std::make_shared(K_unsqueezed, pads_begin, K_pads_end, ov::op::PadMode::CONSTANT); + const auto V_padded = std::make_shared(V_unsqueezed, pads_begin, V_pads_end, ov::op::PadMode::CONSTANT); // concat key and value tensors along first axis to make 'present' state // after that operation, 'present' has shape (2, batch_size, num_heads, sequence_length, head_size) - std::shared_ptr present = std::make_shared(NodeVector{K_padded, V_padded}, 0); + std::shared_ptr present = std::make_shared(NodeVector{K_padded, V_padded}, 0); if (is_past_input_available(op_inputs)) { const auto& past = op_inputs[4]; // concat 'past' to 'present' output along fourth axis // after that operation, 'present' has shape: // (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) - present = std::make_shared(OutputVector{past, present}, 3); + present = std::make_shared(OutputVector{past, present}, 3); } return present; } diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/bias_gelu.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/bias_gelu.cpp index b5a6e58e78dc79..6b929766272bef 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/bias_gelu.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/bias_gelu.cpp @@ -4,8 +4,11 @@ #include "op/com.microsoft/bias_gelu.hpp" -#include "default_opset.hpp" #include "openvino/frontend/exception.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/gelu.hpp" + +using namespace ov::op; namespace ngraph { namespace onnx_import { @@ -14,7 +17,7 @@ namespace set_1 { OutputVector bias_gelu(const Node& node) { auto nodes = node.get_ng_inputs(); FRONT_END_GENERAL_CHECK(nodes.size() == 2, "BiasGelu takes 2 inputs. Provided " + std::to_string(nodes.size())); - return {std::make_shared(std::make_shared(nodes.at(0), nodes.at(1)))}; + return {std::make_shared(std::make_shared(nodes.at(0), nodes.at(1)))}; } } // namespace set_1 } // namespace op diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp index 05d1a8e47bba50..13e63051a2dc53 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp @@ -4,9 +4,19 @@ #include "op/com.microsoft/embed_layer_normalization.hpp" -#include "default_opset.hpp" #include "onnx_import/core/null_node.hpp" #include "openvino/frontend/exception.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/mvn.hpp" +#include "openvino/op/reduce_sum.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" + +using namespace ov::op; namespace ngraph { namespace onnx_import { @@ -28,15 +38,15 @@ OutputVector embed_layer_normalization(const Node& node) { const auto& gamma = nodes[5]; const auto& beta = nodes[6]; - const auto zero = default_opset::Constant::create(element::i32, Shape{1}, {0}); - std::shared_ptr input = std::make_shared(word_embeddings, input_ids, zero, 0); + const auto zero = v0::Constant::create(element::i32, Shape{1}, {0}); + std::shared_ptr input = std::make_shared(word_embeddings, input_ids, zero, 0); // add position embeddings if (num_nodes > 8 && !ov::op::util::is_null(nodes[8])) { // if we have position_ids const auto& position_ids = nodes[8]; const auto gathered_position_embeddings = - std::make_shared(position_embeddings, position_ids, zero, 0); - input = std::make_shared(input, gathered_position_embeddings); + std::make_shared(position_embeddings, position_ids, zero, 0); + input = std::make_shared(input, gathered_position_embeddings); } else { // input_ids' shape is [batchsize, sequence_length] // input's shape is [batchsize, sequence_length, hidden_size] @@ -44,21 +54,20 @@ OutputVector embed_layer_normalization(const Node& node) { // therefore input and position_embeddings cannot be added together // so we need slice the position_embeddings to [sequence_length, hidden_size] first // then add it with input. - const auto one = default_opset::Constant::create(element::i32, Shape{1}, {1}); - const auto input_ids_shape = std::make_shared(input_ids, element::i32); - const auto seqlen = std::make_shared(input_ids_shape, one, zero, 0); + const auto one = v0::Constant::create(element::i32, Shape{1}, {1}); + const auto input_ids_shape = std::make_shared(input_ids, element::i32); + const auto seqlen = std::make_shared(input_ids_shape, one, zero, 0); const auto gathered_position_embeddings = - std::make_shared(position_embeddings, zero, seqlen, one, zero); - input = std::make_shared(input, gathered_position_embeddings); + std::make_shared(position_embeddings, zero, seqlen, one, zero); + input = std::make_shared(input, gathered_position_embeddings); } // add segment embeddings if available if (!ov::op::util::is_null(segment_ids)) { FRONT_END_GENERAL_CHECK(!ov::op::util::is_null(segment_embeddings), "segment_ids provided, but segment_embedding input is missing"); FRONT_END_GENERAL_CHECK(nodes[1].get_element_type() == element::i32, "segment_ids must have int32 type"); - auto gathered_segment_embeddings = - std::make_shared(segment_embeddings, segment_ids, zero, 0); - input = std::make_shared(input, gathered_segment_embeddings); + auto gathered_segment_embeddings = std::make_shared(segment_embeddings, segment_ids, zero, 0); + input = std::make_shared(input, gathered_segment_embeddings); } float eps = node.get_attribute_value("epsilon"); @@ -66,25 +75,25 @@ OutputVector embed_layer_normalization(const Node& node) { // hidden_size dimension is 2 here, because the shape after Gather(word_embedding, input_ids) // is (batch_size, seq_len, hidden_size) int hidden_size_dim = 2; - const auto reduction_axes = default_opset::Constant::create(element::i32, Shape{1}, {hidden_size_dim}); - std::shared_ptr result = - std::make_shared(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT); + const auto reduction_axes = v0::Constant::create(element::i32, Shape{1}, {hidden_size_dim}); + std::shared_ptr result = + std::make_shared(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT); // result = gamma * result + beta - result = std::make_shared(result, gamma); - result = std::make_shared(result, beta); + result = std::make_shared(result, gamma); + result = std::make_shared(result, beta); // compute mask_index output - std::shared_ptr mask_index; + std::shared_ptr mask_index; if (num_nodes > 7 && !ov::op::util::is_null(nodes[7])) { FRONT_END_GENERAL_CHECK(nodes[7].get_element_type() == element::i32, "mask must have int32 type"); - auto axis = default_opset::Constant::create(element::i32, Shape{}, {1}); - mask_index = std::make_shared(nodes[7], axis, false); + auto axis = v0::Constant::create(element::i32, Shape{}, {1}); + mask_index = std::make_shared(nodes[7], axis, false); } else { - auto batch_size = std::make_shared(std::make_shared(nodes[0]), - zero, // indices - zero); // axis - mask_index = std::make_shared(zero, batch_size); + auto batch_size = std::make_shared(std::make_shared(nodes[0]), + zero, // indices + zero); // axis + mask_index = std::make_shared(zero, batch_size); } return {result, mask_index}; } diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/fused_conv.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/fused_conv.cpp index 1feafa08e4a1bb..38c120b332621d 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/fused_conv.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/fused_conv.cpp @@ -7,9 +7,19 @@ #include #include -#include "default_opset.hpp" #include "exceptions.hpp" #include "op/conv.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/clamp.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/hard_sigmoid.hpp" +#include "openvino/op/prelu.hpp" +#include "openvino/op/relu.hpp" +#include "openvino/op/sigmoid.hpp" +#include "openvino/op/tan.hpp" +#include "openvino/op/tanh.hpp" + +using namespace ov::op; namespace ngraph { namespace onnx_import { @@ -19,36 +29,36 @@ OutputVector fused_conv(const Node& node) { auto conv_res = conv(node).at(0); if (node.get_ng_inputs().size() == 4) { // Z input provided - conv_res = std::make_shared(conv_res, node.get_ng_inputs()[3]); + conv_res = std::make_shared(conv_res, node.get_ng_inputs()[3]); } const auto activation_type = node.get_attribute_value("activation"); const auto activation_params = node.get_attribute_value>("activation_params", {}); if (activation_type == "Relu") { - return {std::make_shared(conv_res)}; + return {std::make_shared(conv_res)}; } else if (activation_type == "Tanh") { - return {std::make_shared(conv_res)}; + return {std::make_shared(conv_res)}; } else if (activation_type == "Sigmoid") { - return {std::make_shared(conv_res)}; + return {std::make_shared(conv_res)}; } else if (activation_type == "Clip") { CHECK_VALID_NODE(node, activation_params.size() == 2, "min and max attributes of Clip activation function were not provided"); - return {std::make_shared(conv_res, activation_params[0], activation_params[1])}; + return {std::make_shared(conv_res, activation_params[0], activation_params[1])}; } else if (activation_type == "LeakyRelu") { CHECK_VALID_NODE(node, activation_params.size() == 1, "activation_alpha attribute of LeakyRelu activation function was not provided"); - const auto activation_alpha_node = default_opset::Constant::create(element::f32, Shape{}, activation_params); - return {std::make_shared(conv_res, activation_alpha_node)}; + const auto activation_alpha_node = v0::Constant::create(element::f32, Shape{}, activation_params); + return {std::make_shared(conv_res, activation_alpha_node)}; } else if (activation_type == "HardSigmoid") { CHECK_VALID_NODE(node, activation_params.size() == 2, "alpha and beta attributes of HardSigmoid activation function were not provided"); - const auto alpha = default_opset::Constant::create(element::f32, Shape{}, {activation_params[0]}); - const auto beta = default_opset::Constant::create(element::f32, Shape{}, {activation_params[1]}); - return {std::make_shared(conv_res, alpha, beta)}; + const auto alpha = v0::Constant::create(element::f32, Shape{}, {activation_params[0]}); + const auto beta = v0::Constant::create(element::f32, Shape{}, {activation_params[1]}); + return {std::make_shared(conv_res, alpha, beta)}; } CHECK_VALID_NODE(node, !activation_type.empty(), diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/fusedgemm.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/fusedgemm.cpp index 4af42e8263bcb5..6f6039e5496f4c 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/fusedgemm.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/fusedgemm.cpp @@ -6,13 +6,16 @@ #include -#include "default_opset.hpp" -#include "ngraph/op/add.hpp" -#include "ngraph/op/constant.hpp" -#include "ngraph/op/matmul.hpp" -#include "ngraph/op/multiply.hpp" #include "onnx_import/core/null_node.hpp" #include "openvino/frontend/exception.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/prelu.hpp" +#include "openvino/op/relu.hpp" + +using namespace ov::op; namespace ngraph { namespace onnx_import { @@ -24,14 +27,14 @@ OutputVector fusedgemm(const Node& node) { FRONT_END_GENERAL_CHECK(num_inputs == 2 || num_inputs == 3, "FusedGemm takes 2/3 inputs. Provided " + std::to_string(num_inputs)); - Output input_a = inputs.at(0); - Output input_b = inputs.at(1); - Output input_c; + Output input_a = inputs.at(0); + Output input_b = inputs.at(1); + Output input_c; if (num_inputs == 3 && !ov::op::util::is_null(inputs[2])) { input_c = inputs.at(2); } else { - input_c = default_opset::Constant::create(input_b.get_element_type(), ngraph::Shape{}, {0}); + input_c = v0::Constant::create(input_b.get_element_type(), ov::Shape{}, {0}); } const auto alpha_node = node.get_attribute_as_constant("alpha", 1, input_b.get_element_type()); @@ -40,22 +43,22 @@ OutputVector fusedgemm(const Node& node) { const bool trans_a = node.get_attribute_value("transA", 0); const bool trans_b = node.get_attribute_value("transB", 0); - const auto matmul_node = std::make_shared(input_a, input_b, trans_a, trans_b); - const auto matmul_times_alpha = std::make_shared(matmul_node, alpha_node); + const auto matmul_node = std::make_shared(input_a, input_b, trans_a, trans_b); + const auto matmul_times_alpha = std::make_shared(matmul_node, alpha_node); - const auto beta_times_input_c = std::make_shared(beta_node, input_c); + const auto beta_times_input_c = std::make_shared(beta_node, input_c); const std::string onnx_name = !node.get_name().empty() ? node.get_name() : node.output(0); matmul_node->set_friendly_name(onnx_name + "/WithoutBiases"); - const auto gemm_res = std::make_shared(matmul_times_alpha, beta_times_input_c); + const auto gemm_res = std::make_shared(matmul_times_alpha, beta_times_input_c); const auto activation_type = node.get_attribute_value("activation", "Relu"); if (activation_type == "LeakyRelu") { double activation_alpha = node.get_attribute_value("activation_alpha", 0.01); - std::shared_ptr activation_alpha_node = - default_opset::Constant::create(input_c.get_element_type(), Shape{1}, {activation_alpha}); - return {std::make_shared(gemm_res, activation_alpha_node)}; + std::shared_ptr activation_alpha_node = + v0::Constant::create(input_c.get_element_type(), Shape{1}, {activation_alpha}); + return {std::make_shared(gemm_res, activation_alpha_node)}; } - return {std::make_shared(gemm_res)}; + return {std::make_shared(gemm_res)}; } } // namespace set_1 diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/skip_layer_normalization.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/skip_layer_normalization.cpp index aed16be77b9c6b..72d8dc57fb5d36 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/skip_layer_normalization.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/skip_layer_normalization.cpp @@ -4,8 +4,13 @@ #include "op/com.microsoft/skip_layer_normalization.hpp" -#include "default_opset.hpp" #include "openvino/frontend/exception.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/mvn.hpp" + +using namespace ov::op; namespace ngraph { namespace onnx_import { @@ -18,22 +23,22 @@ OutputVector skip_layer_normalization(const Node& node) { "SkipLayerNormalization takes 3, 4 or 5 inputs. Provided " + std::to_string(num_nodes)); // input + skip - std::shared_ptr input = std::make_shared(nodes[0], nodes[1]); + std::shared_ptr input = std::make_shared(nodes[0], nodes[1]); // add bias if available if (num_nodes == 5) { - input = std::make_shared(input, nodes[4]); + input = std::make_shared(input, nodes[4]); } float eps = node.get_attribute_value("epsilon"); // reduce over hidden_size int hidden_size_dim = 2; - const auto reduction_axes = default_opset::Constant::create(element::i32, Shape{1}, {hidden_size_dim}); - std::shared_ptr result = - std::make_shared(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT); + const auto reduction_axes = v0::Constant::create(element::i32, Shape{1}, {hidden_size_dim}); + std::shared_ptr result = + std::make_shared(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT); // multiply by gamma - result = std::make_shared(result, nodes[2]); + result = std::make_shared(result, nodes[2]); // add beta if available if (num_nodes > 3) { - result = std::make_shared(result, nodes[3]); + result = std::make_shared(result, nodes[3]); } // spec mentions three outputs (output, mean, inv_std_var) while we support only first one, but: // - onnxruntime also doesn't support the last two