From 80f8c59ae8436a64024e269ccfb55a4c31831b74 Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Tue, 24 Oct 2023 18:18:50 +0200 Subject: [PATCH] fix random layer names and count (#20323) * add sorting for fix sporadic failure in SharedOpOptimization shared_node_optimization * fix Output and Input comparison * remove unneed sorting from transformation * add unit test * code review fixes * code review fixes * code review fixes * code review fixes --------- Co-authored-by: Ivan Tikhonov --- .../shared_ops_optimization.cpp | 71 +++++++++++++++++++ src/core/src/node_input.cpp | 13 ++-- src/core/src/node_output.cpp | 12 ++-- 3 files changed, 88 insertions(+), 8 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/shared_ops_optimization.cpp b/src/common/transformations/tests/common_optimizations/shared_ops_optimization.cpp index 698973740e08e6..b0e327e4d4bad4 100644 --- a/src/common/transformations/tests/common_optimizations/shared_ops_optimization.cpp +++ b/src/common/transformations/tests/common_optimizations/shared_ops_optimization.cpp @@ -433,3 +433,74 @@ TEST_F(SharedTransformationTestsF, SharedShapeOfTestMixed) { model_ref = std::make_shared(NodeVector{concat}, ParameterVector{input}); } } + +namespace { +OutputVector createShapeNodesInMemory(const std::vector& node_order_in_memory, + std::shared_ptr& memory, + const std::string& node_name_prefix, + const std::shared_ptr& input, + element::Type output_type) { + OutputVector outputs; + memory.reset(::malloc(node_order_in_memory.size() * sizeof(v3::ShapeOf)), ::free); + for (size_t i = 0; i < node_order_in_memory.size(); ++i) { + v3::ShapeOf* node_addr = static_cast(memory.get()) + node_order_in_memory[i]; + auto node_ptr = + std::shared_ptr(new (node_addr) v3::ShapeOf(input, output_type), [](v3::ShapeOf* node) { + node->v3::ShapeOf::~ShapeOf(); + }); + std::stringstream ss; + ss << node_name_prefix << i; + node_ptr->set_friendly_name(ss.str()); + outputs.push_back(node_ptr->output(0)); + } + + return outputs; +} + +std::shared_ptr createModelWithShapes(const Shape& input_shape, + const std::vector& node_order_in_memory, + const std::string& node_name_prefix, + std::shared_ptr& buffer) { + auto input = std::make_shared(element::f32, input_shape); + auto shape_nodes = createShapeNodesInMemory(node_order_in_memory, buffer, node_name_prefix, input, element::i64); + + NodeVector inputs_of_concat; + for (const auto& shape_node : shape_nodes) { + auto node = std::make_shared(shape_node, element::i64); + inputs_of_concat.push_back(node); + } + + auto concat = std::make_shared(inputs_of_concat, 0); + return std::make_shared(NodeVector{concat}, ParameterVector{input}); +} +} // namespace + +/** + * @brief Check that node address is not influenced on the transformation result + */ +TEST(TransformationTests, SharedShapeOfTestRandomOrder) { + Shape input_shape{120, 4}; + std::shared_ptr buffer; + // nodes are placed into pre-allocated memory in order that is specified in next variable + std::vector> node_orders_in_memory = {{0, 1}, {1, 0}}; + + std::vector> models; + for (const auto& node_order_in_memory : node_orders_in_memory) { + auto model = createModelWithShapes(input_shape, node_order_in_memory, "Shape_", buffer); + + ov::pass::Manager manager; + manager.register_pass(); + manager.run_passes(model); + + const auto model_ops = model->get_ops(); + const auto op_it = std::find_if(model_ops.begin(), model_ops.end(), [](const std::shared_ptr& node) { + return node->get_friendly_name() == "Shape_0"; + }); + ASSERT_TRUE(op_it != model_ops.end()) << "node Shape_0 is not found in model"; + // we need to clone while memory will be reused on the next iteration for the new model + models.push_back(model->clone()); + } + + FunctionsComparator comparator = FunctionsComparator::with_default(); + comparator.compare(models[0], models[1]); +} diff --git a/src/core/src/node_input.cpp b/src/core/src/node_input.cpp index 7c6b8a9ff2102c..11a353cb765b49 100644 --- a/src/core/src/node_input.cpp +++ b/src/core/src/node_input.cpp @@ -60,12 +60,15 @@ bool Input::operator==(const Input& other) const { bool Input::operator!=(const Input& other) const { return !(*this == other); } + bool Input::operator<(const Input& other) const { - return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index); + return m_node->get_instance_id() < other.m_node->get_instance_id() || + (m_node == other.m_node && m_index < other.m_index); } bool Input::operator>(const Input& other) const { - return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index); + return m_node->get_instance_id() > other.m_node->get_instance_id() || + (m_node == other.m_node && m_index > other.m_index); } bool Input::operator<=(const Input& other) const { @@ -135,11 +138,13 @@ bool Input::operator!=(const Input& other) const { return !(*this == other); } bool Input::operator<(const Input& other) const { - return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index); + return m_node->get_instance_id() < other.m_node->get_instance_id() || + (m_node == other.m_node && m_index < other.m_index); } bool Input::operator>(const Input& other) const { - return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index); + return m_node->get_instance_id() > other.m_node->get_instance_id() || + (m_node == other.m_node && m_index > other.m_index); } bool Input::operator<=(const Input& other) const { diff --git a/src/core/src/node_output.cpp b/src/core/src/node_output.cpp index fbd7d3f172280c..4d5de39b75132a 100644 --- a/src/core/src/node_output.cpp +++ b/src/core/src/node_output.cpp @@ -137,10 +137,12 @@ bool Output::operator!=(const Output& other) const { return !(*this == other); } bool Output::operator<(const Output& other) const { - return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index); + return m_node->get_instance_id() < other.m_node->get_instance_id() || + (m_node == other.m_node && m_index < other.m_index); } bool Output::operator>(const Output& other) const { - return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index); + return m_node->get_instance_id() > other.m_node->get_instance_id() || + (m_node == other.m_node && m_index > other.m_index); } bool Output::operator<=(const Output& other) const { return !(*this > other); @@ -211,10 +213,12 @@ bool Output::operator!=(const Output& other) const { return !(*this == other); } bool Output::operator<(const Output& other) const { - return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index); + return m_node->get_instance_id() < other.m_node->get_instance_id() || + (m_node == other.m_node && m_index < other.m_index); } bool Output::operator>(const Output& other) const { - return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index); + return m_node->get_instance_id() > other.m_node->get_instance_id() || + (m_node == other.m_node && m_index > other.m_index); } bool Output::operator<=(const Output& other) const { return !(*this > other);