From e2ebd84ee7a51c6dac9f53c8b6c0ce32b990b4ed Mon Sep 17 00:00:00 2001 From: fotcorn Date: Wed, 17 Apr 2024 23:20:05 +0100 Subject: [PATCH] add basic topological sort --- include/cppdl/trace.h | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/include/cppdl/trace.h b/include/cppdl/trace.h index 3a9613c..bf035c5 100644 --- a/include/cppdl/trace.h +++ b/include/cppdl/trace.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include @@ -109,6 +111,32 @@ class NeuralNetwork { } const Graph &getGraph() { return graph; } + + std::vector topologicalSort() { + // TODO: cycle detection. + std::vector list; + + std::function topoVisit = [&](NodeId nodeId) { + if (std::find(list.begin(), list.end(), nodeId) != list.end()) { + return; + } + Node *node = graph.getNode(nodeId); + for (NodeId successor : node->getSuccessors()) { + topoVisit(successor); + } + list.push_back(nodeId); + }; + + for (auto nodeId : inputTensors) { + topoVisit(nodeId); + } + for (auto nodeId : paramTensors) { + topoVisit(nodeId); + } + + std::reverse(list.begin(), list.end()); + return list; + } }; class LinearLayer {