diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8fff7016a827b..246921da43b50 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -188,7 +188,6 @@ Expr ForwardRewrite(const Expr& expr, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); - /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. @@ -212,6 +211,36 @@ struct StructuralHash { size_t operator()(const Expr& expr) const; }; +/*! \brief turn a dataflow graph into A Normal Form. + * + * It will turn an expression that is in a graph form (with sharing implicit), + * to an expression with explicit sharing (A Normal Form). + * + * All subexpression will be lifted to the least common ancestor of all scope it is referenced in. + * + * If an expression is not referenced anywhere (it is the root expression) it will be lifted to the outmost scope. + * + * If there are multiple subexpression in the same scope, they are lifted by the postDFS order. + * + * \param e the expression to observably share + * + * \param mod The module used for referencing global functions, can be + * None. + * + * \return expression in A Normal Form + */ +Expr ToANF(const Expr& e, const Module& mod); + +inline bool IsPrimitiveFunction(const Function& fn) { + NodeRef res = FunctionGetAttr(fn, "Primitive"); + const ir::IntImm* pval = res.as(); + return pval && (pval->value != 0); +} + +inline bool IsPrimitiveFunction(const Expr& e) { + return e.as() && IsPrimitiveFunction(Downcast(e)); +} + } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 53fa59cd053da..38fa26645e17f 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit): ---------- expr : tvm.relay.Expr The input expression. + fvisit : function The visitor function to be applied. """ @@ -35,7 +36,6 @@ def infer_type(expr, mod=None): mod: Optional[tvm.relay.Module] The global module. - Returns ------- checked_expr : tvm.relay.Expr @@ -357,3 +357,23 @@ def alter_op_layout(expr): Transformed expression with alternated layout. """ return _ir_pass.AlterOpLayout(expr) + + +def to_anf(expr, mod=None): + """ + Turn Graph Normal Form expression into A Normal Form Expression + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + mod: Optional[tvm.relay.Module] + The global module. + + Returns + ------- + expr: tvm.relay.Expr + The output expression + """ + return _ir_pass.to_anf(expr, mod) diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 904ceab36c3d4..2fecc8ba3727e 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -36,6 +36,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Var pv, Expr expr) { + CHECK(!used_); lets_.emplace_back(std::make_pair(pv, expr)); return pv; } @@ -71,11 +72,13 @@ class LetList { * * \return the wrapped expr. */ - Expr Get(const Expr& body) const { + Expr Get(const Expr& body) { + CHECK(!used_); Expr ret = body; for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret); } + used_ = true; return ret; } @@ -108,6 +111,7 @@ class LetList { private: std::vector > lets_; + bool used_ = false; }; } // namespace relay diff --git a/src/relay/pass/to_anf.cc b/src/relay/pass/to_anf.cc new file mode 100644 index 0000000000000..9f90bebbb41a4 --- /dev/null +++ b/src/relay/pass/to_anf.cc @@ -0,0 +1,433 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file to_anf.cc + * + * \brief Turn implicit sharing into observable sharing. + */ +#include +#include +#include "let_list.h" + +namespace tvm { +namespace relay { + +/* The algorithm is very tricky. + * Maybe you'll think, 'cant I just lift everything to the lowest common ancestor?' + * But what is the lowest common ancestor? + * The problem is here because you also need to lift lambda. + * And lambda bind variable. + * So while you are lifting stuff, + * the scope that dictate how you should lift stuff is also being lifted too. + * + * The insight to the algorithm is that, we gotta think backward. + * Instead of the usual recursion, where parent scope are built from children scope, + * Children scope should be built from the parent scope. + * + * The scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all it's dependent, + * so they can all use it. + * + * In order to do this, we need to: + * 0: find dependency between expr. + * 1: toposort it. + * 2: for all new scope, create a LetList ll. + * 3: flush every value to the ll of the lowest common ancestor. + * 4: collect the values. + */ +Expr ToANF(const Expr& e, const Module& m, std::set* gv); + +struct ScopeNode; +using Scope = std::shared_ptr; +using OptionalScope = Scope; + +/* Invariant: when parent is null level is 0 + * + * Invariant: when parent is notempty level is 1 + parent->level + */ +struct ScopeNode { + size_t level; + Scope parent; + std::shared_ptr ll = std::make_shared(); + explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { } + ScopeNode() : level(0) { } +}; + +Scope ChildScope(const Scope& s) { + return std::make_shared(s); +} + +OptionalScope LCA(OptionalScope lhs, OptionalScope rhs) { + if (!lhs) { + return rhs; + } else if (!rhs) { + return lhs; + } else { + while (lhs != rhs) { + if (lhs->level > rhs->level) { + lhs = lhs->parent; + } else if (lhs->level < rhs->level) { + rhs = rhs->parent; + } else { + lhs = lhs->parent; + rhs = rhs->parent; + } + } + return lhs; + } +} + +struct ExprPairHash { + size_t operator()(const std::pair& p) const { + return dmlc::HashCombine(NodeHash()(p.first), NodeHash()(p.second)); + } +}; + +using Graph = std::unordered_map, + NodeHash, + NodeEqual>; + +struct DependentGraph { + Graph dependency; + Graph dependent; + DependentGraph(const Graph& dependency, + const Graph& dependent) : dependency(dependency), + dependent(dependent) { } + DependentGraph() = default; +}; + +/* Invariant: after visited dependency of such expr will be filled. + * + * Invariant: dependency and dependent is inverse of each other. + */ +class ExprDep : ExprFunctor { + public: + static DependentGraph Dependency(const Expr& e) { + ExprDep ed; + ed(e); + return ed.dg_; + } + private: + DependentGraph dg_; + + void VisitExpr(const Expr& e) final { + if (dg_.dependency.count(e) == 0) { + dg_.dependency.insert({e, {}}); + ExprFunctor::VisitExpr(e); + } + } + + void Depend(const Expr& parent, const Expr& child) { + CHECK_NE(dg_.dependency.count(parent), 0); + dg_.dependency.at(parent).insert(child); + + if (dg_.dependent.count(child) == 0) { + dg_.dependent.insert({child, {}}); + } + dg_.dependent.at(child).insert(parent); + + VisitExpr(child); + } + + void VisitExpr_(const CallNode* c) final { + Expr e = GetRef(c); + Depend(e, c->op); + for (const auto& a : c->args) { + Depend(e, a); + } + } + + void VisitExpr_(const TupleNode* t) final { + Expr e = GetRef(t); + for (const auto& a : t->fields) { + Depend(e, a); + } + } + + void VisitExpr_(const TupleGetItemNode* t) final { + Expr e = GetRef(t); + Depend(e, t->tuple); + } + + void VisitExpr_(const IfNode* i) final { + Expr e = GetRef(i); + Depend(e, i->cond); + Depend(e, i->true_branch); + Depend(e, i->false_branch); + } + + void VisitExpr_(const FunctionNode* f) final { + Expr e = GetRef(f); + Depend(e, f->body); + for (const auto& p : f->params) { + Depend(e, p); + } + } + + void VisitExpr_(const LetNode* l) final { + Expr e = GetRef(l); + Depend(e, l->var); + Depend(e, l->value); + Depend(e, l->body); + } + + void VisitExpr_(const VarNode* v) final { } + + void VisitExpr_(const GlobalVarNode* v) final { } + + void VisitExpr_(const ConstantNode* c) final { } + + void VisitExpr_(const OpNode* o) final { } +}; + +struct ExprScopeMap { + std::unordered_map expr_to_scope; + std::unordered_map, NodeHash, NodeEqual> expr_scope_alloc; + std::unordered_map, Scope, ExprPairHash> dep_final; + + Scope global_scope = std::make_shared(); + + Scope get_scope(const Expr& e) const { + return expr_to_scope.count(e) == 0 ? global_scope : expr_to_scope.at(e); + } +}; + +class Topo { + public: + static std::vector Sort(const DependentGraph& dg) { + Topo t(dg); + for (const auto& p : dg.dependency) { + t(p.first); + } + return t.topo_; + } + + private: + std::vector topo_; + std::unordered_set visited_; // make sure each expr appear only once + const DependentGraph& dg_; + explicit Topo(const DependentGraph& dg) : dg_(dg) { } + void operator()(const Expr& e) { + if (visited_.count(e) == 0) { + visited_.insert(e); + if (dg_.dependent.count(e) != 0) { + for (const auto & p : dg_.dependent.at(e)) { + (*this)(p); + } + } + topo_.push_back(e); + } + } +}; + +class CalcScope : ExprFunctor { + public: + static ExprScopeMap Calculate(const DependentGraph& dg, const std::vector& topo) { + CalcScope cs(dg); + for (const Expr& expr : topo) { + cs(expr); + } + return cs.esm_; + } + + private: + void VisitExpr_(const FunctionNode* f) final { + Expr e = GetRef(f); + Scope s = ChildScope(esm_.expr_to_scope.at(e)); + esm_.expr_scope_alloc.insert({e, {s}}); + esm_.dep_final.insert({std::pair(e, f->body), s}); + } + + void VisitExpr_(const LetNode* l) final { + Expr e = GetRef(l); + Scope s = ChildScope(esm_.expr_to_scope.at(e)); + esm_.expr_scope_alloc.insert({e, {s}}); + esm_.dep_final.insert({std::pair(e, l->body), s}); + } + + void VisitExpr_(const IfNode* i) final { + Expr e = GetRef(i); + Scope t = ChildScope(esm_.expr_to_scope.at(e)); + Scope f = ChildScope(esm_.expr_to_scope.at(e)); + esm_.expr_scope_alloc.insert({e, {t, f}}); + esm_.dep_final.insert({std::pair(e, i->true_branch), t}); + esm_.dep_final.insert({std::pair(e, i->false_branch), f}); + } + + void VisitExprDefault_(const Node* e) final { } + + void VisitExpr(const Expr& e) final { + OptionalScope os; + if (dg_.dependent.count(e) != 0) { + for (const auto& p : dg_.dependent.at(e)) { + os = LCA(os, GetScope(p, e)); + } + } + esm_.expr_to_scope.insert({e, os ? os : esm_.global_scope}); + ExprFunctor::VisitExpr(e); + } + + OptionalScope GetScope(const Expr& parent, const Expr& child) { + std::pair ep(parent, child); + if (esm_.dep_final.count(ep) != 0) { + return esm_.dep_final.at(ep); + } + CHECK_NE(esm_.expr_to_scope.count(parent), 0); + return esm_.expr_to_scope.at(parent); + } + + const DependentGraph& dg_; + + CalcScope(const DependentGraph& dg) : dg_(dg) { } + + ExprScopeMap esm_; +}; + +class Fill : ExprFunctor { + public: + static Expr ToANF(const ExprScopeMap& esm, + const Module& m, + std::set* gv, + const Expr& e) { + Fill fi(esm, m, gv); + return esm.global_scope->ll->Get(fi.VisitExpr(e)); + } + + private: + const ExprScopeMap& esm_; + Module mod_; + std::set* visited_; + + Fill(const ExprScopeMap& esm, + Module mod, + std::set* visited) : + esm_(esm), + mod_(mod), + visited_(visited) { } + + std::unordered_map memo; + + Expr VisitExpr(const Expr& e, const Var& v) final { + if (memo.count(e) == 0) { + memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } + return memo.at(e); + } + + Expr VisitExpr(const Expr& e) { + Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType)); + return this->VisitExpr(e, v); + } + + Expr Compound(const Expr& orig, const Expr& now, const Var& v) { + return esm_.get_scope(orig)->ll->Push(v, now); + } + + Expr VisitExpr_(const CallNode* c, const Var& v) final { + Expr e = GetRef(c); + std::vector args; + for (const auto& a : c->args) { + args.push_back(VisitExpr(a)); + } + return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v); + } + + Expr VisitExpr_(const TupleNode* t, const Var& v) final { + Expr e = GetRef(t); + std::vector fields; + for (const auto& a : t->fields) { + fields.push_back(VisitExpr(a)); + } + return Compound(e, TupleNode::make(fields), v); + } + + Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { + Expr e = GetRef(t); + return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v); + } + + Expr VisitExpr_(const IfNode* i, const Var& v) final { + Expr e = GetRef(i); + Expr ret = IfNode::make(VisitExpr(i->cond), + esm_.expr_scope_alloc.at(e)[0]->ll->Get(VisitExpr(i->true_branch)), + esm_.expr_scope_alloc.at(e)[1]->ll->Get(VisitExpr(i->false_branch))); + return Compound(e, ret, v); + } + + Expr VisitExpr_(const FunctionNode* f, const Var& v) final { + Expr e = GetRef(f); + Expr ret; + if (IsPrimitiveFunction(e)) { + ret = e; + } else { + ret = FunctionNode::make(f->params, + esm_.expr_scope_alloc.at(e)[0]->ll->Get(VisitExpr(f->body)), + f->ret_type, + f->type_params, + f->attrs); + } + return Compound(e, ret, v); + } + + Expr VisitExpr_(const LetNode* l, const Var& v) final { + Expr e = GetRef(l); + VisitExpr(l->value, l->var); + Expr ret = esm_.expr_scope_alloc.at(e)[0]->ll->Get(VisitExpr(l->body)); + return Compound(e, ret, v); + } + + Expr VisitExpr_(const ConstantNode* c, const Var& v) final { + Expr e = GetRef(c); + return Compound(e, e, v); + } + + Expr VisitExpr_(const VarNode* vn, const Var& v) final { + return GetRef(vn); + } + + Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { + GlobalVar gv = GetRef(gvn); + if (visited_->count(gv) == 0) { + visited_->insert(gv); + mod_->Update(gv, Downcast(relay::ToANF(mod_->Lookup(gv), mod_, visited_))); + } + return gv; + } + + Expr VisitExpr_(const OpNode* op, const Var& v) final { + return GetRef(op); + } +}; + +Expr ToANFAux(const Expr& e, const Module& m, std::set* gv) { + DependentGraph dg = ExprDep::Dependency(e); + std::vector topo = Topo::Sort(dg); + ExprScopeMap esm = CalcScope::Calculate(dg, topo); + return Fill::ToANF(esm, m, gv, e); +} + +Expr ToANF(const Expr& e, const Module& m, std::set* gv) { + if (auto f = e.as()) { + return FunctionNode::make(f->params, + ToANFAux(f->body, m, gv), + f->ret_type, + f->type_params, + f->attrs); + } else { + return ToANFAux(e, m, gv); + } +} + +Expr ToANF(const Expr& e, const Module& m) { + std::set gv; + return ToANF(e, m, &gv); +} + +TVM_REGISTER_API("relay._ir_pass.to_anf") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ToANF(args[0], args[1]); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_to_anf.py b/tests/python/relay/test_to_anf.py new file mode 100644 index 0000000000000..49a112064c75e --- /dev/null +++ b/tests/python/relay/test_to_anf.py @@ -0,0 +1,32 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay.ir_pass import to_anf, alpha_equal +from tvm.relay import op, create_executor +from tvm.relay.backend.interpreter import Value, TupleValue + + +def check_eval(expr, expected_result, mod=None, rtol=1e-07): + if mod is None: + mod = relay.Module() + + ctx = tvm.context("llvm", 0) + intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + + result = intrp.evaluate(expr) + np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) + + +def test_to_anf(): + x = relay.const(1) + y = op.add(x, x) + z = op.add(y, y) + f = relay.Function([], op.add(z, z)) + assert not "let" in f.astext() + taf = to_anf(f) + assert "let" in taf.astext() + check_eval(f(), 8.0) + check_eval(taf(), 8.0) + +if __name__ == '__main__': + test_to_anf()