Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jan 14, 2019
1 parent 028f49a commit df176f2
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions src/relay/pass/to_anf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class DependencyGraph {
};
/*! \brief The node map that maps node to graph */
std::unordered_map<Expr, Node*, NodeHash, NodeEqual> node_map;
std::unordered_map<Node*, Expr> node_unmap;

/*! \brief All the nodes in post DFS order */
std::vector<Node*> post_dfs_order;
/*!
* \brief create a dependency graph.
* \param arena The arena used for data allocation.
Expand Down Expand Up @@ -58,21 +60,14 @@ class DependencyGraph::Creator : private ExprVisitor {
DependencyGraph graph_;
// Update the message stored at the node.
void Depend(const Expr& parent, const Expr& child) {
if (graph_.node_map.count(parent) == 0) {
graph_.node_map[parent] = arena_->make<DependencyGraph::Node>();
graph_.node_unmap[graph_.node_map[parent]] = parent;
}
VisitExpr(child);

if (graph_.node_map.count(child) == 0) {
graph_.node_map[child] = arena_->make<DependencyGraph::Node>();
graph_.node_unmap[graph_.node_map[child]] = child;
}
CHECK_NE(graph_.node_map.count(parent), 0);
CHECK_NE(graph_.node_map.count(child), 0);

auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
parent_link->value = graph_.node_map[parent];
graph_.node_map[child]->output.Push(parent_link);

VisitExpr(child);
}

std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
Expand All @@ -81,7 +76,7 @@ class DependencyGraph::Creator : private ExprVisitor {
if (visited_.count(e) == 0) {
if (graph_.node_map.count(e) == 0) {
graph_.node_map[e] = arena_->make<DependencyGraph::Node>();
graph_.node_unmap[graph_.node_map[e]] = e;
graph_.post_dfs_order.push_back(graph_.node_map[e]);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
Expand Down Expand Up @@ -178,17 +173,17 @@ Scope LCA(Scope lhs, Scope rhs) {
return lhs;
}

using Edge = std::pair<Expr, Expr>;
using Edge = std::pair<DependencyGraph::Node*, DependencyGraph::Node*>;

struct EdgeHash {
size_t operator()(const Edge& p) const {
return dmlc::HashCombine(NodeHash()(p.first), NodeHash()(p.second));
return dmlc::HashCombine(std::hash<DependencyGraph::Node*>()(p.first), std::hash<DependencyGraph::Node*>()(p.second));
}
};

class ExprScopeMap {
public:
Scope GetScope(const Expr& e) const {
Scope GetScope(const DependencyGraph::Node* e) const {
CHECK_NE(expr_scope_.count(e), 0);
return expr_scope_.at(e);
}
Expand All @@ -197,7 +192,7 @@ class ExprScopeMap {
return edge_scope_.count(edge) != 0 ? edge_scope_.at(edge) : GetScope(edge.first);
}

Scope GetScope(const Expr& e, const LinkedList<DependencyGraph::Node*>& es) const {
Scope GetScope(const DependencyGraph::Node* e, const LinkedList<DependencyGraph::Node*>& es) const {
auto it = es.head;
if (it == nullptr) {
return global_scope_;
Expand Down Expand Up @@ -239,11 +234,11 @@ class ExprScopeMap {

private:
// Scope of an expression.
std::unordered_map<Expr, Scope, NodeHash, NodeEqual> expr_scope_;
std::unordered_map<Node*, Scope> expr_scope_;
// subscopes of an expression.
// For example, conditional create two subscopes, one for each case.
// The conditional use the original scope.
std::unordered_map<Expr, std::vector<Scope>, NodeHash, NodeEqual> expr_subscope_;
std::unordered_map<Node*, std::vector<Scope>> expr_subscope_;
// Scope of an edge.
// Note that it might not be stored here if it is the same as the parent's scope.
std::unordered_map<Edge, Scope, EdgeHash> edge_scope_;
Expand Down

0 comments on commit df176f2

Please sign in to comment.