From 546e2083acf098edb3857b3380205aa33692dbaa Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 May 2024 21:01:55 -0400 Subject: [PATCH 1/2] [TIR] Fix Bug in VectorizeLoop This PR fixes a bug in vectorize loop introduced related to recent change. The visit to condition can write need scalarize to true then the followup visit to then case can trigger an ICHECK. The visit to let value can also write need scalarize flag in which case we need to immediately scalarize. --- src/tir/transforms/vectorize_loop.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index aa62d5850513..63569f342aed 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -676,12 +676,16 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); + // need scalarize can be marked as true during visit of condition + bool cond_need_scalarize = false; + std::swap(cond_need_scalarize, need_scalarize_); + // temp clear need_scalarize flag, so VisitStmt + // won't trigger an ICHECK eror Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } - // Check if we can rewrite the condition with predicated buffers if (EnableBufferLevelPredication(target_) && condition.dtype().is_scalable_or_fixed_length_vector() && !else_case.defined()) { @@ -693,7 +697,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && @@ -710,6 +714,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); + // if visit of value triggers need scalarize + // we need to scalarize the let + if (need_scalarize_) { + need_scalarize_ = false; + Scalarize(GetRef(op)); + } ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; From b22feb5449bd5821db99e370a0ea621f02ef99df Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Wed, 29 May 2024 01:08:53 -0400 Subject: [PATCH 2/2] Add unit test --- .../test_tir_transform_vectorize.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index e02c227b05b7..7523cab54941 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm import te from tvm.script import ir as I from tvm.script import tir as T -import pytest - simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu") sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve") @@ -312,6 +312,29 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32): tvm.ir.assert_structural_equal(mod, After) +def test_vectorize_let_if_then_else(): + @I.ir_module + class Before: + @T.prim_func + def main(): + for i in T.vectorized(4): + if i < 2: + result: T.int32 = T.if_then_else(i < 1, 1, 2) + + @I.ir_module + class After: + @T.prim_func + def main(): + for i_s in range(4): + if i_s < 2: + result: T.int32 = T.if_then_else(i_s < 1, 1, 2) + T.evaluate(0) + + with tvm.target.Target(simple_target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + def test_vectorize_while_fail(): """A while loop inside a vectorized loop should fail."""