diff --git a/include/cppdl/trace.h b/include/cppdl/trace.h index 3a9613c..bf035c5 100644 --- a/include/cppdl/trace.h +++ b/include/cppdl/trace.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include @@ -109,6 +111,32 @@ class NeuralNetwork { } const Graph &getGraph() { return graph; } + + std::vector topologicalSort() { + // TODO: cycle detection. + std::vector list; + + std::function topoVisit = [&](NodeId nodeId) { + if (std::find(list.begin(), list.end(), nodeId) != list.end()) { + return; + } + Node *node = graph.getNode(nodeId); + for (NodeId successor : node->getSuccessors()) { + topoVisit(successor); + } + list.push_back(nodeId); + }; + + for (auto nodeId : inputTensors) { + topoVisit(nodeId); + } + for (auto nodeId : paramTensors) { + topoVisit(nodeId); + } + + std::reverse(list.begin(), list.end()); + return list; + } }; class LinearLayer {