Skip to content

Commit

Permalink
implement
Browse files Browse the repository at this point in the history
fix lint

save

add test

save

fix lint

update comment

fix build for gpu

Update ir_pass.py
  • Loading branch information
MarisaKirisame committed Dec 31, 2018
1 parent 1e78d41 commit 63270a2
Show file tree
Hide file tree
Showing 5 changed files with 521 additions and 3 deletions.
31 changes: 30 additions & 1 deletion include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand All @@ -212,6 +211,36 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into A Normal Form.
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A Normal Form).
*
* All subexpression will be lifted to the least common ancestor of all scope it is referenced in.
*
* If an expression is not referenced anywhere (it is the root expression) it will be lifted to the outmost scope.
*
* If there are multiple subexpression in the same scope, they are lifted by the postDFS order.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);

inline bool IsPrimitiveFunction(const Function& fn) {
NodeRef res = FunctionGetAttr(fn, "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && (pval->value != 0);
}

inline bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && IsPrimitiveFunction(Downcast<Function>(e));
}

} // namespace relay
} // namespace tvm

Expand Down
22 changes: 21 additions & 1 deletion python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
Expand All @@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
Expand Down Expand Up @@ -357,3 +357,23 @@ def alter_op_layout(expr):
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)


def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression
Parameters
----------
expr : tvm.relay.Expr
The input expression
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression
"""
return _ir_pass.to_anf(expr, mod)
6 changes: 5 additions & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
Expand Down Expand Up @@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}

Expand Down Expand Up @@ -108,6 +111,7 @@ class LetList {

private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
};

} // namespace relay
Expand Down
Loading

0 comments on commit 63270a2

Please sign in to comment.