Skip to content

Commit

Permalink
[CPU] Reimplement TopologicalSort (#21911)
Browse files Browse the repository at this point in the history
to get rid of extra Node class member variables
  • Loading branch information
EgorDuplensky authored Jan 16, 2024
1 parent f4e1ef5 commit 2126eea
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 47 deletions.
66 changes: 27 additions & 39 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,54 +1299,42 @@ void Graph::Infer(SyncInferRequest* request) {
if (infer_count != -1) infer_count++;
}

void Graph::VisitNode(NodePtr node, std::vector<NodePtr>& sortedNodes) {
if (node->temporary) {
return;
}

if (node->permanent) {
return;
}

node->temporary = true;

for (size_t i = 0; i < node->getChildEdges().size(); i++) {
VisitNode(node->getChildEdgeAt(i)->getChild(), sortedNodes);
}

node->permanent = true;
node->temporary = false;

sortedNodes.insert(sortedNodes.begin(), node);
}

void Graph::SortTopologically() {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::SortTopologically");

std::vector<NodePtr> unsorted;
std::vector<NodePtr> sorted;
auto sort = [](const std::vector<NodePtr>& nodes) {
std::unordered_set<NodePtr> visited;
visited.reserve(nodes.size());
std::vector<NodePtr> sorted;
sorted.reserve(nodes.size());

for (size_t i = 0; i < graphNodes.size(); i++) {
NodePtr node = graphNodes[i];
std::function<void(const NodePtr)> visit;
visit = [&visited, &sorted, &visit](const NodePtr node) {
const bool inserted = visited.insert(node).second;
if (!inserted)
return; // already visited

node->permanent = false;
node->temporary = false;

unsorted.push_back(node);
}
for (size_t i = 0; i < node->getChildEdges().size(); i++) {
visit(node->getChildEdgeAt(i)->getChild());
}

while (!unsorted.empty()) {
NodePtr node = unsorted.at(0);
unsorted.erase(unsorted.begin());
sorted.push_back(node);
};

VisitNode(node, sorted);
}
for (const auto& node : nodes) {
visit(node);
}

for (size_t i = 0; i < sorted.size(); i++)
sorted[i]->execIndex = static_cast<int>(i);
return sorted;
};

graphNodes.erase(graphNodes.begin(), graphNodes.end());
graphNodes.assign(sorted.begin(), sorted.end());
// as a first step sort in reversed topological order to avoid an insertion into the front of the vector
graphNodes = sort(graphNodes);
// reverse to the actual topological order
std::reverse(graphNodes.begin(), graphNodes.end());
// number the nodes based on topological order
for (size_t i = 0; i < graphNodes.size(); i++)
graphNodes[i]->execIndex = static_cast<int>(i);

// TODO: Sort in/out edges by port index because of backward compatibility
// A lot of plugin logic are build on top of assumption that index in
Expand Down
2 changes: 0 additions & 2 deletions src/plugins/intel_cpu/src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ class Graph {
}

protected:
void VisitNode(NodePtr node, std::vector<NodePtr>& sortedNodes);

void ForgetGraphData() {
status = Status::NotReady;

Expand Down
4 changes: 0 additions & 4 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ Node::Node(const std::shared_ptr<ov::Node>& op,
const GraphContext::CPtr ctx,
const ShapeInferFactory& shapeInferFactory)
: selectedPrimitiveDescriptorIndex(-1),
permanent(false),
temporary(false),
constant(ConstantType::NoConst),
context(ctx),
algorithm(Algorithm::Default),
Expand Down Expand Up @@ -182,8 +180,6 @@ Node::Node(const std::shared_ptr<ov::Node>& op,

Node::Node(const std::string& type, const std::string& name, const GraphContext::CPtr ctx)
: selectedPrimitiveDescriptorIndex(-1),
permanent(false),
temporary(false),
constant(ConstantType::NoConst),
context(ctx),
fusingPort(-1),
Expand Down
2 changes: 0 additions & 2 deletions src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,6 @@ class Node {
Node(const std::string& type, const std::string& name, const GraphContext::CPtr ctx);

int selectedPrimitiveDescriptorIndex = -1;
bool permanent = false;
bool temporary = false;

enum class InPlaceType {
Unknown,
Expand Down

0 comments on commit 2126eea

Please sign in to comment.