diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 1d3475b13dad..0087866ea4f8 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -474,6 +474,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 641eed51d5cf..9ff9ff18e5b5 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import te @@ -931,20 +932,13 @@ def test_shift_left_simplify(): ck.verify(z, tvm.tir.const(1 << 10, "int32")) +def test_div_zero_simplify(): + ck = RewriteChecker() + + with pytest.raises(tvm.error.TVMError) as cm: + ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + assert "division by zero" in str(cm.execption) + + if __name__ == "__main__": - test_floordiv_index_simplify() - test_floormod_index_simplify() - test_cmp_simplify() - test_vector_simplify() - test_add_index_simplify() - test_sub_index_simplify() - test_mul_index_simplify() - test_div_index_simplify() - test_max_index_simplify() - test_min_index_simplify() - test_mod_index_simplify() - test_select_simplify() - test_logical_simplify() - test_let_simplify() - test_cast_simplify() - test_shift_left_simplify() + pytest.main([__file__])