@@ -24,11 +24,13 @@ class Graph {
24
24
NodeId addNode (Args &&...args) {
25
25
static_assert (std::is_base_of<Node, GraphNodeType>::value,
26
26
" GraphNodeType must be a derivative of Node" );
27
- new (nextMemory) GraphNodeType (std::forward<Args>(args)...);
27
+ new (nextMemory) GraphNodeType (* this , std::forward<Args>(args)...);
28
28
NodeId id = reinterpret_cast <NodeId>(nextMemory);
29
29
nextMemory += sizeof (GraphNodeType);
30
30
return id;
31
31
}
32
+
33
+ Node *getNode (NodeId id) const { return reinterpret_cast <Node *>(id); }
32
34
};
33
35
34
36
enum class NodeKind {
@@ -48,6 +50,7 @@ enum class NodeKind {
48
50
class Node {
49
51
NodeKind nodeKind;
50
52
std::vector<std::size_t > shape;
53
+ std::vector<NodeId> successors;
51
54
52
55
protected:
53
56
Node (NodeKind type, std::vector<std::size_t > shape)
@@ -56,6 +59,8 @@ class Node {
56
59
public:
57
60
NodeKind getKind () const { return nodeKind; }
58
61
std::vector<std::size_t > getShape () const { return shape; }
62
+ void addSuccessor (NodeId successor) { successors.push_back (successor); }
63
+ const std::vector<NodeId> &getSuccessors () const { return successors; }
59
64
};
60
65
61
66
class BinaryNode : public Node {
@@ -67,9 +72,12 @@ class BinaryNode : public Node {
67
72
NodeId getInputB () const { return inputB; }
68
73
69
74
protected:
70
- BinaryNode (NodeKind type, std::vector<std::size_t > shape, NodeId inputA,
71
- NodeId inputB)
72
- : Node(type, shape), inputA(inputA), inputB(inputB) {}
75
+ BinaryNode (Graph &graph, NodeKind type, std::vector<std::size_t > shape,
76
+ NodeId inputA, NodeId inputB)
77
+ : Node(type, shape), inputA(inputA), inputB(inputB) {
78
+ graph.getNode (inputA)->addSuccessor (reinterpret_cast <NodeId>(this ));
79
+ graph.getNode (inputB)->addSuccessor (reinterpret_cast <NodeId>(this ));
80
+ }
73
81
};
74
82
75
83
class UnaryNode : public Node {
@@ -79,62 +87,69 @@ class UnaryNode : public Node {
79
87
NodeId getInput () const { return input; }
80
88
81
89
protected:
82
- UnaryNode (NodeKind type, std::vector<std::size_t > shape, NodeId input)
83
- : Node(type, shape), input(input) {}
90
+ UnaryNode (Graph &graph, NodeKind type, std::vector<std::size_t > shape,
91
+ NodeId input)
92
+ : Node(type, shape), input(input) {
93
+ graph.getNode (input)->addSuccessor (reinterpret_cast <NodeId>(this ));
94
+ }
84
95
};
85
96
86
97
class AddNode : public BinaryNode {
87
98
public:
88
- AddNode (NodeId inputA, NodeId inputB, std::vector<std::size_t > shape)
89
- : BinaryNode(NodeKind::Add, shape, inputA, inputB) {}
99
+ AddNode (Graph &graph, NodeId inputA, NodeId inputB,
100
+ std::vector<std::size_t > shape)
101
+ : BinaryNode(graph, NodeKind::Add, shape, inputA, inputB) {}
90
102
static bool classof (const Node *node) {
91
103
return node->getKind () == NodeKind::Add;
92
104
}
93
105
};
94
106
95
107
class SubNode : public BinaryNode {
96
108
public:
97
- SubNode (NodeId inputA, NodeId inputB, std::vector<std::size_t > shape)
98
- : BinaryNode(NodeKind::Sub, shape, inputA, inputB) {}
109
+ SubNode (Graph &graph, NodeId inputA, NodeId inputB,
110
+ std::vector<std::size_t > shape)
111
+ : BinaryNode(graph, NodeKind::Sub, shape, inputA, inputB) {}
99
112
static bool classof (const Node *node) {
100
113
return node->getKind () == NodeKind::Sub;
101
114
}
102
115
};
103
116
104
117
class MulNode : public BinaryNode {
105
118
public:
106
- MulNode (NodeId inputA, NodeId inputB, std::vector<std::size_t > shape)
107
- : BinaryNode(NodeKind::Mul, shape, inputA, inputB) {}
119
+ MulNode (Graph &graph, NodeId inputA, NodeId inputB,
120
+ std::vector<std::size_t > shape)
121
+ : BinaryNode(graph, NodeKind::Mul, shape, inputA, inputB) {}
108
122
static bool classof (const Node *node) {
109
123
return node->getKind () == NodeKind::Mul;
110
124
}
111
125
};
112
126
113
127
class DivNode : public BinaryNode {
114
128
public:
115
- DivNode (NodeId inputA, NodeId inputB, std::vector<std::size_t > shape)
116
- : BinaryNode(NodeKind::Div, shape, inputA, inputB) {}
129
+ DivNode (Graph &graph, NodeId inputA, NodeId inputB,
130
+ std::vector<std::size_t > shape)
131
+ : BinaryNode(graph, NodeKind::Div, shape, inputA, inputB) {}
117
132
static bool classof (const Node *node) {
118
133
return node->getKind () == NodeKind::Div;
119
134
}
120
135
};
121
136
122
137
class MatMulNode : public BinaryNode {
123
138
public:
124
- MatMulNode (NodeId inputA, NodeId inputB, std::vector<std::size_t > shape)
125
- : BinaryNode(NodeKind::MatMul, shape, inputA, inputB) {}
139
+ MatMulNode (Graph &graph, NodeId inputA, NodeId inputB,
140
+ std::vector<std::size_t > shape)
141
+ : BinaryNode(graph, NodeKind::MatMul, shape, inputA, inputB) {}
126
142
static bool classof (const Node *node) {
127
143
return node->getKind () == NodeKind::MatMul;
128
144
}
129
145
};
130
146
131
- class TensorNode : public UnaryNode {
147
+ class TensorNode : public Node {
132
148
std::string name;
133
149
134
150
public:
135
- TensorNode (std::string name, std::vector<std::size_t > shape)
136
- : UnaryNode(NodeKind::Tensor, shape, std::numeric_limits<NodeId>::max()),
137
- name (std::move(name)) {}
151
+ TensorNode (Graph &, std::string name, std::vector<std::size_t > shape)
152
+ : Node(NodeKind::Tensor, shape), name(name) {}
138
153
static bool classof (const Node *node) {
139
154
return node->getKind () == NodeKind::Tensor;
140
155
}
@@ -143,25 +158,25 @@ class TensorNode : public UnaryNode {
143
158
144
159
class ReLUNode : public UnaryNode {
145
160
public:
146
- ReLUNode (NodeId input, std::vector<std::size_t > shape)
147
- : UnaryNode(NodeKind::ReLU, shape, input) {}
161
+ ReLUNode (Graph &graph, NodeId input, std::vector<std::size_t > shape)
162
+ : UnaryNode(graph, NodeKind::ReLU, shape, input) {}
148
163
static bool classof (const Node *node) {
149
164
return node->getKind () == NodeKind::ReLU;
150
165
}
151
166
};
152
167
class TransposeNode : public UnaryNode {
153
168
public:
154
- TransposeNode (NodeId input, std::vector<std::size_t > shape)
155
- : UnaryNode(NodeKind::Transpose, shape, input) {}
169
+ TransposeNode (Graph &graph, NodeId input, std::vector<std::size_t > shape)
170
+ : UnaryNode(graph, NodeKind::Transpose, shape, input) {}
156
171
static bool classof (const Node *node) {
157
172
return node->getKind () == NodeKind::Transpose;
158
173
}
159
174
};
160
175
class ReshapeNode : public UnaryNode {
161
176
162
177
public:
163
- ReshapeNode (NodeId input, std::vector<std::size_t > shape)
164
- : UnaryNode(NodeKind::Reshape, shape, input) {}
178
+ ReshapeNode (Graph &graph, NodeId input, std::vector<std::size_t > shape)
179
+ : UnaryNode(graph, NodeKind::Reshape, shape, input) {}
165
180
static bool classof (const Node *node) {
166
181
return node->getKind () == NodeKind::Reshape;
167
182
}
@@ -170,8 +185,9 @@ class SumNode : public UnaryNode {
170
185
std::size_t dim;
171
186
172
187
public:
173
- SumNode (NodeId input, std::size_t dim, std::vector<std::size_t > shape)
174
- : UnaryNode(NodeKind::Sum, shape, input), dim(dim) {}
188
+ SumNode (Graph &graph, NodeId input, std::size_t dim,
189
+ std::vector<std::size_t > shape)
190
+ : UnaryNode(graph, NodeKind::Sum, shape, input), dim(dim) {}
175
191
static bool classof (const Node *node) {
176
192
return node->getKind () == NodeKind::Sum;
177
193
}
@@ -181,8 +197,9 @@ class SoftmaxNode : public UnaryNode {
181
197
std::size_t dim;
182
198
183
199
public:
184
- SoftmaxNode (NodeId input, std::size_t dim, std::vector<std::size_t > shape)
185
- : UnaryNode(NodeKind::Softmax, shape, input), dim(dim) {}
200
+ SoftmaxNode (Graph &graph, NodeId input, std::size_t dim,
201
+ std::vector<std::size_t > shape)
202
+ : UnaryNode(graph, NodeKind::Softmax, shape, input), dim(dim) {}
186
203
static bool classof (const Node *node) {
187
204
return node->getKind () == NodeKind::Softmax;
188
205
}
@@ -195,33 +212,33 @@ ToType *cast(FromType *object) {
195
212
return reinterpret_cast <ToType *>(object);
196
213
}
197
214
198
- void printNode ( NodeId nodeId) {
199
- Node *node = reinterpret_cast <Node *> (nodeId);
215
+ void printNodeBackwards ( const Graph &graph, NodeId nodeId) {
216
+ Node *node = graph. getNode (nodeId);
200
217
201
218
switch (node->getKind ()) {
202
219
case NodeKind::Add: {
203
220
AddNode *addNode = cast<AddNode>(node);
204
221
fmt::println (" {} [label=\" Add {}\" ];" , nodeId, addNode->getShape ());
205
222
fmt::println (" {} -> {};" , addNode->getInputA (), nodeId);
206
223
fmt::println (" {} -> {};" , addNode->getInputB (), nodeId);
207
- printNode ( addNode->getInputA ());
208
- printNode ( addNode->getInputB ());
224
+ printNodeBackwards (graph, addNode->getInputA ());
225
+ printNodeBackwards (graph, addNode->getInputB ());
209
226
break ;
210
227
}
211
228
case NodeKind::MatMul: {
212
229
MatMulNode *matMulNode = cast<MatMulNode>(node);
213
230
fmt::println (" {} [label=\" MatMul {}\" ];" , nodeId, matMulNode->getShape ());
214
231
fmt::println (" {} -> {};" , matMulNode->getInputA (), nodeId);
215
232
fmt::println (" {} -> {};" , matMulNode->getInputB (), nodeId);
216
- printNode ( matMulNode->getInputA ());
217
- printNode ( matMulNode->getInputB ());
233
+ printNodeBackwards (graph, matMulNode->getInputA ());
234
+ printNodeBackwards (graph, matMulNode->getInputB ());
218
235
break ;
219
236
}
220
237
case NodeKind::ReLU: {
221
238
ReLUNode *reluNode = cast<ReLUNode>(node);
222
239
fmt::println (" {} [label=\" ReLU {}\" ];" , nodeId, reluNode->getShape ());
223
240
fmt::println (" {} -> {};" , reluNode->getInput (), nodeId);
224
- printNode ( reluNode->getInput ());
241
+ printNodeBackwards (graph, reluNode->getInput ());
225
242
break ;
226
243
}
227
244
case NodeKind::Tensor: {
@@ -234,7 +251,7 @@ void printNode(NodeId nodeId) {
234
251
SoftmaxNode *softmaxNode = cast<SoftmaxNode>(node);
235
252
fmt::println (" {} [label=\" Softmax {}\" ];" , nodeId, softmaxNode->getShape ());
236
253
fmt::println (" {} -> {};" , softmaxNode->getInput (), nodeId);
237
- printNode ( softmaxNode->getInput ());
254
+ printNodeBackwards (graph, softmaxNode->getInput ());
238
255
break ;
239
256
}
240
257
default :
@@ -243,8 +260,8 @@ void printNode(NodeId nodeId) {
243
260
}
244
261
}
245
262
246
- void printGraph ( NodeId outputNodeId) {
263
+ void printGraphBackwards ( const Graph &graph, NodeId outputNodeId) {
247
264
fmt::println (" digraph {{" );
248
- printNode ( outputNodeId);
265
+ printNodeBackwards (graph, outputNodeId);
249
266
fmt::println (" }}" );
250
267
}
0 commit comments