Skip to content

Commit

Permalink
add basic topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
fotcorn committed Apr 17, 2024
1 parent c38d083 commit e2ebd84
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions include/cppdl/trace.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <algorithm>
#include <functional>
#include <iostream>
#include <vector>

Expand Down Expand Up @@ -109,6 +111,32 @@ class NeuralNetwork {
}

const Graph &getGraph() { return graph; }

std::vector<NodeId> topologicalSort() {
// TODO: cycle detection.
std::vector<NodeId> list;

std::function<void(NodeId)> topoVisit = [&](NodeId nodeId) {
if (std::find(list.begin(), list.end(), nodeId) != list.end()) {
return;
}
Node *node = graph.getNode(nodeId);
for (NodeId successor : node->getSuccessors()) {
topoVisit(successor);
}
list.push_back(nodeId);
};

for (auto nodeId : inputTensors) {
topoVisit(nodeId);
}
for (auto nodeId : paramTensors) {
topoVisit(nodeId);
}

std::reverse(list.begin(), list.end());
return list;
}
};

class LinearLayer {
Expand Down

0 comments on commit e2ebd84

Please sign in to comment.