diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index aa62d5850513..63569f342aed 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -676,12 +676,16 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); + // need scalarize can be marked as true during visit of condition + bool cond_need_scalarize = false; + std::swap(cond_need_scalarize, need_scalarize_); + // temp clear need_scalarize flag, so VisitStmt + // won't trigger an ICHECK eror Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } - // Check if we can rewrite the condition with predicated buffers if (EnableBufferLevelPredication(target_) && condition.dtype().is_scalable_or_fixed_length_vector() && !else_case.defined()) { @@ -693,7 +697,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && @@ -710,6 +714,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); + // if visit of value triggers need scalarize + // we need to scalarize the let + if (need_scalarize_) { + need_scalarize_ = false; + Scalarize(GetRef(op)); + } ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index e02c227b05b7..7523cab54941 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm import te from tvm.script import ir as I from tvm.script import tir as T -import pytest - simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu") sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve") @@ -312,6 +312,29 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32): tvm.ir.assert_structural_equal(mod, After) +def test_vectorize_let_if_then_else(): + @I.ir_module + class Before: + @T.prim_func + def main(): + for i in T.vectorized(4): + if i < 2: + result: T.int32 = T.if_then_else(i < 1, 1, 2) + + @I.ir_module + class After: + @T.prim_func + def main(): + for i_s in range(4): + if i_s < 2: + result: T.int32 = T.if_then_else(i_s < 1, 1, 2) + T.evaluate(0) + + with tvm.target.Target(simple_target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + def test_vectorize_while_fail(): """A while loop inside a vectorized loop should fail."""