Skip to content

Commit

Permalink
implement
Browse files Browse the repository at this point in the history
fix lint

save

add test

save

fix lint

update comment

fix build for gpu

Update ir_pass.py

save

fix error

fix lint

add test

fix lint

fix test

fix test

reboot pytest

Update to_anf.cc

address review comment

save

fused topo

remove dead code

save

save

do it
  • Loading branch information
MarisaKirisame committed Jan 11, 2019
1 parent 547a091 commit 464fe86
Show file tree
Hide file tree
Showing 10 changed files with 684 additions and 17 deletions.
34 changes: 32 additions & 2 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ Expr FoldConstant(const Expr& expr);
/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param mod The global module.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);
Expr FuseOps(const Expr& expr, const Module& mod, int fuse_opt_level);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
Expand Down Expand Up @@ -188,7 +189,6 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand All @@ -212,6 +212,36 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into A Normal Form.
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);

inline bool IsPrimitiveFunction(const Function& fn) {
NodeRef res = FunctionGetAttr(fn, "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && (pval->value != 0);
}

inline bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && IsPrimitiveFunction(Downcast<Function>(e));
}

} // namespace relay
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def optimize(self, expr):
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr)
fused_expr = ir_pass.fuse_ops(ck_expr, mod=self.mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def build(func,
func = optimize(func, target, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
func = ir_pass.fuse_ops(func, mod=None, opt_level=cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
Expand Down
41 changes: 35 additions & 6 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
Expand All @@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
Expand Down Expand Up @@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t : tvm.relay.Type
The type to check
mod: tvm.relay.Module, optional
The global module
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
Expand Down Expand Up @@ -305,14 +305,17 @@ def fold_constant(expr):
return _ir_pass.FoldConstant(expr)


def fuse_ops(expr, opt_level=1):
def fuse_ops(expr, mod=None, opt_level=1):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
opt_level : int
The level of fuse optimization.
Expand All @@ -321,7 +324,7 @@ def fuse_ops(expr, opt_level=1):
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, opt_level)
return _ir_pass.FuseOps(expr, mod, opt_level)


def combine_parallel_conv2d(expr):
Expand Down Expand Up @@ -357,3 +360,29 @@ def alter_op_layout(expr):
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)


def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression
"""
return _ir_pass.to_anf(expr, mod)
2 changes: 1 addition & 1 deletion src/relay/pass/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ConstantFolder : public ExprMutator {
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
expr = InferType(expr, Module(nullptr));
expr = FuseOps(expr, 0);
expr = FuseOps(expr, Module(nullptr), 0);
expr = InferType(expr, Module(nullptr));
return ValueToExpr(executor_(expr));
}
Expand Down
28 changes: 25 additions & 3 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
return std::move(groups_);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv);

class FuseMutator : private ExprMutator {
public:
// Run the transform
Expand All @@ -667,8 +669,13 @@ class FuseMutator : private ExprMutator {
return this->Mutate(body);
}

FuseMutator(const Module& mod, int fuse_opt_level, std::set<GlobalVar>* visited) :
mod_(mod), fuse_opt_level_(fuse_opt_level), visited_(visited) { }

private:
Module mod_;
int fuse_opt_level_;
std::set<GlobalVar>* visited_;
/*! \brief Temporary information from each group. */
struct GroupInfo {
public:
Expand Down Expand Up @@ -751,6 +758,16 @@ class FuseMutator : private ExprMutator {
return new_tuple;
}

Expr VisitExpr_(const GlobalVarNode* node) {
GlobalVar gv = GetRef<GlobalVar>(node);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv,
Downcast<Function>(FuseOps(mod_->Lookup(gv), mod_, fuse_opt_level_, visited_)));
}
return gv;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
Expand Down Expand Up @@ -790,17 +807,22 @@ class FuseMutator : private ExprMutator {
};


Expr FuseOps(const Expr& expr, int fuse_opt_level) {
Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level) {
std::set<GlobalVar> gv;
return FuseOps(expr, m, fuse_opt_level, &gv);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
return FuseMutator().Transform(expr, fuse_opt_level);
return FuseMutator(m, fuse_opt_level, gv).Transform(expr, fuse_opt_level);
}

TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[0], args[1]);
*ret = FuseOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
6 changes: 5 additions & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
Expand Down Expand Up @@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}

Expand Down Expand Up @@ -108,6 +111,7 @@ class LetList {

private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
};

} // namespace relay
Expand Down
Loading

0 comments on commit 464fe86

Please sign in to comment.