Skip to content

Commit

Permalink
[Arith] Simplify nested if_then_else (#12749)
Browse files Browse the repository at this point in the history
[Arith] Simplify nested if_then_else

Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
  • Loading branch information
vinx13 and spectrometerHBH authored Sep 15, 2022
1 parent 1f8b5de commit 9b10425
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,26 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
}
}

if (op->op.same_as(tir::builtin::if_then_else())) {
// Simplify nested if_then_else
// if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr } } else { else_expr }
// => if (cond && inner_cond) { inner_then_expr } else { else_expr }
const PrimExpr& cond = op->args[0];
const PrimExpr& then_expr = op->args[1];
const PrimExpr& else_expr = op->args[2];
const CallNode* inner_call = then_expr.as<CallNode>();
if (inner_call != nullptr && inner_call->op.same_as(tir::builtin::if_then_else())) {
const PrimExpr& inner_cond = inner_call->args[0];
const PrimExpr& inner_then_expr = inner_call->args[1];
const PrimExpr& inner_else_expr = inner_call->args[2];
// Only check constant cases to avoid recursion
if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
analyzer_->CanProve(inner_else_expr == else_expr)) {
return if_then_else(cond && inner_cond, inner_then_expr, else_expr);
}
}
}

return ret;
}

Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,5 +992,15 @@ def test_sub_bufferload():
ck.verify(expr, 0.0)


def test_if_then_else_simplify():
ck = RewriteChecker()
x = te.var("x", "int32")
z = tvm.tir.if_then_else(x < 5, tvm.tir.if_then_else(x > 1, 1, 0), 0)
ck.verify(z, tvm.tir.if_then_else(tvm.tir.And(tvm.tir.LT(x, 5), tvm.tir.LT(1, x)), 1, 0))

z = tvm.tir.if_then_else(x > 2, tvm.tir.if_then_else(x > 1, 1, 0), 0)
ck.verify(z, tvm.tir.if_then_else(tvm.tir.LT(2, x), 1, 0))


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 9b10425

Please sign in to comment.