Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Support control flow in annotate_target #6641

Merged
merged 19 commits into from
Oct 13, 2020
44 changes: 31 additions & 13 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ class AnnotateTargetRewriter : public ExprRewriter {

// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
if (pre->args.size()) {
comaniac marked this conversation as resolved.
Show resolved Hide resolved
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
}
}
}

Expand Down Expand Up @@ -234,20 +236,36 @@ class AnnotateTargetRewriter : public ExprRewriter {
Expr Rewrite_(const LetNode* op, const Expr& post) final {
auto let = Downcast<Let>(post);

auto target_n_args = AnnotateArgs({let->value, let->body});
auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
Expr new_expr;
std::pair<std::string, Array<Expr>> target_n_args;
Expr new_body = InsertCompilerEnd(let->body);
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
if (is_functional_literal) {
comaniac marked this conversation as resolved.
Show resolved Hide resolved
new_expr = Let(let->var, let->value, new_body);
} else {
target_n_args = AnnotateArgs({let->value});
new_expr = Let(let->var, std::get<1>(target_n_args)[0], new_body);
}

return std::move(new_expr);
}

Expr InsertCompilerEnd(const Expr& expr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Put this function together with AnnotateArgs.
  • Add formal function description.
  • From its functionality, the name like InsertCompilerEndAndPropogateTarget would be more proper.
  • Use this function in Expr Rewrite_(const FunctionNode* fn, const Expr& post) to avoid redundant logic.

Copy link
Contributor Author

@codeislife99 codeislife99 Oct 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Although using this function in AnnotateArgs makes us do the same condition check twice which I think is unnecessary. For the time being I added it, but let me know your thoughts. Maybe I misinterpreted what you meant by Put this function together with AnnotateArgs ?

Copy link
Contributor

@comaniac comaniac Oct 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not being clear. For the last point, I meant you can use InertCompilerEndAndPropogateTarget in https://github.com/apache/incubator-tvm/blob/main/src/relay/transforms/annotate_target.cc#L225 to replace the same logic. This point has nothing to do with AnnotateArgs tho

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see , oh so thats already done now in the previous commit. I will remove it from AnnotateArgs

Expr new_expr = expr;
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op);
op_expr_to_target_[new_expr] = op_expr_to_target_[expr];
}
return std::move(new_expr);
}

Expr Rewrite_(const IfNode* op, const Expr& post) final {
auto expr = Downcast<If>(post);
Expr new_cond = InsertCompilerEnd(expr->cond);
Expr new_true_branch = InsertCompilerEnd(expr->true_branch);
Expr new_false_branch = InsertCompilerEnd(expr->false_branch);

auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1],
std::get<1>(target_n_args)[2]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
auto new_expr = If(new_cond, new_true_branch, new_false_branch);
return std::move(new_expr);
}

Expand Down
157 changes: 157 additions & 0 deletions tests/python/relay/test_pass_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,161 @@ def before():
assert tvm.ir.structural_equal(expected, mod)


def test_if_else():
target = "test_if_else"

@tvm.ir.register_op_attr("equal", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("tanh", "target." + target)
def tanh(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("erf", "target." + target)
def erf(attrs, args): # pylint: disable=unused-variable
return True

"""Test that If-else nodes compiles correctly when surrounded by supported nodes."""

def before():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
eq = relay.equal(eq1, eq2)

true_branch = relay.tanh(data)
false_branch = relay.sigmoid(data)
ife = relay.If(eq, true_branch, false_branch)
out = relay.erf(ife)
func = relay.Function([data, eq1, eq2], out)
mod = tvm.IRModule.from_expr(func)

return mod

def after():

data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")

cb_1 = relay.annotation.compiler_begin(eq1, target)
cb_2 = relay.annotation.compiler_begin(eq2, target)

equality_condition = relay.equal(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(equality_condition, target)

# if condition
cb_3 = relay.annotation.compiler_begin(data, target)
true_branch = relay.tanh(cb_3)
ce_2 = relay.annotation.compiler_end(true_branch, target)

# else condition
cb_4 = relay.annotation.compiler_begin(data, target)
false_branch = relay.sigmoid(cb_4)
ce_3 = relay.annotation.compiler_end(false_branch, target)

if_condition = relay.If(ce_1, ce_2, ce_3)
cb_5 = relay.annotation.compiler_begin(if_condition, target)
erf_out = relay.erf(cb_5)
ce_4 = relay.annotation.compiler_end(erf_out, target)
func = relay.Function([data, eq1, eq2], ce_4)
mod = tvm.IRModule.from_expr(func)
return mod

result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


def test_while_let():
target = "test_while_let"

@tvm.ir.register_op_attr("less", "target." + target)
def less(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("add", "target." + target)
def add(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("zeros_like", "target." + target)
def zeros_like(attrs, args): # pylint: disable=unused-variable
return True

"""Test that let nodes compiles correctly when surrounded by other nodes."""

def before():

var1 = relay.var("var1", shape=(2,))
var2 = relay.var("var2", shape=(), dtype="int32")
var3 = relay.var("var3", shape=(2,))
cond = relay.less(var2, relay.const(10, dtype="int32"))

loop = relay.var("while_loop")
ii = var2 + relay.const(1, dtype="int32")
ss = var3 + var1
true_branch = loop(ii, ss)
ife = relay.If(cond, true_branch, var3)
func_1 = relay.Function([var2, var3], ife)

ret = relay.Let(loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1)))
func_2 = relay.Function([var1], ret)
mod = tvm.IRModule.from_expr(func_2)
return mod

def after():
var1 = relay.var("var1", shape=(2,))
var2 = relay.var("var2", shape=(), dtype="int32")
var3 = relay.var("var3", shape=(2,))
var4 = relay.const(10, dtype="int32")

cb_1 = relay.annotation.compiler_begin(var2, target)
cb_2 = relay.annotation.compiler_begin(var4, target)

less_condition = relay.less(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(less_condition, target)

loop = relay.var("while_loop")

# if condition
cb_3 = relay.annotation.compiler_begin(var2, target)
cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target)
add_op_1 = relay.add(cb_3, cb_4)
ce_2 = relay.annotation.compiler_end(add_op_1, target)
cb_5 = relay.annotation.compiler_begin(ce_2, "default")
cb_6 = relay.annotation.compiler_begin(var3, target)
cb_7 = relay.annotation.compiler_begin(var1, target)
add_op_2 = relay.add(cb_6, cb_7)
ce_3 = relay.annotation.compiler_end(add_op_2, target)
cb_8 = relay.annotation.compiler_begin(ce_3, "default")
true_branch = loop(cb_5, cb_8) # while loop
ce_4 = relay.annotation.compiler_end(true_branch, "default")
if_condition = relay.If(ce_1, ce_4, var3)

cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), "default")
cb_10 = relay.annotation.compiler_begin(var1, target)
zeros_like = relay.zeros_like(cb_10)
ce_5 = relay.annotation.compiler_end(zeros_like, target)
cb_11 = relay.annotation.compiler_begin(ce_5, "default")
while_condition = loop(cb_9, cb_11)
ce_6 = relay.annotation.compiler_end(while_condition, "default")

func_1 = relay.Function([var2, var3], if_condition)
ret = relay.Let(loop, func_1, ce_6)
func_2 = relay.Function([var1], ret)
mod = tvm.IRModule.from_expr(func_2)
return mod

result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


if __name__ == "__main__":
test_extern_dnnl()
test_composite_function()
Expand All @@ -361,3 +516,5 @@ def before():
test_type_propagation()
test_tuple()
test_multiple_runs()
test_if_else()
test_while_let()