Skip to content

Commit c38d083

Browse files
committed
graph: add successors to node
1 parent 9155ee3 commit c38d083

File tree

5 files changed

+62
-43
lines changed

5 files changed

+62
-43
lines changed

examples/graph.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ int main() {
2020
g.addNode<AddNode>(r0, biasL0, std::vector<std::size_t>({30, 16}));
2121
NodeId output = g.addNode<ReLUNode>(r1, std::vector<std::size_t>({30, 16}));
2222

23-
printGraph(output);
23+
printGraphBackwards(g, output);
2424
}

examples/mnist_trace.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ int main() {
2929
auto z2 = layer2.forward(a1);
3030
auto result = z2.softmax(1);
3131

32-
printGraph(result.nodeId);
32+
printGraphBackwards(nn.getGraph(), result.nodeId);
3333

3434
return 0;
3535
}

examples/trace.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ int main() {
1212
auto h1 = x.matmul(w1).add(b1).relu();
1313
auto y = h1.matmul(w2).add(b2).relu();
1414

15-
printGraph(y.nodeId);
15+
printGraphBackwards(nn.getGraph(), y.nodeId);
1616
}

include/cppdl/graph.h

+57-40
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ class Graph {
2424
NodeId addNode(Args &&...args) {
2525
static_assert(std::is_base_of<Node, GraphNodeType>::value,
2626
"GraphNodeType must be a derivative of Node");
27-
new (nextMemory) GraphNodeType(std::forward<Args>(args)...);
27+
new (nextMemory) GraphNodeType(*this, std::forward<Args>(args)...);
2828
NodeId id = reinterpret_cast<NodeId>(nextMemory);
2929
nextMemory += sizeof(GraphNodeType);
3030
return id;
3131
}
32+
33+
Node *getNode(NodeId id) const { return reinterpret_cast<Node *>(id); }
3234
};
3335

