From fa0bd86ce3d7883fe2ac50908d916e8907c241db Mon Sep 17 00:00:00 2001 From: Ritwik Das Date: Wed, 7 Oct 2020 02:52:06 -0700 Subject: [PATCH 01/19] Change annotate target --- src/relay/transforms/annotate_target.cc | 49 +++++++++++++++++-------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 74c236ae3280..add997c41a26 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -35,10 +35,9 @@ namespace tvm { namespace relay { namespace annotate_target { -static const PackedFunc* make_begin_op = +const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); -static const PackedFunc* make_end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); +const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. @@ -130,11 +129,13 @@ class AnnotateTargetRewriter : public ExprRewriter { // Peek the first argument. If it is compiler begin then this node had annotated by // another target before, so we also consider that target as a supported target. - const CallNode* first_arg_call = pre->args[0].as(); - if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { - std::string arg_target = first_arg_call->attrs.as()->compiler; - if (arg_target != "default") { - supported_targets.push_back(arg_target); + if (pre->args.size()) { + const CallNode* first_arg_call = pre->args[0].as(); + if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { + std::string arg_target = first_arg_call->attrs.as()->compiler; + if (arg_target != "default") { + supported_targets.push_back(arg_target); + } } } @@ -234,20 +235,36 @@ class AnnotateTargetRewriter : public ExprRewriter { Expr Rewrite_(const LetNode* op, const Expr& post) final { auto let = Downcast(post); - auto target_n_args = AnnotateArgs({let->value, let->body}); - auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); - op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + Expr new_expr; + std::pair> target_n_args; + Expr new_body = InsertCompilerEnd(let->body); + bool is_functional_literal = let->value.as() != nullptr; + if (is_functional_literal) { + new_expr = Let(let->var, let->value, new_body); + } else { + target_n_args = AnnotateArgs({let->value}); + new_expr = Let(let->var, std::get<1>(target_n_args)[0], new_body); + } + + return std::move(new_expr); + } + + Expr InsertCompilerEnd(const Expr& expr) { + Expr new_expr = expr; + if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { + new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op); + op_expr_to_target_[new_expr] = op_expr_to_target_[expr]; + } return std::move(new_expr); } Expr Rewrite_(const IfNode* op, const Expr& post) final { auto expr = Downcast(post); + Expr new_cond = InsertCompilerEnd(expr->cond); + Expr new_true_branch = InsertCompilerEnd(expr->true_branch); + Expr new_false_branch = InsertCompilerEnd(expr->false_branch); - auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch}); - CHECK_EQ(std::get<1>(target_n_args).size(), 3U); - auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1], - std::get<1>(target_n_args)[2]); - op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + auto new_expr = If(new_cond, new_true_branch, new_false_branch); return std::move(new_expr); } From c468a30ca709cb85a29a4d9423448cdba534a7c7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Oct 2020 18:46:55 +0000 Subject: [PATCH 02/19] Annotate_target --- src/relay/transforms/annotate_target.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index add997c41a26..dd93dc51d0ae 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/relay/transforms/annotate_target.cc + * \file src/relay/transforms/.cc * \brief Wraps an expr with compiler_begin and compiler_end to indicate that * this expr should be handled by the external compiler. */ @@ -33,7 +33,7 @@ namespace tvm { namespace relay { -namespace annotate_target { +namespace { const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); @@ -307,14 +307,14 @@ Expr AnnotateTarget(const Expr& expr, const Array& targets) { return PostOrderRewrite(expr, &rewriter); } -} // namespace annotate_target +} // namespace namespace transform { Pass AnnotateTarget(const Array& targets) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); + return Downcast(relay::AnnotateTarget(f, targets)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); From 5f74c0c1f8fb74351c4862f3e6bde8bcecc9cda6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Oct 2020 19:00:28 +0000 Subject: [PATCH 03/19] Revert namespace changes --- src/relay/transforms/annotate_target.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index dd93dc51d0ae..add997c41a26 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/relay/transforms/.cc + * \file src/relay/transforms/annotate_target.cc * \brief Wraps an expr with compiler_begin and compiler_end to indicate that * this expr should be handled by the external compiler. */ @@ -33,7 +33,7 @@ namespace tvm { namespace relay { -namespace { +namespace annotate_target { const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); @@ -307,14 +307,14 @@ Expr AnnotateTarget(const Expr& expr, const Array& targets) { return PostOrderRewrite(expr, &rewriter); } -} // namespace +} // namespace annotate_target namespace transform { Pass AnnotateTarget(const Array& targets) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::AnnotateTarget(f, targets)); + return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); From 718daea046cac62f054b4b55ffa660800ce63f26 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Oct 2020 20:20:40 +0000 Subject: [PATCH 04/19] Add tests for if-else node --- .../python/relay/test_pass_annotate_target.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index b7c43498a69a..51f2015b83c1 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -352,6 +352,77 @@ def before(): expected = transform.AnnotateTarget(["A", "B"])(before()) assert tvm.ir.structural_equal(expected, mod) +def test_if_else(): + target = "test_if_else" + + @tvm.ir.register_op_attr("equal", "target." + target) + def relu(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("tanh", "target." + target) + def tanh(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("sigmoid", "target." + target) + def sigmoid(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("If", "target." + target) + def If(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("erf", "target." + target) + def erf(attrs, args): # pylint: disable=unused-variable + return True + + """Test that If-else nodes compiles correctly when surrounded by supported nodes.""" + + def before(): + data = relay.var('data', shape=(1, 32)) + eq1 = relay.var('e1', shape=[], dtype='int32') + eq2 = relay.var('e2', shape=[], dtype='int32') + eq = relay.equal(eq1, eq2) + + true_branch = relay.tanh(data) + false_branch = relay.sigmoid(data) + ife = relay.If(eq, true_branch, false_branch) + out = relay.erf(ife) + func = relay.Function([data, eq1, eq2], out) + mod = tvm.IRModule.from_expr(func) + + return mod + + def after(): + + data = relay.var('data', shape=(1, 32)) + eq1 = relay.var('e1', shape=[], dtype='int32') + eq2 = relay.var('e2', shape=[], dtype='int32') + + cb_1 = relay.annotation.compiler_begin(eq1, target) + cb_2 = relay.annotation.compiler_begin(eq2, target) + + equality_condition = relay.equal(cb_1, cb_2) + ce_1 = relay.annotation.compiler_end(equality_condition, target) + + cb_3 = relay.annotation.compiler_begin(data, target) + true_branch = relay.tanh(cb_3) + ce_2 = relay.annotation.compiler_end(true_branch, target) + + cb_4 = relay.annotation.compiler_begin(data, target) + false_branch = relay.sigmoid(cb_4) + ce_3 = relay.annotation.compiler_end(false_branch, target) + + if_condition = relay.If(ce_1, ce_2, ce_3) + cb_5 = relay.annotation.compiler_begin(if_condition, target) + erf_out = relay.erf(cb_5) + ce_4 = relay.annotation.compiler_end(erf_out, target) + func = relay.Function([data, eq1, eq2], ce_4) + mod = tvm.IRModule.from_expr(func) + return mod + + result = transform.AnnotateTarget(target)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) if __name__ == "__main__": test_extern_dnnl() @@ -361,3 +432,4 @@ def before(): test_type_propagation() test_tuple() test_multiple_runs() + test_if_else() From 3386ef1bcd7c63ea2e43f8abbdbc1131f95732bb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Oct 2020 23:07:32 +0000 Subject: [PATCH 05/19] Add while_let testcase --- .../python/relay/test_pass_annotate_target.py | 112 +++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 51f2015b83c1..54645da44b38 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -404,10 +404,12 @@ def after(): equality_condition = relay.equal(cb_1, cb_2) ce_1 = relay.annotation.compiler_end(equality_condition, target) + # if condition cb_3 = relay.annotation.compiler_begin(data, target) true_branch = relay.tanh(cb_3) ce_2 = relay.annotation.compiler_end(true_branch, target) + # else condition cb_4 = relay.annotation.compiler_begin(data, target) false_branch = relay.sigmoid(cb_4) ce_3 = relay.annotation.compiler_end(false_branch, target) @@ -420,10 +422,117 @@ def after(): mod = tvm.IRModule.from_expr(func) return mod - result = transform.AnnotateTarget(target)(before()) + seq = tvm.transform.Sequential( + [ + transform.AnnotateTarget(target), + transform.MergeCompilerRegions(), + ]) + result = seq(before()) + + # result = transform.AnnotateTarget(target)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + +def test_while_let(): + target = "test_while_let" + + @tvm.ir.register_op_attr("less", "target." + target) + def less(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("add", "target." + target) + def add(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("If", "target." + target) + def If(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("scope_builder", "target." + target) + def scope_builder(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("Let", "target." + target) + def Let(attrs, args): # pylint: disable=unused-variable + return True + + """Test that let nodes compiles correctly when surrounded by other nodes.""" + + def before(): + + var1 = relay.var("var1", shape=(2,)) + var2 = relay.var("var2", shape=(), dtype="int32") + var3 = relay.var("var3", shape=(2,)) + cond = var2 < relay.const(10, dtype="int32") + + loop = relay.var("while_loop") + ii = var2 + relay.const(1, dtype="int32") + ss = var3 + var1 + true_branch = loop(ii,ss) + ife = relay.If(cond, true_branch, var3) + func_1 = relay.Function([var2, var3], ife) + + ret = relay.Let( + loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1)) + ) + func_2 = relay.Function([var1], ret) + mod = tvm.IRModule.from_expr(func_2) + return mod + + def after(): + var1 = relay.var("var1", shape=(2,)) + var2 = relay.var("var2", shape=(), dtype="int32") + var3 = relay.var("var3", shape=(2,)) + var4 = relay.const(10, dtype="int32") + + cb_1 = relay.annotation.compiler_begin(var2, target) + cb_2 = relay.annotation.compiler_begin(var4, target) + + less_condition = cb_1 < cb_2 + ce_1 = relay.annotation.compiler_end(less_condition, target) + + loop = relay.var("while_loop") + + # if condition + cb_3 = relay.annotation.compiler_begin(var2, target) + cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target) + add_op_1 = relay.add(cb_3, cb_4) + ce_2 = relay.annotation.compiler_end(add_op_1, target) + cb_5 = relay.annotation.compiler_begin(ce_2, target) + cb_6 = relay.annotation.compiler_begin(var3, target) + cb_7 = relay.annotation.compiler_begin(var1, target) + add_op_2 = relay.add(cb_6, cb_7) + ce_3 = relay.annotation.compiler_end(add_op_2, target) + cb_8 = relay.annotation.compiler_begin(ce_3, target) + true_branch = loop(cb_5, cb_8) # while loop + ce_4 = relay.annotation.compiler_end(true_branch, target) + if_condition = relay.If(ce_1, ce_4, var3) + + cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), target) + cb_10 = relay.annotation.compiler_begin(var1, target) + zeros_like = relay.zeros_like(cb_10) + while_condition = loop(cb_9, zeros_like) + ce_5 = relay.annotation.compiler_end(while_condition, target) + + func_1 = relay.Function([var2, var3], if_condition) + ret = relay.Let( + loop, func_1, ce_5 + ) + func_2 = relay.Function([var1], ret) + mod = tvm.IRModule.from_expr(func_2) + return mod + + seq = tvm.transform.Sequential( + [ + transform.AnnotateTarget(target), + transform.MergeCompilerRegions(), + ]) + result = seq(before()) expected = transform.InferType()(after()) + print(expected, result) assert tvm.ir.structural_equal(expected, result) + if __name__ == "__main__": test_extern_dnnl() test_composite_function() @@ -433,3 +542,4 @@ def after(): test_tuple() test_multiple_runs() test_if_else() + test_while_let() From c907927587b13ac33ed7d4c4412d7e86755b11cc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Oct 2020 23:08:13 +0000 Subject: [PATCH 06/19] No merging in ifelse --- tests/python/relay/test_pass_annotate_target.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 54645da44b38..2b49eb66a94d 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -422,14 +422,7 @@ def after(): mod = tvm.IRModule.from_expr(func) return mod - seq = tvm.transform.Sequential( - [ - transform.AnnotateTarget(target), - transform.MergeCompilerRegions(), - ]) - result = seq(before()) - - # result = transform.AnnotateTarget(target)(before()) + result = transform.AnnotateTarget(target)(before()) expected = transform.InferType()(after()) assert tvm.ir.structural_equal(expected, result) From c51bc1ad8bf4af2fb951951ec9412a477c353983 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Oct 2020 23:22:32 +0000 Subject: [PATCH 07/19] Remove scope builder --- tests/python/relay/test_pass_annotate_target.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 2b49eb66a94d..6029bccb8eea 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -441,10 +441,6 @@ def add(attrs, args): # pylint: disable=unused-variable def If(attrs, args): # pylint: disable=unused-variable return True - @tvm.ir.register_op_attr("scope_builder", "target." + target) - def scope_builder(attrs, args): # pylint: disable=unused-variable - return True - @tvm.ir.register_op_attr("Let", "target." + target) def Let(attrs, args): # pylint: disable=unused-variable return True @@ -523,7 +519,7 @@ def after(): result = seq(before()) expected = transform.InferType()(after()) print(expected, result) - assert tvm.ir.structural_equal(expected, result) + assert tvm.ir.structural_equal(expected, result, map_free_vars=True) if __name__ == "__main__": From e34a56210f694692c0953229207ae5a13a72a1cf Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Oct 2020 23:33:50 +0000 Subject: [PATCH 08/19] Add ops --- tests/python/relay/test_pass_annotate_target.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 6029bccb8eea..18526e5eef85 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -445,6 +445,14 @@ def If(attrs, args): # pylint: disable=unused-variable def Let(attrs, args): # pylint: disable=unused-variable return True + @tvm.ir.register_op_attr("const", "target." + target) + def const(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("zeros_like", "target." + target) + def zeros_like(attrs, args): # pylint: disable=unused-variable + return True + """Test that let nodes compiles correctly when surrounded by other nodes.""" def before(): @@ -500,12 +508,14 @@ def after(): cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), target) cb_10 = relay.annotation.compiler_begin(var1, target) zeros_like = relay.zeros_like(cb_10) - while_condition = loop(cb_9, zeros_like) - ce_5 = relay.annotation.compiler_end(while_condition, target) + ce_5 = relay.annotation.compiler_end(zeros_like, target) + cb_11 = relay.annotation.compiler_begin(ce_5, target) + while_condition = loop(cb_9, cb_11) + ce_6 = relay.annotation.compiler_end(while_condition, target) func_1 = relay.Function([var2, var3], if_condition) ret = relay.Let( - loop, func_1, ce_5 + loop, func_1, ce_6 ) func_2 = relay.Function([var1], ret) mod = tvm.IRModule.from_expr(func_2) @@ -514,7 +524,6 @@ def after(): seq = tvm.transform.Sequential( [ transform.AnnotateTarget(target), - transform.MergeCompilerRegions(), ]) result = seq(before()) expected = transform.InferType()(after()) From fe8bd29290221f017fd9be144ee1e3177aa10684 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 8 Oct 2020 00:24:26 +0000 Subject: [PATCH 09/19] Replace < with less --- tests/python/relay/test_pass_annotate_target.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 18526e5eef85..bfef53dc6ad2 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -460,7 +460,7 @@ def before(): var1 = relay.var("var1", shape=(2,)) var2 = relay.var("var2", shape=(), dtype="int32") var3 = relay.var("var3", shape=(2,)) - cond = var2 < relay.const(10, dtype="int32") + cond = relay.less(var2, relay.const(10, dtype="int32")) loop = relay.var("while_loop") ii = var2 + relay.const(1, dtype="int32") @@ -485,7 +485,7 @@ def after(): cb_1 = relay.annotation.compiler_begin(var2, target) cb_2 = relay.annotation.compiler_begin(var4, target) - less_condition = cb_1 < cb_2 + less_condition = relay.less(cb_1, cb_2) ce_1 = relay.annotation.compiler_end(less_condition, target) loop = relay.var("while_loop") @@ -527,7 +527,7 @@ def after(): ]) result = seq(before()) expected = transform.InferType()(after()) - print(expected, result) + # print(expected, result) assert tvm.ir.structural_equal(expected, result, map_free_vars=True) From 39f6220c25e7540457f159bf9caecb178ffeb739 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 8 Oct 2020 17:13:58 +0000 Subject: [PATCH 10/19] Linter --- .../python/relay/test_pass_annotate_target.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index bfef53dc6ad2..def897fd187d 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -352,6 +352,7 @@ def before(): expected = transform.AnnotateTarget(["A", "B"])(before()) assert tvm.ir.structural_equal(expected, mod) + def test_if_else(): target = "test_if_else" @@ -378,9 +379,9 @@ def erf(attrs, args): # pylint: disable=unused-variable """Test that If-else nodes compiles correctly when surrounded by supported nodes.""" def before(): - data = relay.var('data', shape=(1, 32)) - eq1 = relay.var('e1', shape=[], dtype='int32') - eq2 = relay.var('e2', shape=[], dtype='int32') + data = relay.var("data", shape=(1, 32)) + eq1 = relay.var("e1", shape=[], dtype="int32") + eq2 = relay.var("e2", shape=[], dtype="int32") eq = relay.equal(eq1, eq2) true_branch = relay.tanh(data) @@ -394,9 +395,9 @@ def before(): def after(): - data = relay.var('data', shape=(1, 32)) - eq1 = relay.var('e1', shape=[], dtype='int32') - eq2 = relay.var('e2', shape=[], dtype='int32') + data = relay.var("data", shape=(1, 32)) + eq1 = relay.var("e1", shape=[], dtype="int32") + eq2 = relay.var("e2", shape=[], dtype="int32") cb_1 = relay.annotation.compiler_begin(eq1, target) cb_2 = relay.annotation.compiler_begin(eq2, target) @@ -426,6 +427,7 @@ def after(): expected = transform.InferType()(after()) assert tvm.ir.structural_equal(expected, result) + def test_while_let(): target = "test_while_let" @@ -464,14 +466,12 @@ def before(): loop = relay.var("while_loop") ii = var2 + relay.const(1, dtype="int32") - ss = var3 + var1 - true_branch = loop(ii,ss) + ss = var3 + var1 + true_branch = loop(ii, ss) ife = relay.If(cond, true_branch, var3) func_1 = relay.Function([var2, var3], ife) - ret = relay.Let( - loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1)) - ) + ret = relay.Let(loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1))) func_2 = relay.Function([var1], ret) mod = tvm.IRModule.from_expr(func_2) return mod @@ -481,7 +481,7 @@ def after(): var2 = relay.var("var2", shape=(), dtype="int32") var3 = relay.var("var3", shape=(2,)) var4 = relay.const(10, dtype="int32") - + cb_1 = relay.annotation.compiler_begin(var2, target) cb_2 = relay.annotation.compiler_begin(var4, target) @@ -501,7 +501,7 @@ def after(): add_op_2 = relay.add(cb_6, cb_7) ce_3 = relay.annotation.compiler_end(add_op_2, target) cb_8 = relay.annotation.compiler_begin(ce_3, target) - true_branch = loop(cb_5, cb_8) # while loop + true_branch = loop(cb_5, cb_8) # while loop ce_4 = relay.annotation.compiler_end(true_branch, target) if_condition = relay.If(ce_1, ce_4, var3) @@ -514,17 +514,16 @@ def after(): ce_6 = relay.annotation.compiler_end(while_condition, target) func_1 = relay.Function([var2, var3], if_condition) - ret = relay.Let( - loop, func_1, ce_6 - ) + ret = relay.Let(loop, func_1, ce_6) func_2 = relay.Function([var1], ret) mod = tvm.IRModule.from_expr(func_2) return mod seq = tvm.transform.Sequential( - [ - transform.AnnotateTarget(target), - ]) + [ + transform.AnnotateTarget(target), + ] + ) result = seq(before()) expected = transform.InferType()(after()) # print(expected, result) From 627ccd0c11f8f4549224ccfa2d740c1c2fd1f74b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 9 Oct 2020 20:30:02 +0000 Subject: [PATCH 11/19] Pass Tests --- .../python/relay/test_pass_annotate_target.py | 46 +++++-------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index def897fd187d..ba1de7416384 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -368,10 +368,6 @@ def tanh(attrs, args): # pylint: disable=unused-variable def sigmoid(attrs, args): # pylint: disable=unused-variable return True - @tvm.ir.register_op_attr("If", "target." + target) - def If(attrs, args): # pylint: disable=unused-variable - return True - @tvm.ir.register_op_attr("erf", "target." + target) def erf(attrs, args): # pylint: disable=unused-variable return True @@ -380,8 +376,8 @@ def erf(attrs, args): # pylint: disable=unused-variable def before(): data = relay.var("data", shape=(1, 32)) - eq1 = relay.var("e1", shape=[], dtype="int32") - eq2 = relay.var("e2", shape=[], dtype="int32") + eq1 = relay.var("e1", shape=[], dtype="float32") + eq2 = relay.var("e2", shape=[], dtype="float32") eq = relay.equal(eq1, eq2) true_branch = relay.tanh(data) @@ -396,8 +392,8 @@ def before(): def after(): data = relay.var("data", shape=(1, 32)) - eq1 = relay.var("e1", shape=[], dtype="int32") - eq2 = relay.var("e2", shape=[], dtype="int32") + eq1 = relay.var("e1", shape=[], dtype="float32") + eq2 = relay.var("e2", shape=[], dtype="float32") cb_1 = relay.annotation.compiler_begin(eq1, target) cb_2 = relay.annotation.compiler_begin(eq2, target) @@ -439,18 +435,6 @@ def less(attrs, args): # pylint: disable=unused-variable def add(attrs, args): # pylint: disable=unused-variable return True - @tvm.ir.register_op_attr("If", "target." + target) - def If(attrs, args): # pylint: disable=unused-variable - return True - - @tvm.ir.register_op_attr("Let", "target." + target) - def Let(attrs, args): # pylint: disable=unused-variable - return True - - @tvm.ir.register_op_attr("const", "target." + target) - def const(attrs, args): # pylint: disable=unused-variable - return True - @tvm.ir.register_op_attr("zeros_like", "target." + target) def zeros_like(attrs, args): # pylint: disable=unused-variable return True @@ -495,23 +479,23 @@ def after(): cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target) add_op_1 = relay.add(cb_3, cb_4) ce_2 = relay.annotation.compiler_end(add_op_1, target) - cb_5 = relay.annotation.compiler_begin(ce_2, target) + cb_5 = relay.annotation.compiler_begin(ce_2, "default") cb_6 = relay.annotation.compiler_begin(var3, target) cb_7 = relay.annotation.compiler_begin(var1, target) add_op_2 = relay.add(cb_6, cb_7) ce_3 = relay.annotation.compiler_end(add_op_2, target) - cb_8 = relay.annotation.compiler_begin(ce_3, target) + cb_8 = relay.annotation.compiler_begin(ce_3, "default") true_branch = loop(cb_5, cb_8) # while loop - ce_4 = relay.annotation.compiler_end(true_branch, target) + ce_4 = relay.annotation.compiler_end(true_branch, "default") if_condition = relay.If(ce_1, ce_4, var3) - cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), target) + cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), "default") cb_10 = relay.annotation.compiler_begin(var1, target) zeros_like = relay.zeros_like(cb_10) ce_5 = relay.annotation.compiler_end(zeros_like, target) - cb_11 = relay.annotation.compiler_begin(ce_5, target) + cb_11 = relay.annotation.compiler_begin(ce_5, "default") while_condition = loop(cb_9, cb_11) - ce_6 = relay.annotation.compiler_end(while_condition, target) + ce_6 = relay.annotation.compiler_end(while_condition, "default") func_1 = relay.Function([var2, var3], if_condition) ret = relay.Let(loop, func_1, ce_6) @@ -519,15 +503,9 @@ def after(): mod = tvm.IRModule.from_expr(func_2) return mod - seq = tvm.transform.Sequential( - [ - transform.AnnotateTarget(target), - ] - ) - result = seq(before()) + result = transform.AnnotateTarget(target)(before()) expected = transform.InferType()(after()) - # print(expected, result) - assert tvm.ir.structural_equal(expected, result, map_free_vars=True) + assert tvm.ir.structural_equal(expected, result) if __name__ == "__main__": From 11cb3b0b8510aa8d003a954c8ecca6a094c6693a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 9 Oct 2020 20:35:43 +0000 Subject: [PATCH 12/19] Change back to static const --- src/relay/transforms/annotate_target.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index add997c41a26..b67c61dc9857 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -35,9 +35,9 @@ namespace tvm { namespace relay { namespace annotate_target { -const PackedFunc* make_begin_op = +static const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); -const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); +static const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. From aed106121b5d7af056c542fa0faacd0ea10740d3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 9 Oct 2020 20:42:25 +0000 Subject: [PATCH 13/19] Cpplinter --- src/relay/transforms/annotate_target.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index b67c61dc9857..2534a536ddc7 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -37,7 +37,8 @@ namespace annotate_target { static const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); -static const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); +static const PackedFunc* make_end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. From 3a9faea1825016886f4e759a6b8d5f0f51de98b4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Oct 2020 19:46:40 +0000 Subject: [PATCH 14/19] address PR comments' --- src/relay/transforms/annotate_target.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 2534a536ddc7..4bd912073ce4 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -128,9 +128,9 @@ class AnnotateTargetRewriter : public ExprRewriter { return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); } - // Peek the first argument. If it is compiler begin then this node had annotated by - // another target before, so we also consider that target as a supported target. if (pre->args.size()) { + // Peek the first argument. If it is compiler begin then this node had annotated by + // another target before, so we also consider that target as a supported target. const CallNode* first_arg_call = pre->args[0].as(); if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { std::string arg_target = first_arg_call->attrs.as()->compiler; From cd86d23bae0b4c07e839f0730d85097e69fa176b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Oct 2020 19:53:20 +0000 Subject: [PATCH 15/19] PR Comments --- src/relay/transforms/annotate_target.cc | 32 +++++++++++-------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 4bd912073ce4..cd48fc798e7d 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -108,6 +108,15 @@ class AnnotateTargetRewriter : public ExprRewriter { return new_op; } + Expr InsertCompilerEndAndPropogateTarget(const Expr& expr) { + Expr new_expr = expr; + if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { + new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op); + op_expr_to_target_[new_expr] = op_expr_to_target_[expr]; + } + return std::move(new_expr); + } + Expr Rewrite_(const CallNode* pre, const Expr& post) final { // Supported targets for this node. The order implies the priority. std::vector supported_targets; @@ -224,11 +233,7 @@ class AnnotateTargetRewriter : public ExprRewriter { new_body = func->body; } else { func = Downcast(post); - new_body = func->body; - if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { - new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); - op_expr_to_target_[new_body] = op_expr_to_target_[func->body]; - } + new_body = InsertCompilerEndAndPropogateTarget(func->body); } return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); } @@ -238,7 +243,7 @@ class AnnotateTargetRewriter : public ExprRewriter { Expr new_expr; std::pair> target_n_args; - Expr new_body = InsertCompilerEnd(let->body); + Expr new_body = InsertCompilerEndAndPropogateTarget(let->body); bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { new_expr = Let(let->var, let->value, new_body); @@ -250,20 +255,11 @@ class AnnotateTargetRewriter : public ExprRewriter { return std::move(new_expr); } - Expr InsertCompilerEnd(const Expr& expr) { - Expr new_expr = expr; - if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { - new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op); - op_expr_to_target_[new_expr] = op_expr_to_target_[expr]; - } - return std::move(new_expr); - } - Expr Rewrite_(const IfNode* op, const Expr& post) final { auto expr = Downcast(post); - Expr new_cond = InsertCompilerEnd(expr->cond); - Expr new_true_branch = InsertCompilerEnd(expr->true_branch); - Expr new_false_branch = InsertCompilerEnd(expr->false_branch); + Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond); + Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch); + Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch); auto new_expr = If(new_cond, new_true_branch, new_false_branch); return std::move(new_expr); From 85ea7715c63ba1441ffef887fac976674b12ff8a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Oct 2020 20:41:31 +0000 Subject: [PATCH 16/19] Clang-format check --- src/relay/transforms/annotate_target.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index cd48fc798e7d..dd5e8a180ab9 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -138,8 +138,8 @@ class AnnotateTargetRewriter : public ExprRewriter { } if (pre->args.size()) { - // Peek the first argument. If it is compiler begin then this node had annotated by - // another target before, so we also consider that target as a supported target. + // Peek the first argument. If it is compiler begin then this node had annotated by + // another target before, so we also consider that target as a supported target. const CallNode* first_arg_call = pre->args[0].as(); if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { std::string arg_target = first_arg_call->attrs.as()->compiler; @@ -244,8 +244,8 @@ class AnnotateTargetRewriter : public ExprRewriter { Expr new_expr; std::pair> target_n_args; Expr new_body = InsertCompilerEndAndPropogateTarget(let->body); - bool is_functional_literal = let->value.as() != nullptr; - if (is_functional_literal) { + // Do not annotate function literal with let binding. + if (let->value->IsInstance()) { new_expr = Let(let->var, let->value, new_body); } else { target_n_args = AnnotateArgs({let->value}); From a2ebe0e2a7a2e07f37799b2d1d8fcc3fecf88f2a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Oct 2020 21:47:50 +0000 Subject: [PATCH 17/19] PR Comments --- src/relay/transforms/annotate_target.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index dd5e8a180ab9..331220cf13b6 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -77,7 +77,7 @@ class AnnotateTargetRewriter : public ExprRewriter { compiler_ends.push_back(call->args[0]); } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { arg_target = op_expr_to_target_[arg]; - compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); + compiler_ends.push_back(InsertCompilerEndAndPropogateTarget(arg)); } else { // Input vars. compiler_ends.push_back(arg); @@ -136,7 +136,7 @@ class AnnotateTargetRewriter : public ExprRewriter { CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); } - + // Check prior to peeking first argument if (pre->args.size()) { // Peek the first argument. If it is compiler begin then this node had annotated by // another target before, so we also consider that target as a supported target. From d17ec1138f9f4269b741564b3d5482d4a88f1753 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Oct 2020 21:55:07 +0000 Subject: [PATCH 18/19] PR Comments --- src/relay/transforms/annotate_target.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 331220cf13b6..5dff4b19c2cd 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -109,6 +109,16 @@ class AnnotateTargetRewriter : public ExprRewriter { } Expr InsertCompilerEndAndPropogateTarget(const Expr& expr) { + /*! + * \brief This function inserts compiler end to expr and maps the corresponding target to the + * new expression. + * + * This function checks for expr existence within the map and inserts the annotation + * Further, it propagates the target to the new expression and returns it + * + * \param expr A relay expression + * \return An annotated and target-propagated relay expression. + */ Expr new_expr = expr; if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op); From 110602fac26284effe66136f511779006c60e15f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Oct 2020 22:25:32 +0000 Subject: [PATCH 19/19] Change back to Insert Ann in AnnotateARgs --- src/relay/transforms/annotate_target.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 5dff4b19c2cd..015489dd0857 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -77,7 +77,7 @@ class AnnotateTargetRewriter : public ExprRewriter { compiler_ends.push_back(call->args[0]); } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { arg_target = op_expr_to_target_[arg]; - compiler_ends.push_back(InsertCompilerEndAndPropogateTarget(arg)); + compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); } else { // Input vars. compiler_ends.push_back(arg);