Skip to content

Commit

Permalink
add support for subgraphs. (apache#1221)
Browse files Browse the repository at this point in the history
* add support for subgraphs.

* fix.

* fix.

* Fix compilation error

* Fix compilation error

* add comments.

* update comments.

* Sanity check on subgraphs when creating IndexedGraph

* avoid the overhead of sanity check.

* Stop using non-recursive DFS

* Trigger CI

* trigger CI
  • Loading branch information
zheng-da authored and sergei-mironov committed Aug 8, 2018
1 parent 53d65da commit 95a27c0
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 74 deletions.
16 changes: 16 additions & 0 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace nnvm {

// Forward declare node.
class Node;
class Symbol;

/*!
* \brief we always used NodePtr for a reference pointer
Expand Down Expand Up @@ -90,6 +91,21 @@ struct NodeAttrs {
* The object can be used to quickly access attributes.
*/
any parsed;
/*!
* \brief Some operators take graphs as input. These operators include
* control flow operators and high-order functions.
* These graphs don't change when the operators are invoked for different
* mini-batches. In this sense, the subgraphs are kind of similar to
* the parameters and show be kept as node attributes.
*
* Users need to make sure the subgraphs are disjoint with the main graph.
* If a graph shares nodes with subgraphs, loading the graph from LoadJSON
* may generate a graph that has a different structure from the original graph
* (some of the nodes are duplicated). If nodes are shared between two graphs,
* shared nodes might be executed multiple times, which can be a problem for
* stateful operators.
*/
std::vector<std::shared_ptr<Symbol> > subgraphs;
};

/*!
Expand Down
12 changes: 12 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ using FCorrectLayout = std::function<bool(
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts)>;

/*!
* \brief Get a list of inputs that represent graphs instead of data.
* Normally, input symbols are considered as data to the operator. However,
* control flow operators and high-order functions need to interpret symbols
* as graphs.
* \param attrs The attributes of this node.
* \return a list of input index that are interpreted as symbols by the operator.
*
* \note Register under "FInputGraph".
*/
using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
40 changes: 39 additions & 1 deletion nnvm/src/core/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,51 @@ const IndexedGraph& Graph::indexed_graph() const {
return *indexed_graph_;
}

// a subgraph should not refer to any nodes with higher level
// where "level" refers to the nested depth of the subgraph
// e.g. the main graph is level 0
// subgraphs of the main graph is level 1
// subgraphs of the subgraphs of the main graph is level 2
static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subgraphs) {
std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
std::unordered_map<nnvm::Node*, uint32_t> node2level;
for (auto &subgraph : subgraphs)
next_level.push_back(&subgraph->outputs);
for (uint32_t level = 0; !next_level.empty(); ++level) {
curr_level.swap(next_level);
next_level.clear();
for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
const std::vector<NodeEntry> &graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) {
nnvm::Node *node = n.get();
// if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error
CHECK(!node2level.count(node) || node2level[node] == level)
<< "A subgraph should not depend on the outputs of nodes on higher levels";
// otherwise, this node belongs to the current level
node2level[node] = level;
// subgraphs of current node belongs to next level
for (const auto& subgraph : n->attrs.subgraphs) {
next_level.push_back(&subgraph->outputs);
}
});
}
}
}

// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
std::vector<std::shared_ptr<Symbol>> subgraphs;

DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
for (const auto &subgraph : n->attrs.subgraphs)
subgraphs.push_back(subgraph);
// nodes_
IndexedGraph::Node new_node;
new_node.source = n.get();
Expand Down Expand Up @@ -53,6 +89,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {
}
control_rptr.push_back(control_deps_.size());
});
if (!subgraphs.empty())
SubgraphSanityCheck(subgraphs);

for (const auto& e : g.outputs) {
outputs_.emplace_back(NodeEntry{
Expand Down
81 changes: 62 additions & 19 deletions nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,43 +267,86 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");

// The arguments that contain graphs.
Node* n = outputs[0].node.get();
FInputGraph fng = fgraph.get(n->op(), nullptr);
std::vector<uint32_t> garg_idx;
if (fng != nullptr)
garg_idx = fng(n->attrs);

// The names of the arguments that contain graphs.
FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
std::vector<std::string> garg_names(garg_idx.size());
for (size_t i = 0; i < garg_idx.size(); i++) {
size_t idx = garg_idx[i];
if (idx < arg_names.size())
garg_names[i] = arg_names[idx];
}

// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
// If the argument isn't a graph, it should have only one output.
if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second->outputs.size(), 1U)
if (garg_names.empty()
|| std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
if (!name.empty()) outputs[0].node->attrs.name = name;

// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
std::vector<const Symbol *> arg_vec(args.begin(), args.end());
std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
// If one of the input arguments is a graph, we need to remove it from the
// list.
if (fng != nullptr) {
std::vector<uint32_t> idxes = fng(n->attrs);
for (auto idx : idxes) {
const Symbol *sym;
if (idx < arg_vec.size()) {
sym = arg_vec[idx];
arg_vec.erase(arg_vec.begin() + idx);
} else {
auto it = kwarg_map.find(arg_names[idx]);
CHECK(it != kwarg_map.end());
sym = it->second;
kwarg_map.erase(it);
}

if (n_req != kVarg)
n_req--;
arg_names.erase(arg_names.begin() + idx);
n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
}
}

if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(args.size(), n_req)
CHECK_LE(arg_vec.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i]->outputs[0];
<< ", provided " << arg_vec.size();
for (size_t i = 0; i < arg_vec.size(); ++i) {
n->inputs[i] = arg_vec[i]->outputs[0];
}
// switch to keyword argument matching
if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_vec.size() != n_req) {
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
}
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) {
for (size_t i = arg_vec.size(); i < n_req; ++i) {
auto it = kwarg_map.find(arg_names[i]);
if (it != kwarg_map.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
Expand All @@ -314,18 +357,18 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
}

if (nmatched != kwargs.size()) {
if (nmatched != kwarg_map.size()) {
n->inputs.clear();
std::vector<std::string> keys = GetKeys(kwargs);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(),
std::vector<std::string> keys = GetKeys(kwarg_map);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, view);
}
}
} else {
CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size());
for (const Symbol* s : args) {
CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(arg_vec.size());
for (const Symbol* s : arg_vec) {
n->inputs.push_back(s->outputs[0]);
}
}
Expand Down
Loading

0 comments on commit 95a27c0

Please sign in to comment.