Skip to content

Commit

Permalink
fix random layer names and count (openvinotoolkit#20323)
Browse files Browse the repository at this point in the history
* add sorting for fix sporadic failure in SharedOpOptimization shared_node_optimization

* fix Output and Input comparison

* remove unneed sorting from transformation

* add unit test

* code review fixes

* code review fixes

* code review fixes

* code review fixes

---------

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
  • Loading branch information
2 people authored and allnes committed Nov 23, 2023
1 parent e280432 commit 80f8c59
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,74 @@ TEST_F(SharedTransformationTestsF, SharedShapeOfTestMixed) {
model_ref = std::make_shared<Model>(NodeVector{concat}, ParameterVector{input});
}
}

namespace {
OutputVector createShapeNodesInMemory(const std::vector<size_t>& node_order_in_memory,
std::shared_ptr<void>& memory,
const std::string& node_name_prefix,
const std::shared_ptr<Node>& input,
element::Type output_type) {
OutputVector outputs;
memory.reset(::malloc(node_order_in_memory.size() * sizeof(v3::ShapeOf)), ::free);
for (size_t i = 0; i < node_order_in_memory.size(); ++i) {
v3::ShapeOf* node_addr = static_cast<v3::ShapeOf*>(memory.get()) + node_order_in_memory[i];
auto node_ptr =
std::shared_ptr<v3::ShapeOf>(new (node_addr) v3::ShapeOf(input, output_type), [](v3::ShapeOf* node) {
node->v3::ShapeOf::~ShapeOf();
});
std::stringstream ss;
ss << node_name_prefix << i;
node_ptr->set_friendly_name(ss.str());
outputs.push_back(node_ptr->output(0));
}

return outputs;
}

std::shared_ptr<Model> createModelWithShapes(const Shape& input_shape,
const std::vector<size_t>& node_order_in_memory,
const std::string& node_name_prefix,
std::shared_ptr<void>& buffer) {
auto input = std::make_shared<v0::Parameter>(element::f32, input_shape);
auto shape_nodes = createShapeNodesInMemory(node_order_in_memory, buffer, node_name_prefix, input, element::i64);

NodeVector inputs_of_concat;
for (const auto& shape_node : shape_nodes) {
auto node = std::make_shared<v0::Convert>(shape_node, element::i64);
inputs_of_concat.push_back(node);
}

auto concat = std::make_shared<v0::Concat>(inputs_of_concat, 0);
return std::make_shared<Model>(NodeVector{concat}, ParameterVector{input});
}
} // namespace

/**
* @brief Check that node address is not influenced on the transformation result
*/
TEST(TransformationTests, SharedShapeOfTestRandomOrder) {
Shape input_shape{120, 4};
std::shared_ptr<void> buffer;
// nodes are placed into pre-allocated memory in order that is specified in next variable
std::vector<std::vector<size_t>> node_orders_in_memory = {{0, 1}, {1, 0}};

std::vector<std::shared_ptr<Model>> models;
for (const auto& node_order_in_memory : node_orders_in_memory) {
auto model = createModelWithShapes(input_shape, node_order_in_memory, "Shape_", buffer);

ov::pass::Manager manager;
manager.register_pass<pass::SharedOpOptimization>();
manager.run_passes(model);

const auto model_ops = model->get_ops();
const auto op_it = std::find_if(model_ops.begin(), model_ops.end(), [](const std::shared_ptr<Node>& node) {
return node->get_friendly_name() == "Shape_0";
});
ASSERT_TRUE(op_it != model_ops.end()) << "node Shape_0 is not found in model";
// we need to clone while memory will be reused on the next iteration for the new model
models.push_back(model->clone());
}

FunctionsComparator comparator = FunctionsComparator::with_default();
comparator.compare(models[0], models[1]);
}
13 changes: 9 additions & 4 deletions src/core/src/node_input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ bool Input<Node>::operator==(const Input& other) const {
bool Input<Node>::operator!=(const Input& other) const {
return !(*this == other);
}

bool Input<Node>::operator<(const Input& other) const {
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
return m_node->get_instance_id() < other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index < other.m_index);
}

bool Input<Node>::operator>(const Input& other) const {
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
return m_node->get_instance_id() > other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index > other.m_index);
}

bool Input<Node>::operator<=(const Input& other) const {
Expand Down Expand Up @@ -135,11 +138,13 @@ bool Input<const Node>::operator!=(const Input& other) const {
return !(*this == other);
}
bool Input<const Node>::operator<(const Input& other) const {
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
return m_node->get_instance_id() < other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index < other.m_index);
}

bool Input<const Node>::operator>(const Input& other) const {
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
return m_node->get_instance_id() > other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index > other.m_index);
}

bool Input<const Node>::operator<=(const Input& other) const {
Expand Down
12 changes: 8 additions & 4 deletions src/core/src/node_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ bool Output<Node>::operator!=(const Output& other) const {
return !(*this == other);
}
bool Output<Node>::operator<(const Output& other) const {
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
return m_node->get_instance_id() < other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index < other.m_index);
}
bool Output<Node>::operator>(const Output& other) const {
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
return m_node->get_instance_id() > other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index > other.m_index);
}
bool Output<Node>::operator<=(const Output& other) const {
return !(*this > other);
Expand Down Expand Up @@ -211,10 +213,12 @@ bool Output<const Node>::operator!=(const Output& other) const {
return !(*this == other);
}
bool Output<const Node>::operator<(const Output& other) const {
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
return m_node->get_instance_id() < other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index < other.m_index);
}
bool Output<const Node>::operator>(const Output& other) const {
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
return m_node->get_instance_id() > other.m_node->get_instance_id() ||
(m_node == other.m_node && m_index > other.m_index);
}
bool Output<const Node>::operator<=(const Output& other) const {
return !(*this > other);
Expand Down

0 comments on commit 80f8c59

Please sign in to comment.