3436
enum class NodeKind {
@@ -48,6 +50,7 @@ enum class NodeKind {
4850
class Node {
4951
NodeKind nodeKind;
5052
std::vector<std::size_t> shape;
53+
std::vector<NodeId> successors;
5154

5255
protected:
5356
Node(NodeKind type, std::vector<std::size_t> shape)
@@ -56,6 +59,8 @@ class Node {
5659
public:
5760
NodeKind getKind() const { return nodeKind; }
5861
std::vector<std::size_t> getShape() const { return shape; }
62+
void addSuccessor(NodeId successor) { successors.push_back(successor); }
63+
const std::vector<NodeId> &getSuccessors() const { return successors; }
5964
};
6065

6166
class BinaryNode : public Node {
@@ -67,9 +72,12 @@ class BinaryNode : public Node {
6772
NodeId getInputB() const { return inputB; }
6873

6974
protected:
70-
BinaryNode(NodeKind type, std::vector<std::size_t> shape, NodeId inputA,
71-
NodeId inputB)
72-
: Node(type, shape), inputA(inputA), inputB(inputB) {}
75+
BinaryNode(Graph &graph, NodeKind type, std::vector<std::size_t> shape,
76+
NodeId inputA, NodeId inputB)
77+
: Node(type, shape), inputA(inputA), inputB(inputB) {
78+
graph.getNode(inputA)->addSuccessor(reinterpret_cast<NodeId>(this));
79+
graph.getNode(inputB)->addSuccessor(reinterpret_cast<NodeId>(this));
80+
}
7381
};
7482

7583
class UnaryNode : public Node {
@@ -79,62 +87,69 @@ class UnaryNode : public Node {
7987
NodeId getInput() const { return input; }
8088

8189
protected:
82-
UnaryNode(NodeKind type, std::vector<std::size_t> shape, NodeId input)
83-
: Node(type, shape), input(input) {}
90+
UnaryNode(Graph &graph, NodeKind type, std::vector<std::size_t> shape,
91+
NodeId input)
92+
: Node(type, shape), input(input) {
93+
graph.getNode(input)->addSuccessor(reinterpret_cast<NodeId>(this));
94+
}
8495
};
8596

8697
class AddNode : public BinaryNode {
8798
public:
88-
AddNode(NodeId inputA, NodeId inputB, std::vector<std::size_t> shape)
89-
: BinaryNode(NodeKind::Add, shape, inputA, inputB) {}
99+
AddNode(Graph &graph, NodeId inputA, NodeId inputB,
100+
std::vector<std::size_t> shape)
101+
: BinaryNode(graph, NodeKind::Add, shape, inputA, inputB) {}
90102
static bool classof(const Node *node) {
91103
return node->getKind() == NodeKind::Add;
92104
}
93105
};
94106

95107
class SubNode : public BinaryNode {
96108
public:
97-
SubNode(NodeId inputA, NodeId inputB, std::vector<std::size_t> shape)
98-
: BinaryNode(NodeKind::Sub, shape, inputA, inputB) {}
109+
SubNode(Graph &graph, NodeId inputA, NodeId inputB,
110+
std::vector<std::size_t> shape)
111+
: BinaryNode(graph, NodeKind::Sub, shape, inputA, inputB) {}
99112
static bool classof(const Node *node) {
100113
return node->getKind() == NodeKind::Sub;
101114
}
102115
};
103116

104117
class MulNode : public BinaryNode {
105118
public:
106-
MulNode(NodeId inputA, NodeId inputB, std::vector<std::size_t> shape)
107-
: BinaryNode(NodeKind::Mul, shape, inputA, inputB) {}
119+
MulNode(Graph &graph, NodeId inputA, NodeId inputB,
120+
std::vector<std::size_t> shape)
121+
: BinaryNode(graph, NodeKind::Mul, shape, inputA, inputB) {}
108122
static bool classof(const Node *node) {
109123
return node->getKind() == NodeKind::Mul;
110124
}
111125
};
112126

113127
class DivNode : public BinaryNode {
114128
public:
115-
DivNode(NodeId inputA, NodeId inputB, std::vector<std::size_t> shape)
116-
: BinaryNode(NodeKind::Div, shape, inputA, inputB) {}
129+
DivNode(Graph &graph, NodeId inputA, NodeId inputB,
130+
std::vector<std::size_t> shape)
131+
: BinaryNode(graph, NodeKind::Div, shape, inputA, inputB) {}
117132
static bool classof(const Node *node) {
118133
return node->getKind() == NodeKind::Div;
119134
}
120135
};
121136

122137
class MatMulNode : public BinaryNode {
123138
public:
124-
MatMulNode(NodeId inputA, NodeId inputB, std::vector<std::size_t> shape)
125-
: BinaryNode(NodeKind::MatMul, shape, inputA, inputB) {}
139+
MatMulNode(Graph &graph, NodeId inputA, NodeId inputB,
140+
std::vector<std::size_t> shape)
141+
: BinaryNode(graph, NodeKind::MatMul, shape, inputA, inputB) {}
126142
static bool classof(const Node *node) {
127143
return node->getKind() == NodeKind::MatMul;
128144
}
129145
};
130146

131-
class TensorNode : public UnaryNode {
147+
class TensorNode : public Node {
132148
std::string name;
133149

134150
public:
135-
TensorNode(std::string name, std::vector<std::size_t> shape)
136-
: UnaryNode(NodeKind::Tensor, shape, std::numeric_limits<NodeId>::max()),
137-
name(std::move(name)) {}
151+
TensorNode(Graph &, std::string name, std::vector<std::size_t> shape)
152+
: Node(NodeKind::Tensor, shape), name(name) {}
138153
static bool classof(const Node *node) {
139154
return node->getKind() == NodeKind::Tensor;
140155
}
@@ -143,25 +158,25 @@ class TensorNode : public UnaryNode {
143158

144159
class ReLUNode : public UnaryNode {
145160
public:
146-
ReLUNode(NodeId input, std::vector<std::size_t> shape)
147-
: UnaryNode(NodeKind::ReLU, shape, input) {}
161+
ReLUNode(Graph &graph, NodeId input, std::vector<std::size_t> shape)
162+
: UnaryNode(graph, NodeKind::ReLU, shape, input) {}
148163
static bool classof(const Node *node) {
149164
return node->getKind() == NodeKind::ReLU;
150165
}
151166
};
152167
class TransposeNode : public UnaryNode {
153168
public:
154-
TransposeNode(NodeId input, std::vector<std::size_t> shape)
155-
: UnaryNode(NodeKind::Transpose, shape, input) {}
169+
TransposeNode(Graph &graph, NodeId input, std::vector<std::size_t> shape)
170+
: UnaryNode(graph, NodeKind::Transpose, shape, input) {}
156171
static bool classof(const Node *node) {
157172
return node->getKind() == NodeKind::Transpose;
158173
}
159174
};
160175
class ReshapeNode : public UnaryNode {
161176

162177
public:
163-
ReshapeNode(NodeId input, std::vector<std::size_t> shape)
164-
: UnaryNode(NodeKind::Reshape, shape, input) {}
178+
ReshapeNode(Graph &graph, NodeId input, std::vector<std::size_t> shape)
179+
: UnaryNode(graph, NodeKind::Reshape, shape, input) {}
165180
static bool classof(const Node *node) {
166181
return node->getKind() == NodeKind::Reshape;
167182
}
@@ -170,8 +185,9 @@ class SumNode : public UnaryNode {
170185
std::size_t dim;
171186

172187
public:
173-
SumNode(NodeId input, std::size_t dim, std::vector<std::size_t> shape)
174-
: UnaryNode(NodeKind::Sum, shape, input), dim(dim) {}
188+
SumNode(Graph &graph, NodeId input, std::size_t dim,
189+
std::vector<std::size_t> shape)
190+
: UnaryNode(graph, NodeKind::Sum, shape, input), dim(dim) {}
175191
static bool classof(const Node *node) {
176192
return node->getKind() == NodeKind::Sum;
177193
}
@@ -181,8 +197,9 @@ class SoftmaxNode : public UnaryNode {
181197
std::size_t dim;
182198

183199
public:
184-
SoftmaxNode(NodeId input, std::size_t dim, std::vector<std::size_t> shape)
185-
: UnaryNode(NodeKind::Softmax, shape, input), dim(dim) {}
200+
SoftmaxNode(Graph &graph, NodeId input, std::size_t dim,
201+
std::vector<std::size_t> shape)
202+
: UnaryNode(graph, NodeKind::Softmax, shape, input), dim(dim) {}
186203
static bool classof(const Node *node) {
187204
return node->getKind() == NodeKind::Softmax;
188205
}
@@ -195,33 +212,33 @@ ToType *cast(FromType *object) {
195212
return reinterpret_cast<ToType *>(object);
196213
}
197214

198-
void printNode(NodeId nodeId) {
199-
Node *node = reinterpret_cast<Node *>(nodeId);
215+
void printNodeBackwards(const Graph &graph, NodeId nodeId) {
216+
Node *node = graph.getNode(nodeId);
200217

201218
switch (node->getKind()) {
202219
case NodeKind::Add: {
203220
AddNode *addNode = cast<AddNode>(node);
204221
fmt::println("{} [label=\"Add {}\"];", nodeId, addNode->getShape());
205222
fmt::println("{} -> {};", addNode->getInputA(), nodeId);
206223
fmt::println("{} -> {};", addNode->getInputB(), nodeId);
207-
printNode(addNode->getInputA());
208-
printNode(addNode->getInputB());
224+
printNodeBackwards(graph, addNode->getInputA());
225+
printNodeBackwards(graph, addNode->getInputB());
209226
break;
210227
}
211228
case NodeKind::MatMul: {
212229
MatMulNode *matMulNode = cast<MatMulNode>(node);
213230
fmt::println("{} [label=\"MatMul {}\"];", nodeId, matMulNode->getShape());
214231
fmt::println("{} -> {};", matMulNode->getInputA(), nodeId);
215232
fmt::println("{} -> {};", matMulNode->getInputB(), nodeId);
216-
printNode(matMulNode->getInputA());
217-
printNode(matMulNode->getInputB());
233+
printNodeBackwards(graph, matMulNode->getInputA());
234+
printNodeBackwards(graph, matMulNode->getInputB());
218235
break;
219236
}
220237
case NodeKind::ReLU: {
221238
ReLUNode *reluNode = cast<ReLUNode>(node);
222239
fmt::println("{} [label=\"ReLU {}\"];", nodeId, reluNode->getShape());
223240
fmt::println("{} -> {};", reluNode->getInput(), nodeId);
224-
printNode(reluNode->getInput());
241+
printNodeBackwards(graph, reluNode->getInput());
225242
break;
226243
}
227244
case NodeKind::Tensor: {
@@ -234,7 +251,7 @@ void printNode(NodeId nodeId) {
234251
SoftmaxNode *softmaxNode = cast<SoftmaxNode>(node);
235252
fmt::println("{} [label=\"Softmax {}\"];", nodeId, softmaxNode->getShape());
236253
fmt::println("{} -> {};", softmaxNode->getInput(), nodeId);
237-
printNode(softmaxNode->getInput());
254+
printNodeBackwards(graph, softmaxNode->getInput());
238255
break;
239256
}
240257
default:
@@ -243,8 +260,8 @@ void printNode(NodeId nodeId) {
243260
}
244261
}
245262

246-
void printGraph(NodeId outputNodeId) {
263+
void printGraphBackwards(const Graph &graph, NodeId outputNodeId) {
247264
fmt::println("digraph {{");
248-
printNode(outputNodeId);
265+
printNodeBackwards(graph, outputNodeId);
249266
fmt::println("}}");
250267
}

include/cppdl/trace.h

+2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class NeuralNetwork {
107107
inputTensors.push_back(id);
108108
return TraceTensor(shape, graph, id);
109109
}
110+
111+
const Graph &getGraph() { return graph; }
110112
};
111113

112114
class LinearLayer {

0 commit comments

Comments
 (0)