From 40abc863f6954c86298a68052d84b42f5568bfc0 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 8 Jan 2021 00:24:31 +0800 Subject: [PATCH] [Arith] Simplify cast (#7045) --- src/arith/canonical_simplify.cc | 161 ++++++++++++++++++ .../unittest/test_arith_canonical_simplify.py | 41 +++++ 2 files changed, 202 insertions(+) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index d0a0702a0fb0f..ba549959ac985 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -77,6 +77,27 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { } } +/*! + * \brief check if value fits in dtype + * \param value The value to be analyzed + * \param dtype The target dtype + * \param analyzer The analyzer + * \return whether value fits in dtype + */ +bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) { + if (!IsIndexType(dtype)) { + return false; + } + ConstIntBound bound = analyzer->const_int_bound(value); + int64_t ubound = Downcast(max_value(dtype))->value; + int64_t lbound = Downcast(min_value(dtype))->value; + if (value.dtype().bits() <= dtype.bits() || // upcast is safe + (bound->max_value <= ubound && bound->min_value >= lbound)) { + return true; + } + return false; +} + /*! * \brief Internal "Split normal form" of expression. * @@ -128,6 +149,58 @@ class SplitExprNode : public CanonicalExprNode { void MulToSelf(int64_t scale) { this->scale *= scale; } + /*! + * \brief check if cast can be pushed to sub-expressions + * \param dtype The target datatype + * \param analyzer The analyzer + * \return whether the cast can be safely pushed to children + */ + bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { + // cast(dtype, index % upper_factor / lower_factor * scale) == + // cast(dtype, index) % upper_factor / lower_factor * scale + // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of + // its intermediate results fit in the range of dtype + if (dtype.bits() >= this->dtype.bits()) { + return true; // upcast is safe + } + PrimExpr res = this->index; + if (this->scale == 0) { + return true; + } + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + if (this->upper_factor != SplitExprNode::kPosInf) { + res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode); + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + } + if (this->lower_factor != 1) { + res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode); + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + } + if (this->scale != 1) { + ICHECK(!this->dtype.is_uint() || this->scale > 0); + res = res * make_const(this->dtype, this->scale); + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + } + return true; + } + + /*! + * \brief self = cast(dtype, self) + * \param dtype The target datatype + */ + void PushCastToChildren(DataType dtype) { + this->index = cast(dtype, this->index); + this->dtype = dtype; + } + inline bool IndexEqual(const SplitExpr& other) const; inline bool DivModeCompatibleTo(DivMode mode) const; @@ -255,6 +328,69 @@ class SumExprNode : public CanonicalExprNode { void AddToSelf(const SumExpr& other, int64_t scale); + /*! + * \brief check if cast can be pushed to sub-expressions + * \param dtype The target datatype + * \param analyzer The analyzer + * \return whether the cast can be safely pushed to children + */ + bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { + // cast(dtype, arg_1 + arg_2 + ... arg_n) == + // cast(dtype, arg_1) + ... + cast(dtype, arg_n) + // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of + // its intermediate results fit in the range of dtype + if (dtype.bits() >= this->dtype.bits()) { + return true; // upcast is safe + } + PrimExpr res = make_const(dtype, 0); + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->scale > 0) { + res = res + args[i]->Normalize(); + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + } + } + if (base > 0) { + res = res + make_const(dtype, base); + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + } + // negative scales follows using sub. + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->scale < 0) { + res = res - args[i]->NormalizeWithScale(-1); + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + } + } + if (base < 0) { + res = res - make_const(dtype, -base); + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } + } + for (const auto& arg : args) { + if (!arg->CanPushCastToChildren(dtype, analyzer)) { + return false; + } + } + return true; + } + + /*! + * \brief self = cast(dtype, self) + * \param dtype The target datatype + */ + void PushCastToChildren(DataType dtype) { + for (auto& arg : args) { + arg.CopyOnWrite()->PushCastToChildren(dtype); + } + this->dtype = dtype; + } + static constexpr const char* _type_key = "arith.SumExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode); @@ -430,6 +566,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { PrimExpr VisitExpr_(const FloorDivNode* op) final; PrimExpr VisitExpr_(const FloorModNode* op) final; PrimExpr VisitExpr_(const ReduceNode* op) final; + PrimExpr VisitExpr_(const CastNode* op) final; private: /*! @@ -1071,6 +1208,30 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { return ret; } +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { + if (!IsIndexType(op->dtype)) { + return Rewriter::VisitExpr_(op); + } + // normalize + PrimExpr value = this->CanonicalMutate(op->value); + // PushCastToChildren + if (value.as()) { + SumExpr se = Downcast(value); + if (se->CanPushCastToChildren(op->dtype, analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->dtype); + return std::move(se); + } + } + if (value.as()) { + SplitExpr se = Downcast(value); + if (se->CanPushCastToChildren(op->dtype, analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->dtype); + return std::move(se); + } + } + return Rewriter::VisitExpr_(op); +} + PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { return impl_->CanonicalSimplify(expr); } diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 65c8ec3dfe02c..c241b81da986b 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -310,6 +310,46 @@ def test_complex_cases(): ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4)) +def test_simplify_cast(): + ck = CanonicalChecker() + tcast = tvm.tir.Cast + fld = tvm.te.floordiv + flm = tvm.te.floormod + # cast(i64, i + j + 1) - cast(i64, i) + i = te.var("i", dtype="int32") + j = te.var("j", dtype="int32") + res = tcast("int64", i + j + 1) - tcast("int64", i) + ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64")) + # cast(i32, i + j + 1) - cast(i32, i) + i = te.var("i", dtype="int64") + j = te.var("j", dtype="int64") + ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10)) + ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) + res = tcast("int32", i + j + 1) - tcast("int32", i) + ck.verify(res, tcast("int32", j) + 1) + # cast(i32, i + j - 100) + i = te.var("i", dtype="int64") + j = te.var("j", dtype="int64") + ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1)) + ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) + res = tcast("int32", i + j - 100) + ck.verify(res, res) + # cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32 + # - cast(i32, flm(axis, 7i64) * 2i64) + axis = te.var("axis", dtype="int64") + ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42)) + res = ( + tcast( + "int32", + flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64") + + tvm.tir.const(1, "int64"), + ) + + tvm.tir.const(1, "int32") + - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")) + ) + ck.verify(res, 2) + + if __name__ == "__main__": test_floormod_simplify() test_mul_sum_simplify() @@ -321,3 +361,4 @@ def test_complex_cases(): test_split_index_simplify() test_canonical_mixed() test_complex_cases() + test_simplify_cast()