From 4c36d01005bc3e4b7240c93e98d3041757527a2b Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 12 May 2022 20:00:41 +0800 Subject: [PATCH 1/6] simplify (x * 96) % 64 to (x * 32) % 64 --- src/arith/rewrite_simplify.cc | 14 +++--- src/tir/schedule/primitive/compute_at.cc | 2 +- .../unittest/test_arith_rewrite_simplify.py | 4 +- .../unittest/test_tir_schedule_compute_at.py | 44 +++++++++++++++++++ 4 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index dab78c77a0a1..916069153045 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -930,22 +930,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here - TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), + c2.Eval()->value != 0); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, c1.Eval()->value > 0 && c2.Eval()->value > 0 && c2.Eval()->value % c1.Eval()->value == 0 && CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), + c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), + c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 2a349f8fe61e..7f1d74ac2021 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -244,7 +244,7 @@ class ScopeReconstructor : private StmtMutator { if (preserve_unit_loops || !is_one(iter_dom->extent)) { Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32)); loop_vars.push_back(var); - loop_extents.push_back(iter_dom->extent); + loop_extents.push_back(analyzer->Simplify(iter_dom->extent)); iter_values.push_back(iter_dom->min + var); analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent)); } else { diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 8d26710f40db..4627677cfd52 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -549,15 +549,17 @@ def test_mod_index_simplify(): def test_floormod_index_simplify(): # short name for floordiv flm = tvm.te.floormod - ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") ck = RewriteChecker() x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"), te.var("z") ck.verify(flm(x * 10, 2), 0) + ck.verify(flm(x * 9600, 6400), flm(x * 3200, 6400)) ck.verify(flm(x * 10 + y, 2), flm(y, 2)) + ck.verify(flm(x * 360 + y, 16), flm(x * 8 + y, 16)) ck.verify(flm(x + 10, 2), flm(x, 2)) ck.verify(flm(x + y * 10, 2), flm(x, 2)) + ck.verify(flm(x + y * 360, 16), flm(x + y * 8, 16)) ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1) ck.verify(flm(x * (-10), 2), 0) ck.verify(flm(x * (-10) + y, 2), flm(y, 2)) diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index b06dcebe1d1c..9204ee90d7bb 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1249,6 +1249,50 @@ def test_compute_at_simplify_static_bound(): verify_trace_roundtrip(sch=sch, mod=static_bound) +def test_compute_at_non_perfect_channel_group(): + @T.prim_func + def grouped_channel_bias( + X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] + ): + B = T.alloc_buffer([45], dtype="float32", scope="") + for i in T.grid(45): + with T.block("init"): + vi = T.axis.remap("S", [i]) + B[vi] = vi + for c_o, h, w, c_i in T.grid(2, 8, 8, 360): + with T.block("compute"): + hh, ww = T.axis.remap("SS", [h, w]) + cc = T.axis.spatial(720, c_o * 360 + c_i) + Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] + + @T.prim_func + def grouped_channel_bias_non_perfect_tiled( + X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] + ): + B = T.alloc_buffer([45], dtype="float32") + for c_o in range(2): + for ax0 in range(23): + with T.block("init"): + vi = T.axis.spatial(45, c_o * 360 // 16 + ax0) + B[vi] = vi + for h, w, c_i in T.grid(8, 8, 360): + with T.block("compute"): + hh, ww = T.axis.remap("SS", [h, w]) + cc = T.axis.spatial(720, c_o * 360 + c_i) + Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] + + def check_sched(debug_mask): + sch = tir.Schedule(grouped_channel_bias, debug_mask=debug_mask) + loop = sch.get_loops(sch.get_block("compute"))[0] + sch.compute_at(sch.get_block("init"), loop) + tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) + + check_sched("none") + with pytest.raises(tvm.TVMError, match="region_cover"): + # TODO: try fix region cover proof + check_sched("all") + + def test_fail_subtree_complete_block(): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") block = sch.get_block("B_0") From 8ce069daedd7010e38119ff157031279911a8f65 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Fri, 13 May 2022 12:28:08 +0800 Subject: [PATCH 2/6] adapt merge mulmod opt for OffsetOf computation --- src/tir/ir/buffer.cc | 17 ++++++++++------- tests/python/unittest/test_tir_buffer.py | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index ccf186634b8a..dffb8b499285 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -75,13 +75,15 @@ inline std::vector ExprSplitAddition(const PrimExpr& expr) { } // Searches for the following types of expr: -// mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki -// mod_l_expr = c +// mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki +// mod_l_expr = c2 // mod_r_expr = k1 * k2 * ... * ki -// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) +// where c1 ~= c2 mod k1 * k2 * ... * ki +// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(const PrimExpr& mult_expr, +inline std::pair MergeMulModInner(arith::Analyzer* analyzer, + const PrimExpr& mult_expr, const PrimExpr& mod_l_expr, const PrimExpr& mod_r_expr) { using namespace tir; @@ -119,9 +121,10 @@ inline std::pair MergeMulModInner(const PrimExpr& mult_expr, } else if (inner_div_ptr) { PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && - expr_equal(inner_div_ptr->a, mod_l_expr)) { + analyzer->CanProveEqual(floormod(inner_div_ptr->a - mod_l_expr, mod_r_expr), 0)) { // Found! - PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; + PrimExpr ret = + no_opt_sum.get() ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a; return std::make_pair(true, ret); } else { return std::make_pair(false, PrimExpr()); @@ -204,7 +207,7 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { std::pair ret = - MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); + MergeMulModInner(analyzer, *mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { inner_find_opt = true; auto temp_mod_it = search_mod_it; diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 337f9cbc0722..10e827978cc0 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -137,6 +137,7 @@ def assert_simplified_equal(index_simplified, index_direct): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod + # Test Case1 index_simplified = A_stride.offset_of( (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1) @@ -174,7 +175,7 @@ def assert_simplified_equal(index_simplified, index_direct): j = te.size_var("j") k = te.size_var("k") - index_simplified = B.offset_of( + index_simplified1 = B.offset_of( ( idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), @@ -182,8 +183,17 @@ def assert_simplified_equal(index_simplified, index_direct): idxm((i * 50176 + j * 28672 + k), 1024), ) ) + index_simplified2 = B.offset_of( + ( + idxd(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14), + idxm(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14), + idxm(i * 7 + idxd(k, 1024), 14), + idxm(k, 1024), + ) + ) index_direct = B.offset_of((0, 0, 0, (i * 50176 + j * 28672 + k))) - assert_simplified_equal(index_simplified, index_direct) + assert_simplified_equal(index_simplified1, index_direct) + assert_simplified_equal(index_simplified2, index_direct) @tvm.testing.requires_llvm From e02844340a4f0df9ce30e9f3854d021e2d6d9a6b Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 23 May 2022 20:57:14 +0800 Subject: [PATCH 3/6] merge DetectIterMap and DetectIterMapPadded --- include/tvm/arith/iter_affine_map.h | 103 ++--- python/tvm/arith/iter_affine_map.py | 6 +- src/arith/int_set.cc | 5 +- src/arith/iter_affine_map.cc | 391 ++++++++++-------- src/arith/pattern_match.h | 2 + src/arith/rewrite_simplify.cc | 58 ++- src/arith/rewrite_simplify.h | 2 + src/tir/ir/index_map.cc | 22 +- src/tir/schedule/analysis/analysis.cc | 8 +- src/tir/schedule/analysis/layout.cc | 7 +- src/tir/schedule/primitive/compute_inline.cc | 5 +- .../unittest/test_arith_iter_affine_map.py | 89 +++- .../unittest/test_arith_rewrite_simplify.py | 10 +- .../unittest/test_tir_schedule_compute_at.py | 7 +- 14 files changed, 416 insertions(+), 299 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 4cf6f086d1ed..7881b7ca714f 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); }; +/*! \brief Mapping level for iterators. */ +enum IterMapLevel { + // Require the mapping to be bijective. + Bijective = 0, + // Require the mapping to be subjective. + Surjective = 1, + // Require the mapping to be injective. + Injective = 2 +}; + /*! - * \brief Detect if indices can be written as - * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] - * - * Here y = some-quasi-affine-iter-map(input_iters) - * and c are symbolic constants. - * - * We also requires that y_i and y_j to be independent for i != j. - * - * For returned value rv, the following is always true: - * - rv[i]->args.size() <=1: only one iterator per element. - * - * \param indices The indices to detect pattern for. - * \param input_iters Map from variable to iterator's range. - * \param predicate The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. - * \param analyzer Analyzer used to get context information. - * \param simplify_trivial_iterators If true, iterators with extent of - * 1 will be replaced with a constant value. - * - * \return The detected pattern if a match exists, - * otherwise return an empty array. + * \brief Result of DetectIterMap. */ -Array DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); +class IterMapResultNode : public Object { + public: + // The detected pattern if a match exists. + Array indices; -/*! \brief A utility struct for return values from DetectPaddedIterMap - */ -struct PaddedIterMapResult { // Any errors that occurred while converting the input indices. If // the array is empty, the conversion was successful. Array errors; - // The detected pattern if a match exists. - Array indices; - - /* \brief Boolean expression indicating if padding was required - * - * `requires_padding` evaluates to true if the returned indices - * contain padding relative to the provided expressions, and false - * otherwise. If `input_iters` contains a variable extent, this - * expression may be in terms of those variables. - */ - PrimExpr requires_padding; - - /* \brief Boolean expression indicating if a specific value w + /*! \brief Boolean expression indicating if a specific value w * * `padding_predicate` evaluates to true for a set of indices that * are outside the bounds of the provided index iterators, but @@ -314,43 +290,54 @@ struct PaddedIterMapResult { * `input_iters`. */ PrimExpr padding_predicate; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("errors", &errors); + v->Visit("indices", &indices); + v->Visit("padding_predicate", &padding_predicate); + } + + static constexpr const char* _type_key = "arith.IterMapResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object); +}; + +/*! + * \brief Managed reference to IterMapResultNode. + * \sa IterMapResultNode + */ +class IterMapResult : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(IterMapResult, ObjectRef, IterMapResultNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMapResultNode); }; /*! * \brief Detect if indices can be written as * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] * - * Here y = some-quasi-affine-iter-map(input_iters) and c are - * symbolic constants. The y_i iterators may be padded to fit this - * representation. + * Here y = some-quasi-affine-iter-map(input_iters) + * and c are symbolic constants. * * We also requires that y_i and y_j to be independent for i != j. * * For returned value rv, the following is always true: - * - rv.indices[i]->args.size() <=1: only one iterator per element. + * - rv[i]->args.size() <=1: only one iterator per element. * * \param indices The indices to detect pattern for. - * * \param input_iters Map from variable to iterator's range. - * * \param predicate The predicate constraints on the input iterators - * - * \param require_bijective A boolean flag that indicates whether the - * mapping should be bijective. If true, no padding may be - * introduced. - * + * \param check_level The iter mapping check level. * \param analyzer Analyzer used to get context information. - * * \param simplify_trivial_iterators If true, iterators with extent of * 1 will be replaced with a constant value. * - * \return An instance of PaddedIterMapResult. + * \return The detected iteration result. + * The return object's .indices is empty on failure. */ -PaddedIterMapResult DetectPaddedIterMap(const Array& indices, - const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, - bool simplify_trivial_iterators = true); +IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 2be939a12277..701474a92b6a 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -117,14 +117,14 @@ def detect_iter_map( Returns ------- - results : List[IterSumExpr] + results : IterMapResult The iter map matching result. - Empty array if no match can be found. + The result's .indices is empty array if no match can be found. """ return _ffi_api.DetectIterMap( indices, input_iters, predicate, require_bijective, simplify_trivial_iterators - ) + ).indices def normalize_iter_map_to_expr(expr): diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a3fa879afa27..48fae479b042 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -867,9 +867,10 @@ Optional> EstimateRegionLowerBound(const Array& region, for (const Range& range : region) { affine_indices.push_back(range->min); } - iter_sum_exprs = DetectIterMap( + auto res = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, - /*predicate=*/predicate, /*require_bijective=*/false, analyzer); + /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); + iter_sum_exprs = res->indices; } if (iter_sum_exprs.empty()) { return NullOpt; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 9fad3b2816a1..e29afc4c82d8 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -241,7 +241,7 @@ class IterMapRewriter : public ExprMutator { * - bindings = [x / 3] will not pass because x / 3 can not be one split of x * \return whether the bindings are valid */ - bool CheckMapping(const Array& bindings, bool require_bijective) { + bool CheckMapping(const Array& bindings, IterMapLevel check_level) { IterMarkSplitCollector collector; // We can check that for each iter mark: // All the splits that refers to the iter_mark covers its extent. @@ -249,11 +249,11 @@ class IterMapRewriter : public ExprMutator { collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { - if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { + if (TryNormalizeSplits(mark, collector.mark2splits_[mark], check_level).empty()) { return false; } } - if (require_bijective) { + if (check_level == IterMapLevel::Bijective) { // all input marks must be visited for (const IterMark& mark : input_marks_) { if (collector.visited_.count(mark) == 0 && !is_one(mark->extent)) { @@ -375,11 +375,11 @@ class IterMapRewriter : public ExprMutator { }; struct IterPaddingInfo { - // Used and collected during first pass - std::vector divisors; + // GCD of padding factor collected during first pass + PrimExpr padding_factor{1}; // Defined on first encounter in second pass - IterSplitExpr padded; + IterMark padded; PrimExpr left_pad; PrimExpr right_pad; }; @@ -427,12 +427,12 @@ class IterMapRewriter : public ExprMutator { // input iter marks std::vector input_marks_; - // Map from a normal PrimExpr to the padded iterator information for + // Map from a split iter to the padded iterator information for // it. This is necessary for introducing the same padding in all // usage of an input iterator. (e.g. (i-1) occurring in the // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be // left-padded by 31 for each occurrence.) - std::unordered_map padded_iter_map_; + std::unordered_map padded_iter_map_; /* If allow_padding_ is true, allow the extents of the IterMap to be * padded beyond the original iterators. @@ -504,6 +504,25 @@ class IterMapRewriter : public ExprMutator { // The flattened forms of constrained iters std::vector constrained_iters_flattened_; + /*! + * \brief Extract original iteration mark's extent before padding, return NullOpt is + * there is no extra padding. + */ + Optional ExtractExtentBeforePadding(const IterMark& mark, Analyzer* analyzer) { + const IterSumExprNode* sum = mark->source.as(); + if (!sum || sum->args.size() != 1) { + return NullOpt; + } + IterSplitExpr split = sum->args[0]; + if (!analyzer->CanProveEqual(split->extent, mark->extent) && + analyzer->CanProveEqual(split->scale, 1) && + analyzer->CanProveEqual(split->lower_factor, 1) && + analyzer->CanProveEqual(split->source->extent, split->extent)) { + return sum->args[0]->extent; + } + return NullOpt; + } + /*! * \brief Look for a split in splits that is not used such that its lower_factor is smallest. * Note that here we use division to compare lower_factor. @@ -538,13 +557,12 @@ class IterMapRewriter : public ExprMutator { * If not, return an empty array. * \param mark The iterator of interest. * \param splits The splits to be verified. - * \param require_bijective A boolean flag that indicates whether the bindings should be - * bijective. + * \param check_level Iteration mapping's check level. * \return The normalized splits. */ Array TryNormalizeSplits(const IterMark& mark, const std::vector& splits, - bool require_bijective) { + IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); @@ -559,7 +577,7 @@ class IterMapRewriter : public ExprMutator { } if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective - if (require_bijective) { + if (check_level == IterMapLevel::Bijective) { return Array(); } // look for the next split skipping this lower factor @@ -578,17 +596,51 @@ class IterMapRewriter : public ExprMutator { expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; } + // Extract padding info of the iteration mark, extent before padding + // is only defined when padding exists. + Optional extent_before_padding = ExtractExtentBeforePadding(mark, analyzer_); + + bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent); + bool match_iter_divisor = + match_full_iter || CanProveDivisible(mark->extent, expected_lower_factor); + // Case 1. bijective is required. - // We check the extent we calculate is consistent with the extent of the mark - // Case 2. bijective is not required. + // We check the extent we calculate is consistent with the extent of the mark and + // iteration mark's padding is not allowed. + // + // Case 2. bijective is not required and there is no padding. // We check the extent we calculate is a factor of the extent of the mark // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. - if (require_bijective) { - if (!analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) { + // + // Case 3. bijective is not required and there exists padding. We check either + // (3.1) The extent we calculate is consistent with the extent of the padded mark and it is + // single split. + // For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective + // according to how we pad the original iteration mark. + // (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent + // before + // padding is greater or equal than the extent we calculate. + // For example, padded iter p in [0, 24), the original extent is 14, [(p % 12)] is + // valid. + // + if (check_level == IterMapLevel::Bijective) { + if (extent_before_padding.defined() || !match_full_iter) { return Array(); } - } else { - if (!CanProveDivisible(mark->extent, expected_lower_factor)) { + } else if (!extent_before_padding.defined()) { + if (!match_iter_divisor) { + return Array(); + } + } else if (check_level == IterMapLevel::Surjective) { + if (match_full_iter) { + if (splits.size() != 1) { + return Array(); + } + } else if (match_iter_divisor) { + if (!analyzer_->CanProve(extent_before_padding.value() >= expected_lower_factor)) { + return Array(); + } + } else { return Array(); } } @@ -1018,40 +1070,25 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { return true; } -Array DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, bool simplify_trivial_iterators) { - auto padded_result = DetectPaddedIterMap(indices, input_iters, predicate, require_bijective, - analyzer, simplify_trivial_iterators); - if (padded_result.errors.size()) { - return Array(); - } - if (!analyzer->CanProve(!padded_result.requires_padding)) { - return Array(); - } - return padded_result.indices; -} - -PaddedIterMapResult DetectPaddedIterMap(const Array& indices, - const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, - bool simplify_trivial_iterators) { - PaddedIterMapResult result; +IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, bool simplify_trivial_iterators) { + IterMapResult result_obj = IterMapResult(make_object()); + auto result = result_obj.CopyOnWrite(); // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. if (!IterRangeSanityCheck(input_iters)) { - result.errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); - return result; + result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); + return result_obj; } Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { - result.errors.push_back("Could not parse predicate as constraints on the input iterators."); - return result; + result->errors.push_back("Could not parse predicate as constraints on the input iterators."); + return result_obj; } // We have to make sure when we visit an iterator, all the constraints related with its successors // in the iter var graph has been visited, where the expression of this iterator will contain the @@ -1065,22 +1102,22 @@ PaddedIterMapResult DetectPaddedIterMap(const Array& indices, [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; }); IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators, - &result.errors); + &result->errors); // Step0.0: rewrite constraints in the order from size-small ones to size-big ones for (const IterConstraint& constraint : constraints) { auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound); - if (result.errors.size()) { - return result; + if (result->errors.size()) { + return result_obj; } } if (!rewriter.CheckConstraints()) { - result.errors.push_back("Invalid constraints."); - return result; + result->errors.push_back("Invalid constraints."); + return result_obj; } // Step0.1: Check each index to determine required padding - bool allow_padding = !require_bijective; + bool allow_padding = check_level != IterMapLevel::Bijective; if (allow_padding) { for (PrimExpr value : indices) { rewriter.UpdatePadding(value); @@ -1088,27 +1125,27 @@ PaddedIterMapResult DetectPaddedIterMap(const Array& indices, } // Step0.2: rewrite indices + Array rewrite_indices; for (PrimExpr value : indices) { - result.indices.push_back(rewriter.Rewrite(value)); - if (result.errors.size()) { - return result; + rewrite_indices.push_back(rewriter.Rewrite(value)); + if (result->errors.size()) { + return result_obj; } } - result.requires_padding = rewriter.requires_padding(); - result.padding_predicate = rewriter.padding_predicate(); + result->padding_predicate = rewriter.padding_predicate(); // Step1: IterIndependenceChecker checks if the iterator are independent. - if (!rewriter.CheckMapping(result.indices, require_bijective)) { - if (require_bijective) { - result.errors.push_back("Index mapping does not form a bijective transform."); + if (!rewriter.CheckMapping(rewrite_indices, check_level)) { + if (check_level == IterMapLevel::Bijective) { + result->errors.push_back("Index mapping does not form a bijective transform."); } else { - result.errors.push_back("Mapped indices are not independent."); + result->errors.push_back("Mapped indices are not independent."); } - return result; + return result_obj; } - - return result; + result->indices = rewrite_indices; + return result_obj; } TVM_REGISTER_GLOBAL("arith.DetectIterMap") @@ -1116,7 +1153,8 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap") const PrimExpr& input_pred, bool is_bijective, bool simplify_trivial_iterators) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, + auto check_level = is_bijective ? IterMapLevel::Bijective : IterMapLevel::Surjective; + return DetectIterMap(indices, input_iters, input_pred, check_level, &ana, simplify_trivial_iterators); }); @@ -1246,15 +1284,17 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o auto split = Downcast(dividend); return IterSumExpr({split}, make_zero(split.dtype())); } else if (dividend->IsInstance()) { - auto opt_fused = TryFuseIters(Downcast(dividend)); + auto sum = Downcast(dividend); + if (sum->args.size() <= 1) { + return sum; + } + auto opt_fused = TryFuseIters(sum); if (!opt_fused) { ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend) << ", can't be written as a single fused IterSum"; return IterSumExpr(); } - IterSumExpr fused = opt_fused.value(); - ICHECK_EQ(fused->args.size(), 1U); return fused; } else { @@ -1263,140 +1303,132 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } } +PrimExpr NearLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) { + auto fsplit = [](const PrimExpr& e) -> std::pair { + if (const IntImmNode* imm = e.as()) { + return {1, imm->value}; + } + PVar pv; + PVar pc; + if ((pv * pc).Match(e) || (pc * pv).Match(e)) { + return {pv.Eval(), pc.Eval()->value}; + } else { + return {e, 1}; + } + }; + + auto p1 = fsplit(a); + auto p2 = fsplit(b); + auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second)); + if (analyzer->CanProveEqual(p1.first, p2.first)) { + return p1.first * const_lcm; + } else { + return (p1.first * p2.first) * const_lcm; + } +} + std::pair IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExpr base, PrimExpr divisor) { // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor // If FloorMod: (((source//lower_factor) % extent) + base) % divisor - PrimExpr lookup_key = split; - - auto modified_divisor = [&]() { - if (update_iterator_padding_) { - return divisor; - } - - auto it = padded_iter_map_.find(lookup_key); - if (it == padded_iter_map_.end()) { - return divisor; - } - - const std::vector& divisors = it->second.divisors; - PrimExpr largest_divisor = divisor; - for (const auto& other : divisors) { - if (CanProveDivisible(other, largest_divisor)) { - // New one is bigger, use it - largest_divisor = other; - } else if (CanProveDivisible(largest_divisor, other)) { - // Current is bigger, keep it - } else { - ErrorLogger(this) << "Iterator appears in multiple terms with incompatible divisors " - << tvm::PrettyPrint(largest_divisor) << " and " - << tvm::PrettyPrint(other); - } - } - return largest_divisor; - }(); - - divisor = modified_divisor; - + // Update current iteration split's padding. // First, adding any padding that is on the lower side of a // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0 // when iter==0. - - PrimExpr left_pad; - - if (is_zero(base)) { - // Padding on the left is unnecessary if base is known to be zero. - left_pad = make_zero(base->dtype); - } else { - left_pad = analyzer_->Simplify(floormod(base, divisor)); - } + PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor)); // Next, adding any padding that is on the upper side of a // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, divisor) == 0 // when iter==extent. - PrimExpr right_edge = left_pad + split->extent; PrimExpr right_pad; - if (CanProveDivisible(right_edge, divisor)) { - // Padding on the right is unnecessary if the extent is a multiple of - // the divisor. right_pad = 0; } else { - right_pad = analyzer_->Simplify(floormod(-right_edge, divisor)); - } - - if (is_zero(left_pad) && is_zero(right_pad)) { - return {split, left_pad}; + right_pad = analyzer_->Simplify(floormod(-right_edge, divisor), 9); } if (update_iterator_padding_) { + IterMark mark = split->source; + auto& info = padded_iter_map_[mark]; + info.padding_factor = + NearLeastCommonMultiple(info.padding_factor, divisor * split->lower_factor, analyzer_); + + if (is_zero(left_pad) && is_zero(right_pad)) { + return {split, 0}; + } + // In the first pass, the primary goal is to collect all the divisors // that may be used for padding. These will impact the divisor used // to determine padding in the second pass. - IterPaddingInfo& info = padded_iter_map_[lookup_key]; - - info.divisors.push_back(divisor); + PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + right_pad); - PrimExpr padded_extent = left_pad + split->extent + right_pad; - - IterSumExpr as_sum({split}, left_pad); - IterMark mark(as_sum, padded_extent); - IterSplitExpr new_split(mark); - - return {new_split, left_pad}; + PrimExpr mark_left_pad = left_pad * split->lower_factor; + if (!is_zero(left_pad)) { + if (info.left_pad.defined()) { + info.left_pad = max(info.left_pad, mark_left_pad); + } else { + info.left_pad = mark_left_pad; + } + } + split.CopyOnWrite()->extent = padded_extent; + return {split, left_pad}; } - // Any padding that is required during parsing should have been found - // during the first pass that determines the GCD. - auto it = padded_iter_map_.find(lookup_key); + // In the second pass, update iteration mark's to padded + const IterMark& mark = split->source; + auto it = padded_iter_map_.find(mark); if (it == padded_iter_map_.end()) { - ErrorLogger(this) << "Dividend has extent " << tvm::PrettyPrint(split->extent) << " and offset " - << tvm::PrettyPrint(base) << ", which requires padding for divisor " - << tvm::PrettyPrint(divisor) << "."; - return {IterSplitExpr(), left_pad}; + return {split, left_pad}; } - IterPaddingInfo& info = it->second; - - if (info.padded.defined()) { - // A previous visit already applied padding to this iterator. - // (e.g. Visiting `(i+1)//4`, then visiting `(i+1)%4`). - ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad)); - ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad)); - - return {info.padded, left_pad}; + auto& info = it->second; + if (is_zero(info.left_pad.defined() ? info.left_pad : 0) && + CanProveDivisible(mark->extent, info.padding_factor)) { + return {split, left_pad}; } - // This is the first encounter with the iterator during the second pass. - IterSumExpr as_sum({split}, left_pad); - IterMark mark(as_sum, left_pad + split->extent + right_pad); - info.padded = IterSplitExpr(mark); - info.left_pad = left_pad; - info.right_pad = right_pad; - - auto left_padding_introduced = (left_pad != 0); - // Equivalent to (0 <= split < left_pad), but easier to simplify in - // terms of the transformed variables. - auto left_padding_predicate = - left_padding_introduced && (floordiv(info.padded, divisor) == floordiv(base, divisor) && - floormod(info.padded, divisor) < left_pad); - - PrimExpr nparts = ceildiv(right_edge, divisor); - - auto right_padding_introduced = (right_pad != 0); - - // Equivalent to (right_edge <= split < right_edge+right_pad), but - // easier to simplify in terms of the transformed variables. - auto right_padding_predicate = right_padding_introduced && - (floordiv(info.padded, divisor) == floordiv(right_edge, divisor) && - floormod(info.padded, divisor) >= floormod(right_edge, divisor)); - - requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); - padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); - - return {info.padded, left_pad}; + if (!info.padded.defined()) { + PrimExpr mark_left_pad = info.left_pad.defined() ? info.left_pad : 0; + PrimExpr mark_right_pad; + if (CanProveDivisible(mark->extent + mark_left_pad, info.padding_factor)) { + mark_right_pad = 0; + } else { + mark_right_pad = floormod(-(mark->extent + mark_left_pad), info.padding_factor); + } + PrimExpr padded_extent = analyzer_->Simplify(mark_left_pad + mark->extent + mark_right_pad); + info.right_pad = mark_right_pad; + info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent); + + auto left_padding_introduced = (mark_left_pad != 0); + PrimExpr divisor = info.padding_factor; + PrimExpr right_edge = mark_left_pad + mark->extent; + + // Equivalent to (0 <= split < left_pad), but easier to simplify in + // terms of the transformed variables. + auto left_padding_predicate = + left_padding_introduced && (floordiv(info.padded->source, divisor) == 0 && + floormod(info.padded->source, divisor) < mark_left_pad); + + auto right_padding_introduced = (mark_right_pad != 0); + + // Equivalent to (right_edge <= split < right_edge+right_pad), but + // easier to simplify in terms of the transformed variables. + auto right_padding_predicate = + right_padding_introduced && + (floordiv(info.padded->source, divisor) == floordiv(right_edge, divisor) && + floormod(info.padded->source, divisor) >= floormod(right_edge, divisor)); + + requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); + padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); + } + // ICHECK(CanProveDivisible(info.padded->extent, split->lower_factor)); + // ICHECK(CanProveDivisible(info.padded->extent, divisor)) << info.padded->extent << " " << + // divisor; + split.CopyOnWrite()->source = info.padded; + split.CopyOnWrite()->extent = floordiv(info.padded->extent, split->lower_factor); + return {split, left_pad}; } PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { @@ -1462,7 +1494,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), /* scale = */ padded->scale); - auto new_base = floordiv(base - left_pad, rhs); + auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); if (is_zero(new_base)) { return std::move(new_split); } else { @@ -1659,7 +1691,7 @@ bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs PrimExpr divisor = normalizer.Convert(rhs); return analyzer_->CanProveEqual(dividend, divisor) || - analyzer_->CanProve(floormod(dividend, divisor) == 0); + analyzer_->CanProve(analyzer_->Simplify(floormod(dividend, divisor), 8) == 0); } PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { @@ -1674,16 +1706,18 @@ Array IterMapSimplify(const Array& indices, const Map rewrite = - DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + auto check_level = require_bijective ? IterMapLevel::Bijective : IterMapLevel::Surjective; + auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer); + Array rewrite = res->indices; + if (rewrite.empty()) { return indices; } - Array res; - res.reserve(rewrite.size()); + Array simplified; + simplified.reserve(rewrite.size()); IterMapToExprNormalizer converter(&analyzer); - for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); - return res; + for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr)); + return simplified; } /*! @@ -1965,8 +1999,9 @@ Array> SubspaceDivide(const Array& bindings, const Array& sub_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer) { if (!IterRangeSanityCheck(input_iters)) return Array>(); - const Array& maps = - DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer); + auto check_level = require_bijective ? IterMapLevel::Bijective : IterMapLevel::Surjective; + auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer); + const Array& maps = res->indices; if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -2128,5 +2163,7 @@ Map InverseAffineIterMap(const Array& iter_map, TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); +TVM_REGISTER_NODE_TYPE(IterMapResultNode); + } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 7d1f315b3cb3..6abcc728fc8d 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -203,6 +203,8 @@ class PVar : public Pattern> { return value_; } + T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; } + protected: /*! \brief The matched value */ mutable T value_; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 916069153045..f9e38dee48e5 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -776,26 +776,32 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1) + c2, c3), floordiv(x + c1 * c2, c1 * c3), c1.Eval()->value > 0 && c3.Eval()->value > 0); - if (floordiv(x * c1, c2).Match(ret)) { + if (floordiv(x * c1 + y, c2).Match(ret) || floordiv(x * c1, c2).Match(ret) || + floordiv(y + x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - if (c1val > 0 && c2val > 0) { - if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval(); - if (c2val % c1val == 0) return floordiv(x, floordiv(c2, c1)).Eval(); + PrimExpr yval = y.EvalOr(Integer(0)); + if (c2val == 0) return ret; + + // try eliminate residue part + PrimExpr residue = + floordiv(x.Eval() * floormod(c1.Eval(), c2val) + floormod(yval, c2val), c2val); + PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); + auto bound = analyzer_->const_int_bound(residue); + if (bound.defined() && bound->max_value == bound->min_value) { + return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); } - } - if (floordiv(x * c1 + c2, c3).Match(ret)) { - int64_t c1val = c1.Eval()->value; - int64_t c2val = c2.Eval()->value; - int64_t c3val = c3.Eval()->value; - if (c1val > 0 && c3val > 0 && c3val % c1val == 0 && floormod(c2val, c3val) < c1val) { - // assume c3 == a * c1, x == a * y + b, c2 = d * c3 + e then - // (x * c1 + c2) // c3 - // ==> ((a * y + b) * c1 + d * a * c1 + e) // (a * c1) - // ==> y + d + (b * c1 + e) // c3 - // ==> y + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 - // ==> x // (c3 // c1) + (c2 // c3) - return (floordiv(x, floordiv(c3, c1)) + floordiv(c2, c3)).Eval(); + + // try simplify divisor + if (c1val > 0 && c2val > 0 && c2val % c1val == 0 && + CanProveLess(floormod(yval, c2val), c1val)) { + // assume c2 == a * c1, x == a * x' + b, y = d * c2 + e then + // (x * c1 + y) // c2 + // ==> ((a * x' + b) * c1 + d * a * c1 + e) // (a * c1) + // ==> x' + d + (b * c1 + e) // c2 + // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 + // ==> x // (c2 // c1) + (y // c2) + return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; } } @@ -804,28 +810,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE(floordiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -878,6 +868,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); + + TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0); } return ret; } diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 258f833a7b21..202b9209da6d 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -110,6 +110,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { return analyzer_->CanProveGreaterEqual(x, val); } + // Whether x < val + bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); } // Whether x == val bool CanProveEqual(const PrimExpr& x, int64_t val) { // TODO(tqchen) refer back to super-analyzer. diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 77678d829a8e..757c67a17f7f 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -77,16 +77,16 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // indices. arith::Analyzer analyzer; auto padded_iter_map = - DetectPaddedIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* require_bijective = */ false, &analyzer, - /* simplify_trivial_iterators = */ false); - CHECK(padded_iter_map.errors.empty()) << "Could not parse mapping as sum of iterators. " - << "Error: " << padded_iter_map.errors[0]; + DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /* check_level = */ arith::IterMapLevel::Injective, &analyzer, + /* simplify_trivial_iterators = */ false); + CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " + << "Error: " << padded_iter_map->errors[0]; // Determine expressions for the input variables, in terms of the // output variables. Map inverse_exprs_map = InverseAffineIterMap( - padded_iter_map.indices, Array(output_vars.begin(), output_vars.end())); + padded_iter_map->indices, Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; @@ -94,7 +94,7 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia inverse_exprs.push_back(inverse_exprs_map.at(index)); } - PrimExpr padding_predicate = padded_iter_map.padding_predicate; + PrimExpr padding_predicate = padded_iter_map->padding_predicate; padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); padding_predicate = Substitute(padding_predicate, inverse_exprs_map); @@ -141,14 +141,14 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { // indices. arith::Analyzer analyzer; auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* require_bijective = */ true, &analyzer, + /* require_bijective = */ arith::IterMapLevel::Bijective, &analyzer, /* simplify_trivial_iterators = */ false); - CHECK(iter_map.size()) << "Index transformation was not bijective."; + CHECK(iter_map->indices.size()) << "Index transformation was not bijective."; // Determine expressions for the input variables, in terms of the // output variables. - Map inverse_exprs_map = - InverseAffineIterMap(iter_map, Array(output_vars.begin(), output_vars.end())); + Map inverse_exprs_map = InverseAffineIterMap( + iter_map->indices, Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c4719015daa4..83ef6adae3b2 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -533,16 +533,16 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va if (loop_var_ranges.empty()) { return true; } - Array results = arith::DetectIterMap( + auto res = arith::DetectIterMap( /*indices=*/realize->iter_values, /*input_iters=*/loop_var_ranges, /*predicate=*/realize->predicate, - /*require_bijective=*/false, + /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/analyzer); - if (results.empty()) { + if (res->indices.empty()) { return false; } - for (const arith::IterSumExpr& sum_expr : results) { + for (const arith::IterSumExpr& sum_expr : res->indices) { const Array& args = sum_expr->args; if (!args.empty() && !is_one(args[0]->scale)) { return false; diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index 993557f8be2f..8c6b9d7072f0 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -77,8 +77,11 @@ class SplitExprCollector { const PrimExpr& predicate, // bool require_bijective, // arith::Analyzer* analyzer) { - Array iter_sum_exprs = arith::DetectIterMap( - {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer); + auto check_level = + require_bijective ? arith::IterMapLevel::Bijective : arith::IterMapLevel::Surjective; + arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, + predicate, check_level, analyzer); + const auto& iter_sum_exprs = res->indices; if (iter_sum_exprs.empty()) { return {}; } diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 452f72e7228f..ad15e06e285a 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -552,13 +552,14 @@ class ReverseComputeInliner : public BaseInliner { } } - buffer_load_iter_map_ = arith::DetectIterMap( + auto res = arith::DetectIterMap( /*indices=*/buffer_load_indices_, /*input_iters=*/consumer_iter_doms, /*predicate=*/true, - /*require_bijective=*/true, + /*check_level=*/arith::IterMapLevel::Bijective, /*analyzer=*/&analyzer, /*simplify_trivial_iterators=*/false); + buffer_load_iter_map_ = res->indices; if (buffer_load_iter_map_.empty()) { // Failure: indices of BufferLoad are not bijective affine return false; diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index fe766b921806..f1dde933179e 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -16,7 +16,6 @@ # under the License. import tvm import tvm.testing -from tvm import te from tvm.tir import floormod, floordiv @@ -385,8 +384,15 @@ def test_predicate(): [(i * 32 + j) % 16], var_dom([(i, 5), (j, 32)]), tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + require_bijective=True, ) assert len(res) == 0 + res = tvm.arith.detect_iter_map( + [(i * 32 + j) % 16], + var_dom([(i, 5), (j, 32)]), + tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + ) + assert_iter_sum_pattern(res[0], 16, 0) res = tvm.arith.detect_iter_map( [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16], var_dom([(i, 5), (j, 32)]), @@ -944,5 +950,86 @@ def test_free_variables(): assert_iter_sum_pattern(res[0], 9, z * z) +def test_padding(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + def assert_padding_pattern(expect_dict, dom_map, predicate=True, require_bijective=False): + keys = list(expect_dict.keys()) + res = tvm.arith.detect_iter_map( + keys, dom_map, predicate=predicate, require_bijective=require_bijective + ) + assert len(res) == len(keys) + print(res) + for i, iter_expr in enumerate(keys): + extent, base, scale = expect_dict[iter_expr] + assert_iter_sum_pattern(res[i], extent, base, scale) + + def assert_padding_pattern_failure(iters, dom_map, predicate=True, require_bijective=False): + res = tvm.arith.detect_iter_map( + list(iters), dom_map, predicate=predicate, require_bijective=False + ) + assert len(res) == 0 + + # left padding only, offset divisible + sum = 64 + y + dom_map = var_dom([(y, 192)]) + assert_padding_pattern( + {fld(sum, 32): (6, 2, 1), flm(sum, 32): (32, 0, 1)}, + dom_map, + require_bijective=True, + ) + + # left padding only, offset non-divisible + sum = 80 + y + dom_map = var_dom([(y, 176)]) + assert_padding_pattern( + {fld(sum, 32): (6, 2, 1)}, + dom_map, + ) + assert_padding_pattern( + {flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)}, + dom_map, + ) + assert_padding_pattern_failure({fld(sum, 32), flm(sum, 32)}, dom_map) + + # right padding only, offset divisible + sum = x * 32 + y * 8 + dom_map = var_dom([(x, 5), (y, 4)]) + assert_padding_pattern( + {fld(sum, 16): (10, 0, 1), flm(sum, 16): (2, 0, 8)}, + dom_map, + ) + assert_padding_pattern_failure({fld(sum, 5)}, dom_map) + + # right padding only, offset non-divisible + dom_map = var_dom([(x, 26)]) + assert_padding_pattern( + {fld(x, 15): (2, 0, 1)}, + dom_map, + ) + assert_padding_pattern( + {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}, + dom_map, + ) + + # padding constants on both side + sum = x + 71 + dom_map = var_dom([(x, 45)]) + assert_padding_pattern({fld(sum, 32): (2, 2, 1)}, dom_map) + assert_padding_pattern( + {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}, + dom_map, + ) + + # padding for free iteration part + sum = x * 360 + y + dom_map = var_dom([(y, 360)]) + assert_padding_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map) + assert_padding_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 4627677cfd52..82e1372f991e 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -459,11 +459,13 @@ def test_div_index_simplify(): def test_floordiv_index_simplify(): # short name for floordiv fld = tvm.te.floordiv + flm = tvm.te.floormod ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") ck.verify(fld(fld(x, 2), 3), fld(x, 6)) ck.verify(fld(fld(x, 2) + 1, 3), fld(x + 2, 6)) + ck.verify(fld(x - flm(x, 21), 21), fld(x, 21)) ck.verify(fld(x * 2, 4), fld(x, 2)) ck.verify(fld(x * 4, 2), x * 2) @@ -472,11 +474,17 @@ def test_floordiv_index_simplify(): ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + -1, 16)) ck.verify(fld(x * 8 - 9, 16), fld(x, 2) + -1) + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1), override=True) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 7), override=True) + ck.verify(fld(x * 360 + y, 16), x * 22) + ck.verify(fld(x * 360 + y, 25), x * 14) + ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)) + ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))) ck.verify(fld(tvm.te.max(x * 6, y), 2), tvm.te.max(x * 3, fld(y, 2))) - ck.verify(fld(y + x * 4, 2), fld(y, 2) + x * 2) + ck.verify(fld(y + x * 4, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3)) ck.verify(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3)) diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 9204ee90d7bb..c7552942cf7e 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1273,7 +1273,7 @@ def grouped_channel_bias_non_perfect_tiled( for c_o in range(2): for ax0 in range(23): with T.block("init"): - vi = T.axis.spatial(45, c_o * 360 // 16 + ax0) + vi = T.axis.spatial(45, c_o * 22 + ax0) B[vi] = vi for h, w, c_i in T.grid(8, 8, 360): with T.block("compute"): @@ -1287,10 +1287,7 @@ def check_sched(debug_mask): sch.compute_at(sch.get_block("init"), loop) tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) - check_sched("none") - with pytest.raises(tvm.TVMError, match="region_cover"): - # TODO: try fix region cover proof - check_sched("all") + check_sched("all") def test_fail_subtree_complete_block(): From 4344b68f8ca0aa37753eb06b85918fd5d5a68db1 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 25 May 2022 15:14:15 +0800 Subject: [PATCH 4/6] adjust related interfaces for IterMapLevel --- include/tvm/arith/iter_affine_map.h | 16 +- python/tvm/arith/iter_affine_map.py | 51 +- src/arith/iter_affine_map.cc | 145 ++-- src/tir/ir/index_map.cc | 9 +- src/tir/schedule/analysis/layout.cc | 8 +- .../schedule/primitive/blockize_tensorize.cc | 7 +- .../schedule/primitive/loop_transformation.cc | 2 +- .../unittest/test_arith_iter_affine_map.py | 619 ++++++++---------- 8 files changed, 397 insertions(+), 460 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 7881b7ca714f..aea67d54584c 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -263,10 +263,10 @@ class IterSumExpr : public IterMapExpr { enum IterMapLevel { // Require the mapping to be bijective. Bijective = 0, - // Require the mapping to be subjective. + // Require the mapping to be surjective. Surjective = 1, - // Require the mapping to be injective. - Injective = 2 + // No mapping safety check. + NoCheck = 3 }; /*! @@ -327,7 +327,7 @@ class IterMapResult : public ObjectRef { * \param indices The indices to detect pattern for. * \param input_iters Map from variable to iterator's range. * \param predicate The predicate constraints on the input iterators - * \param check_level The iter mapping check level. + * \param check_level The iter mapping checking level. * \param analyzer Analyzer used to get context information. * \param simplify_trivial_iterators If true, iterators with extent of * 1 will be replaced with a constant value. @@ -345,12 +345,12 @@ IterMapResult DetectIterMap(const Array& indices, const Map IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool require_bijective); + const PrimExpr& input_pred, IterMapLevel check_level); /*! * \brief Apply the inverse of the affine transformation to the outputs. @@ -390,7 +390,7 @@ Map InverseAffineIterMap(const Array& iter_map, * \param input_iters Map from variable to iterator's range. * \param sub_iters Iterators of subspace. * \param predicate The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * \param check_level The iter mapping checking level. * \param analyzer Analyzer used to get context information. * * \return The result list has length len(bindings) + 1 @@ -403,7 +403,7 @@ Map InverseAffineIterMap(const Array& iter_map, Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer); + IterMapLevel check_level, arith::Analyzer* analyzer); /*! * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 701474a92b6a..77d6f418b853 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """ Iterator (quasi)affine mapping patterns.""" +from enum import IntEnum import tvm._ffi from tvm.runtime import Object from tvm.ir import PrimExpr @@ -88,11 +89,35 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) +class IterMapLevel(IntEnum): + """Possible kinds of iter mapping check level.""" + + Bijective = 0 + Surjective = 1 + NoCheck = 3 + + @staticmethod + def from_str(name: str): + """Helper to create level enum from string""" + if name is None: + return IterMapLevel.NoCheck + name = name.lower() + if name == "bijective": + check_level = IterMapLevel.Bijective + elif name == "surjective": + check_level = IterMapLevel.Surjective + elif name == "nocheck": + check_level = IterMapLevel.NoCheck + else: + raise ValueError(f"Unknown check level {name}") + return check_level + + def detect_iter_map( indices, input_iters, predicate=True, - require_bijective=False, + check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True, ): """Detect if indices can be written as mapped iters from input iters @@ -108,8 +133,8 @@ def detect_iter_map( predicate : PrimExpr The predicate constraints on the input iterators - require_bijective : bool - A boolean flag that indicates whether the mapping should be bijective + check_level : Union[str, IterMapLevel] + Checking level of iteration mapping simplify_trivial_iterators: bool If true, iterators with extent of 1 will be replaced with a @@ -122,9 +147,13 @@ def detect_iter_map( The result's .indices is empty array if no match can be found. """ + if isinstance(check_level, str): + check_level = IterMapLevel.from_str(check_level) + elif check_level is None: + check_level = IterMapLevel.NoCheck return _ffi_api.DetectIterMap( - indices, input_iters, predicate, require_bijective, simplify_trivial_iterators - ).indices + indices, input_iters, predicate, check_level, simplify_trivial_iterators + ) def normalize_iter_map_to_expr(expr): @@ -143,7 +172,9 @@ def normalize_iter_map_to_expr(expr): return _ffi_api.NormalizeIterMapToExpr(expr) -def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False): +def subspace_divide( + bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective +): """Detect if bindings can be written as [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n] where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters) @@ -172,8 +203,8 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi predicate : PrimExpr The predicate constraints on the input iterators - require_bijective : bool - A boolean flag that indicates whether the bindings should be bijective + check_level : Union[str, IterMapLevel] + Checking level of iteration mapping Returns ------- @@ -185,7 +216,9 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi len(bindings): the predicate of outer space and inner space Empty array if no match can be found. """ - return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) + if isinstance(check_level, str): + check_level = IterMapLevel.from_str(check_level) + return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level) def inverse_affine_iter_map(iter_map, outputs): diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index e29afc4c82d8..6b724ca1d6e2 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -178,10 +178,7 @@ class IterMapRewriter : public ExprMutator { explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, bool simplify_trivial_iterators, Array* errors) - : analyzer_(analyzer), - errors_(*errors), - requires_padding_(const_false()), - padding_predicate_(const_false()) { + : analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -202,7 +199,7 @@ class IterMapRewriter : public ExprMutator { } PrimExpr padding_predicate() const { return padding_predicate_; } - PrimExpr requires_padding() const { return requires_padding_; } + PrimExpr requires_padding() const { return !analyzer_->CanProveEqual(padding_predicate_, 0); } IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); @@ -222,7 +219,7 @@ class IterMapRewriter : public ExprMutator { } /*! - * \brief If require_bijective is true, this function checks two conditions: + * \brief If require bijective mapping, this function checks two conditions: * - C0: Each iter mark should be fully covered by non-overlapping splits. * - C1: All of the input iterators are used. * Example: given x in [0, 8) y in [0, 6) @@ -232,7 +229,7 @@ class IterMapRewriter : public ExprMutator { * contribute two non-overlapping splits that covers x. * - bindings = [x / 4, x % 4] won't pass because y is not used. * - * If require_bijective is false, this function checks one condition: + * If only require surjective mapping, this function checks one condition: * - C0: Each iter mark has a chance to be fully covered by non-overlapping splits. * Example: given x in [0, 8) y in [0, 6) * - bindings = [x / 4] will pass because x / 4 can be one split of x @@ -378,10 +375,11 @@ class IterMapRewriter : public ExprMutator { // GCD of padding factor collected during first pass PrimExpr padding_factor{1}; - // Defined on first encounter in second pass + PrimExpr left_pad{0}; + PrimExpr right_pad{0}; + + // Padded form of original iter mark IterMark padded; - PrimExpr left_pad; - PrimExpr right_pad; }; // temp hash for de-duplication purposes. @@ -446,20 +444,6 @@ class IterMapRewriter : public ExprMutator { */ bool update_iterator_padding_{false}; - /* A boolean expression that is true if any padding has been introduced - * by the transformation, and false otherwise. - * - * Example: [i//4, i%4], i in range [0,16) - * requires_padding_ will be false - * - * Example: [i//4, i%4], i in range [0,18) - * requires_padding_ will be true - * - * Example: [i//4, i%4], i in range [0,N) - * requires_padding_ will be the expression N%4==0 - */ - PrimExpr requires_padding_; - /* A boolean expression that is true for any padding that has been * introduced, and false otherwise. If allow_padding_ is false, * padding_predicate_ will always be false. @@ -614,12 +598,11 @@ class IterMapRewriter : public ExprMutator { // // Case 3. bijective is not required and there exists padding. We check either // (3.1) The extent we calculate is consistent with the extent of the padded mark and it is - // single split. + // the single split for the iter mark. // For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective // according to how we pad the original iteration mark. // (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent - // before - // padding is greater or equal than the extent we calculate. + // before padding is greater or equal than the extent we calculate. // For example, padded iter p in [0, 24), the original extent is 14, [(p % 12)] is // valid. // @@ -1150,11 +1133,10 @@ IterMapResult DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, - const PrimExpr& input_pred, bool is_bijective, + const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; - auto check_level = is_bijective ? IterMapLevel::Bijective : IterMapLevel::Surjective; - return DetectIterMap(indices, input_iters, input_pred, check_level, &ana, + return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, simplify_trivial_iterators); }); @@ -1303,7 +1285,8 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } } -PrimExpr NearLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) { +/*! \brief Find approximate least common multiplier. */ +PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) { auto fsplit = [](const PrimExpr& e) -> std::pair { if (const IntImmNode* imm = e.as()) { return {1, imm->value}; @@ -1316,7 +1299,6 @@ PrimExpr NearLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* return {e, 1}; } }; - auto p1 = fsplit(a); auto p2 = fsplit(b); auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second)); @@ -1333,99 +1315,102 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor // If FloorMod: (((source//lower_factor) % extent) + base) % divisor - // Update current iteration split's padding. // First, adding any padding that is on the lower side of a - // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0 - // when iter==0. + // FloorDiv/FloorMod, such that floormod(split - left_pad, divisor) == 0 + // when iter == 0. PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor)); // Next, adding any padding that is on the upper side of a - // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, divisor) == 0 - // when iter==extent. + // FloorDiv/FloorMod, such that floormod(left_pad + split + right_pad, divisor) == 0 + // when iter == extent. PrimExpr right_edge = left_pad + split->extent; PrimExpr right_pad; if (CanProveDivisible(right_edge, divisor)) { right_pad = 0; } else { - right_pad = analyzer_->Simplify(floormod(-right_edge, divisor), 9); + right_pad = analyzer_->Simplify(floormod(-right_edge, divisor)); } + const IterMark& mark = split->source; if (update_iterator_padding_) { - IterMark mark = split->source; + // In the first pass, the primary goal is to collect all the divisors + // that may be used for padding. These will impact the divisor used + // to determine padding in the second pass. We try add padding to + // split's source iteraton mark thus all splits under the same mark will + // share the same padded source iteration. auto& info = padded_iter_map_[mark]; info.padding_factor = - NearLeastCommonMultiple(info.padding_factor, divisor * split->lower_factor, analyzer_); + ApproxLeastCommonMultiple(info.padding_factor, divisor * split->lower_factor, analyzer_); + // If the split itself require no padding, return directly. if (is_zero(left_pad) && is_zero(right_pad)) { return {split, 0}; } - // In the first pass, the primary goal is to collect all the divisors - // that may be used for padding. These will impact the divisor used - // to determine padding in the second pass. - PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + right_pad); - + // Update padding requirement on the lower side of the source iter mark. PrimExpr mark_left_pad = left_pad * split->lower_factor; - if (!is_zero(left_pad)) { - if (info.left_pad.defined()) { - info.left_pad = max(info.left_pad, mark_left_pad); - } else { - info.left_pad = mark_left_pad; - } - } + info.left_pad = max(info.left_pad, mark_left_pad); + + // Since we only care the extent in the first pass's result + // we just create result of compatible padded extent, ignoring + // possible relations between different padded iters. + PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + right_pad); split.CopyOnWrite()->extent = padded_extent; return {split, left_pad}; } - // In the second pass, update iteration mark's to padded - const IterMark& mark = split->source; + // In the second pass, update iteration mark's to padded form auto it = padded_iter_map_.find(mark); if (it == padded_iter_map_.end()) { return {split, left_pad}; } auto& info = it->second; - if (is_zero(info.left_pad.defined() ? info.left_pad : 0) && - CanProveDivisible(mark->extent, info.padding_factor)) { + if (is_zero(info.left_pad) && CanProveDivisible(mark->extent, info.padding_factor)) { + // the iter mark requires no padding return {split, left_pad}; } + // check that padding factor is compatible with current split and divisor + ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor)) + << "The padding factor " << info.padding_factor << " is not divisible by " + << split->lower_factor << " for the split " << split; + ICHECK(CanProveDivisible(info.padding_factor, divisor)) + << "The padding factor " << info.padding_factor << " is not divisible by " << divisor + << " for the split " << split; + if (!info.padded.defined()) { - PrimExpr mark_left_pad = info.left_pad.defined() ? info.left_pad : 0; + // the first time encounter the iter mark to pad, update the padded mark. + PrimExpr mark_left_pad = info.left_pad; + PrimExpr right_edge = mark->extent + mark_left_pad; PrimExpr mark_right_pad; - if (CanProveDivisible(mark->extent + mark_left_pad, info.padding_factor)) { + if (CanProveDivisible(right_edge, info.padding_factor)) { mark_right_pad = 0; } else { - mark_right_pad = floormod(-(mark->extent + mark_left_pad), info.padding_factor); + mark_right_pad = floormod(-right_edge, info.padding_factor); } - PrimExpr padded_extent = analyzer_->Simplify(mark_left_pad + mark->extent + mark_right_pad); + PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad); info.right_pad = mark_right_pad; info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent); auto left_padding_introduced = (mark_left_pad != 0); - PrimExpr divisor = info.padding_factor; - PrimExpr right_edge = mark_left_pad + mark->extent; // Equivalent to (0 <= split < left_pad), but easier to simplify in // terms of the transformed variables. auto left_padding_predicate = - left_padding_introduced && (floordiv(info.padded->source, divisor) == 0 && - floormod(info.padded->source, divisor) < mark_left_pad); - + left_padding_introduced && + (floordiv(info.padded->source, info.padding_factor) == 0 && + floormod(info.padded->source, info.padding_factor) < mark_left_pad); auto right_padding_introduced = (mark_right_pad != 0); - // Equivalent to (right_edge <= split < right_edge+right_pad), but + // Equivalent to (right_edge <= split < right_edge + right_pad), but // easier to simplify in terms of the transformed variables. auto right_padding_predicate = - right_padding_introduced && - (floordiv(info.padded->source, divisor) == floordiv(right_edge, divisor) && - floormod(info.padded->source, divisor) >= floormod(right_edge, divisor)); - - requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); + right_padding_introduced && (floordiv(info.padded->source, info.padding_factor) == + floordiv(right_edge, info.padding_factor) && + floormod(info.padded->source, info.padding_factor) >= + floormod(right_edge, info.padding_factor)); padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); } - // ICHECK(CanProveDivisible(info.padded->extent, split->lower_factor)); - // ICHECK(CanProveDivisible(info.padded->extent, divisor)) << info.padded->extent << " " << - // divisor; split.CopyOnWrite()->source = info.padded; split.CopyOnWrite()->extent = floordiv(info.padded->extent, split->lower_factor); return {split, left_pad}; @@ -1703,10 +1688,9 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool require_bijective) { + const PrimExpr& input_pred, IterMapLevel check_level) { if (!IterRangeSanityCheck(input_iters)) return indices; Analyzer analyzer; - auto check_level = require_bijective ? IterMapLevel::Bijective : IterMapLevel::Surjective; auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer); Array rewrite = res->indices; @@ -1997,9 +1981,8 @@ class SubspaceDivider { Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer) { + IterMapLevel check_level, arith::Analyzer* analyzer) { if (!IterRangeSanityCheck(input_iters)) return Array>(); - auto check_level = require_bijective ? IterMapLevel::Bijective : IterMapLevel::Surjective; auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer); const Array& maps = res->indices; if (maps.empty()) return {}; @@ -2028,10 +2011,10 @@ Array> SubspaceDivide(const Array& bindings, TVM_REGISTER_GLOBAL("arith.SubspaceDivide") .set_body_typed([](const Array& bindings, const Map& root_iters, - const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective) { + const Array& sub_iters, const PrimExpr& predicate, int check_level) { arith::Analyzer ana; - return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), + &ana); }); class InverseAffineIterMapTransformer { diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 757c67a17f7f..ba329676b1c3 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -76,10 +76,9 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto padded_iter_map = - DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* check_level = */ arith::IterMapLevel::Injective, &analyzer, - /* simplify_trivial_iterators = */ false); + auto padded_iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /*check_level=*/arith::IterMapLevel::NoCheck, &analyzer, + /*simplify_trivial_iterators=*/false); CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " << "Error: " << padded_iter_map->errors[0]; @@ -141,7 +140,7 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { // indices. arith::Analyzer analyzer; auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* require_bijective = */ arith::IterMapLevel::Bijective, &analyzer, + /* check_level = */ arith::IterMapLevel::Bijective, &analyzer, /* simplify_trivial_iterators = */ false); CHECK(iter_map->indices.size()) << "Index transformation was not bijective."; diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index 8c6b9d7072f0..b0cafac3151f 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -68,17 +68,15 @@ class SplitExprCollector { * \param index The indexing pattern * \param input_iters The input iterators' domain * \param predicate The predicate of the affine map - * \param require_bijective Whether the affine map is required to be bijective + * \param check_level The iter mapping checking level * \param analyzer The analyzer * \return The collected split expressions */ static std::vector Collect(const PrimExpr& index, const Map& input_iters, // const PrimExpr& predicate, // - bool require_bijective, // + arith::IterMapLevel check_level, // arith::Analyzer* analyzer) { - auto check_level = - require_bijective ? arith::IterMapLevel::Bijective : arith::IterMapLevel::Surjective; arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, predicate, check_level, analyzer); const auto& iter_sum_exprs = res->indices; @@ -152,7 +150,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& // Step 3. Detect the IterSplitExpr of the indexing pattern std::vector split_exprs = SplitExprCollector::Collect( /*index=*/f_flatten_index(indices), input_iters, predicate, - /*require_bijective=*/false, analyzer); + /*check_level=*/arith::IterMapLevel::Surjective, analyzer); if (split_exprs.empty()) { return NullOpt; } diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 7ed80a1c5b8f..4ede2dd90da8 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -258,10 +258,9 @@ Array> CheckSubspaceDivisible(const IRModule& mod, arith::Analyzer* analyzer) { const Block& block = block_realize->block; - Array> division = - arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, - collector.inner_loop_vars, block_realize->predicate, - /*require_bijective=*/false, analyzer); + Array> division = arith::SubspaceDivide( + block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars, + block_realize->predicate, arith::IterMapLevel::Surjective, analyzer); if (division.empty()) { // If we can't do perfect subspace division, check if it is a trivial case of subspace division. diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index dbe6a3bbc0c5..5315b139f0f6 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -115,7 +115,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, - /*require_bijective=*/false); + /*check_level=*/arith::IterMapLevel::Surjective); if (v.same_as(op->iter_values)) { return GetRef(op); } else { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index f1dde933179e..ac3ae00b8a0c 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from xml import dom import tvm import tvm.testing from tvm.tir import floormod, floordiv @@ -47,56 +48,67 @@ def convert_iter_expr(expr): return tvm.arith.normalize_iter_map_to_expr(expr) -def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): - """Check the sum expr have the right pattern.""" - assert isinstance(sum_expr, tvm.arith.IterSumExpr) - if extent == 1: - assert len(sum_expr.args) == 0 - else: - assert len(sum_expr.args) == 1 - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) - tvm.testing.assert_prim_expr_equal(sum_expr.base, base) +def assert_iter_sum_pattern( + expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True +): + keys = list(expect_dict.keys()) + res = tvm.arith.detect_iter_map( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ).indices + assert len(res) == len(keys) + for i, input_iter in enumerate(keys): + spec = expect_dict[input_iter] + ( + extent, + base, + ) = spec[0:2] + scale = spec[2] if len(spec) > 2 else 1 + expect_iter = spec[3] if len(spec) > 3 else None + sum_expr = res[i] + assert isinstance(sum_expr, tvm.arith.IterSumExpr) + if extent == 1: + assert len(sum_expr.args) == 0 + else: + assert len(sum_expr.args) == 1 + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) + tvm.testing.assert_prim_expr_equal(sum_expr.base, base) + if expect_iter is not None: + if not isinstance(expect_iter, tvm.arith.IterMapExpr): + sum_expr = convert_iter_expr(sum_expr) + tvm.ir.assert_structural_equal(sum_expr, expect_iter) + + +def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"): + res = tvm.arith.detect_iter_map( + list(iters), dom_map, predicate=predicate, check_level=check_level + ).indices + assert len(res) == 0 def test_trivial(): - x = tvm.tir.Var("x", "int32"), 3 - y = tvm.tir.Var("y", "int32"), 4 - z = tvm.tir.Var("z", "int32"), 1 - - res = tvm.arith.detect_iter_map([x[0], y[0], 3], var_dom([x, y])) - - assert len(res) == 3 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - assert_iter_sum_pattern(res[2], 1, 3) + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + dom_map = var_dom([(x, 3), (y, 4), (z, 1)]) - res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 1, 3) + assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map) + assert_iter_sum_pattern({x: (3, 0), 3: (1, 3)}, dom_map) # not independent - res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) - assert len(res) == 0 + assert_iter_sum_failure([x, x, 3], dom_map) - res = tvm.arith.detect_iter_map( - [x[0], y[0]], var_dom([x, y, z]), require_bijective=True, simplify_trivial_iterators=True + assert_iter_sum_pattern( + {x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=True ) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - - res = tvm.arith.detect_iter_map( - [x[0], y[0]], var_dom([x, y, z]), require_bijective=True, simplify_trivial_iterators=False + assert_iter_sum_pattern( + {x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=False ) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - - # not bijective - res = tvm.arith.detect_iter_map([x[0], z[0]], var_dom([x, y, z]), require_bijective=True) - assert len(res) == 0 + assert_iter_sum_failure([x, z], dom_map, check_level="bijective") def test_fuse(): @@ -105,42 +117,27 @@ def test_fuse(): c = tvm.tir.SizeVar("c", "int32") c0 = tvm.tir.SizeVar("c0", "int32") - res = tvm.arith.detect_iter_map([y * 3 + 1 + c + x], var_dom([(x, 3), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 12, 1 + c) + assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)])) - res = tvm.arith.detect_iter_map([ifuse([(x, 3), (y, 4)])[0]], var_dom([(x, 3), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 12, 0) + assert_iter_sum_pattern({ifuse([(x, 3), (y, 4)])[0]: (12, 0)}, var_dom([(x, 3), (y, 4)])) # fuse with symbolic factor - res = tvm.arith.detect_iter_map([(y + 1) * c + x], var_dom([(x, c), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 4 * c, c) + assert_iter_sum_pattern({(y + 1) * c + x: (4 * c, c)}, var_dom([(x, c), (y, 4)])) # duplication - res = tvm.arith.detect_iter_map([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 - - # duplication 2 - res = tvm.arith.detect_iter_map([y, x + 1, y], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 + assert_iter_sum_failure([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) + assert_iter_sum_failure([y, x + 1, y], var_dom([(x, 3), (y, 4)])) # factor mismatch - res = tvm.arith.detect_iter_map([y * 4 + x], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 + assert_iter_sum_failure([y * 4 + x], var_dom([(x, 3), (y, 4)])) # simple stride pattern - res = tvm.arith.detect_iter_map([x * 4 + y * 2], var_dom([(x, 3), (y, 2)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 6, 0, scale=2) - tvm.ir.assert_structural_equal(convert_iter_expr(res[0]), (x * 2 + y) * 2) + assert_iter_sum_pattern({x * 4 + y * 2: (6, 0, 2, (x * 2 + y) * 2)}, var_dom([(x, 3), (y, 2)])) # simple stride pattern with symbolic - res = tvm.arith.detect_iter_map([x * 2 * c0 + y * 2], var_dom([(x, 3), (y, c0)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 3 * c0, 0, scale=2) - tvm.ir.assert_structural_equal(convert_iter_expr(res[0]), (x * c0 + y) * 2) + assert_iter_sum_pattern( + {x * 2 * c0 + y * 2: (3 * c0, 0, 2, (x * c0 + y) * 2)}, var_dom([(x, 3), (y, c0)]) + ) def test_split(): @@ -151,171 +148,138 @@ def test_split(): fld = tvm.tir.floordiv flm = tvm.tir.floormod - res = tvm.arith.detect_iter_map([fld(x, 3), flm(x, 3) * 2 + c1], var_dom([(x, 24)])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 8, 0) - assert_iter_sum_pattern(res[1], 3, c1, 2) - - res = tvm.arith.detect_iter_map([fld(x, 6), fld(flm(x, 6), 2), flm(x, 2)], var_dom([(x, 24)])) + assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)])) - assert len(res) == 3 - assert_iter_sum_pattern(res[0], 4, 0) - assert_iter_sum_pattern(res[1], 3, 0) - assert_iter_sum_pattern(res[2], 2, 0) + assert_iter_sum_pattern( + {fld(x, 6): (4, 0), fld(flm(x, 6), 2): (3, 0), flm(x, 2): (2, 0)}, var_dom([(x, 24)]) + ) # simple symbolic bound # TODO(tvm-team) improve symbolic divisible check to enable # more complicated symbolic bound - res = tvm.arith.detect_iter_map([fld(x, c0), flm(x, c0)], var_dom([(x, c1 * c0)])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], c1, 0) - assert_iter_sum_pattern(res[1], c0, 0) + assert_iter_sum_pattern({fld(x, c0): (c1, 0), flm(x, c0): (c0, 0)}, var_dom([(x, c1 * c0)])) - res = tvm.arith.detect_iter_map([fld(x * 2, 4), flm(x * 2, 4)], var_dom([(x, 8)])) + assert_iter_sum_pattern({fld(x * 2, 4): (4, 0, 1), flm(x * 2, 4): (2, 0, 2)}, var_dom([(x, 8)])) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 4, 0, scale=1) - assert_iter_sum_pattern(res[1], 2, 0, scale=2) - - res = tvm.arith.detect_iter_map([fld(x * 2, 4) * 4 + flm(x * 2, 4)], var_dom([(x, 8)])) - - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 8, 0, scale=2) + assert_iter_sum_pattern( + { + fld(x * 2, 4) * 4 + flm(x * 2, 4): (8, 0, 2), + }, + var_dom([(x, 8)]), + ) - res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) - assert len(res) == 0 + assert_iter_sum_failure([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) def test_compound(): - x = tvm.tir.Var("x", "int32"), 10 - y = tvm.tir.Var("y", "int32"), 9 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") - xo, xi = isplit(x, 5) - yo, yi = isplit(y, 3) + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) - res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 18, 0) - assert_iter_sum_pattern(res[1], 5, 0) # reconstruct the pattern manually - mx = tvm.arith.IterMark(x[0], 10) - my = tvm.arith.IterMark(y[0], 9) - + mx = tvm.arith.IterMark(x, 10) + my = tvm.arith.IterMark(y, 9) xoscale = 3 - xiscale = 1 yoscale = 6 yiscale = 1 mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale) - mxi = tvm.arith.IterSplitExpr(mx, 1, 5, xiscale) myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) - mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) - tvm.ir.assert_structural_equal(sz, res[0]) + assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)])) def test_predicate(): - x = tvm.tir.Var("x", "int32"), 13 - y = tvm.tir.Var("y", "int32"), 10 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") # available contraints # upper bound only - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] <= 127) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 128 + ) + + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y <= 127 + ) # lower bound only - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 6) - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 6) + assert_iter_sum_pattern( + {x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y > 5 + ) + + assert_iter_sum_pattern( + {x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y >= 6 + ) # lower bound + upper bound - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128), + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 6) - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127), + + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 6) # constraint on one fused iter i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32") k = tvm.tir.Var("k", "int32") - res = tvm.arith.detect_iter_map( - [i * 8 + j * 2 + k], + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (88, 1)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), + predicate=tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), ) - assert_iter_sum_pattern(res[0], 88, 1) # constraint on single var - res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10)) - assert_iter_sum_pattern(res[0], 10, 0) + assert_iter_sum_pattern({i: (10, 0)}, var_dom([(i, 48)]), predicate=i < 10) - # iterations are subparts of constraint, invalid, case 1 - res = tvm.arith.detect_iter_map( + # iterations are subparts of constraint, invalid case 1 + assert_iter_sum_failure( [i, j, k], var_dom([(i, 128), (j, 128), (k, 128)]), - tvm.tir.all(i * 16384 + j * 128 + k < 100), + predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100), ) - assert len(res) == 0 - # iterations are subparts of constraint, invalid, case 2 - res = tvm.arith.detect_iter_map( + # iterations are subparts of constraint, invalid case 2 + assert_iter_sum_failure( [i * 128 + j, k], var_dom([(i, 128), (j, 128), (k, 128)]), - tvm.tir.all(i * 16384 + j * 128 + k < 100), + predicate=i * 16384 + j * 128 + k < 100, ) - assert len(res) == 0 # irrelavant predicate - res = tvm.arith.detect_iter_map( - [i + j], - var_dom([(i, 1)]), - j <= 24, - ) - assert_iter_sum_pattern(res[0], 1, j) + assert_iter_sum_pattern({i + j: (1, j)}, var_dom([(i, 1)]), predicate=j <= 24) # constraint on nested fused iters - res = tvm.arith.detect_iter_map( - [i * 8 + j * 2 + k], + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (22, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25), + predicate=tvm.tir.all( + 1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25 + ), ) - assert_iter_sum_pattern(res[0], 22, 3) # duplicate constraint on one fused iter - res = tvm.arith.detect_iter_map( - [i * 6 + j * 2 + k], + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (66, 2)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), + predicate=tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), ) - assert_iter_sum_pattern(res[0], 66, 2) # duplicate constraint on nested fused iters - res = tvm.arith.detect_iter_map( - [i * 6 + j * 2 + k], + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (15, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all( + predicate=tvm.tir.all( 1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, @@ -326,15 +290,13 @@ def test_predicate(): i * 6 + j * 2 + k < 18, ), ) - assert_iter_sum_pattern(res[0], 15, 3) # constraint on non-disjoint fused iters should fail - res = tvm.arith.detect_iter_map( + assert_iter_sum_failure( [i * 8 + j * 2 + k], var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), + predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), ) - assert len(res) == 0 # constraint on many disjoint fused iters, case 1 # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) @@ -346,154 +308,135 @@ def test_predicate(): i3 = tvm.tir.Var("i3", "int32") i4 = tvm.tir.Var("i4", "int32") i5 = tvm.tir.Var("i5", "int32") - res = tvm.arith.detect_iter_map( - [i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5], + assert_iter_sum_pattern( + {i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)}, var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), - tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), + predicate=tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), ) - assert_iter_sum_pattern(res[0], 540, 93) # constraint on many disjoint fused iters, case 2 - res = tvm.arith.detect_iter_map( - [i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4], + assert_iter_sum_pattern( + {i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)}, var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), - tvm.tir.all(3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10), + predicate=tvm.tir.all( + 3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10 + ), ) - assert_iter_sum_pattern(res[0], 135, 28) # constraint on split iters - res = tvm.arith.detect_iter_map( - [i % 16, i // 16], + assert_iter_sum_pattern( + {i % 16: (7, 3), i // 16: (8, 4)}, var_dom([(i, 1024)]), - tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), - require_bijective=True, + predicate=tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), + check_level="bijective", ) - assert_iter_sum_pattern(res[0], 7, 3) - assert_iter_sum_pattern(res[1], 8, 4) # constraint on split iters, nested case 1 - res = tvm.arith.detect_iter_map( - [(i * 32 + j) % 16], + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (7, 3)}, var_dom([(i, 5), (j, 32)]), - tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), + predicate=tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), ) - assert_iter_sum_pattern(res[0], 7, 3) # constraint on split iters, nested case 2 - res = tvm.arith.detect_iter_map( - [(i * 32 + j) % 16], + assert_iter_sum_failure( + [ + (i * 32 + j) % 16, + ], var_dom([(i, 5), (j, 32)]), - tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), - require_bijective=True, + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + check_level="bijective", ) - assert len(res) == 0 - res = tvm.arith.detect_iter_map( - [(i * 32 + j) % 16], + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (16, 0)}, var_dom([(i, 5), (j, 32)]), - tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), ) - assert_iter_sum_pattern(res[0], 16, 0) - res = tvm.arith.detect_iter_map( - [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16], + assert_iter_sum_pattern( + {(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)}, var_dom([(i, 5), (j, 32)]), - tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), ) - assert_iter_sum_pattern(res[0], 16, 0) - assert_iter_sum_pattern(res[1], 4, 0) # non-standard form of predicate - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 < 128 - y + ) # duplicate constraint - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.all(x[0] * 10 + y[0] < 128, x[0] * 10 + y[0] < 64), + assert_iter_sum_pattern( + {x * 10 + y: (64, 0)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 64, 0) - # useless constraint - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 140) - - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 130, 0) + assert_iter_sum_pattern( + {x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140 + ) - i1 = tvm.tir.Var("i1", "int32"), 7 - i2 = tvm.tir.Var("i2", "int32"), 2 - i3 = tvm.tir.Var("i3", "int32"), 4 - i4 = tvm.tir.Var("i4", "int32"), 3 - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + i1 = tvm.tir.Var("i1", "int32") + i2 = tvm.tir.Var("i2", "int32") + i3 = tvm.tir.Var("i3", "int32") + i4 = tvm.tir.Var("i4", "int32") + assert_iter_sum_pattern( + {i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)}, + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 10, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, ) ), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) - - i1 = tvm.tir.Var("i1", "int32"), 7 - i2 = tvm.tir.Var("i2", "int32"), 2 - i3 = tvm.tir.Var("i3", "int32"), 4 - i4 = tvm.tir.Var("i4", "int32"), 3 # wrong constraint - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 7, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 7, ) ), ) - assert len(res) == 0 # incompatible constraint - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 10, - i1[0] * 4 + i3[0] < 20, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, + i1 * 4 + i3 < 20, ) ), ) - assert len(res) == 0 - - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i1[0] * 4 + i3[0] < 20, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i1 * 4 + i3 < 20, ) ), ) - assert len(res) == 0 # zero iter - xo = tvm.tir.Var("xo", "int32"), 1 - xi = tvm.tir.Var("xi", "int32"), 129 - y = tvm.tir.Var("y", "int32"), 128 - - res = tvm.arith.detect_iter_map( - [xo[0] * 129 + xi[0], y[0]], var_dom([xo, xi, y]), xo[0] * 129 + xi[0] < 128 + xo = tvm.tir.Var("xo", "int32") + xi = tvm.tir.Var("xi", "int32") + y = tvm.tir.Var("y", "int32") + assert_iter_sum_pattern( + {xo * 129 + xi: (128, 0), y: (128, 0)}, + var_dom([(xo, 1), (xi, 129), (y, 128)]), + predicate=xo * 129 + xi < 128, ) @@ -560,9 +503,10 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + assert_iter_sum_pattern + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.2 @@ -574,9 +518,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.3 @@ -595,9 +539,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) tvm.ir.assert_structural_equal(res[2][1], True) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.5 @@ -613,9 +557,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], True) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.6 @@ -650,9 +594,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.2 @@ -668,9 +612,11 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])) + res1 = tvm.arith.detect_iter_map( + [res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3]) + ).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])).indices assert len(res2) == 3 # compound 2.3 @@ -698,9 +644,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.5 @@ -736,13 +682,6 @@ def test_complex(): i0 = ifuse([j0, j1], 200) i1 = ifuse([j2, j3], 50) - res = tvm.arith.detect_iter_map( - [i0[0], i1[0]], - var_dom([l0, l1, n0, n1, m1, l3]), - tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), - ) - assert len(res) == 2 - n0_mark = tvm.arith.IterMark(n0[0], n0[1]) n1_mark = tvm.arith.IterMark(n1[0], n1[1]) l0_mark = tvm.arith.IterMark(l0[0], l0[1]) @@ -790,16 +729,20 @@ def test_complex(): i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) - tvm.ir.assert_structural_equal(i0_final, res[0]) - tvm.ir.assert_structural_equal(i1_final, res[1]) + assert_iter_sum_pattern( + {i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)}, + var_dom([l0, l1, n0, n1, m1, l3]), + predicate=tvm.tir.all( + i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15 + ), + ) # wrong constraint - res = tvm.arith.detect_iter_map( + assert_iter_sum_failure( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), ) - assert len(res) == 0 # subspace_division res = tvm.arith.subspace_divide( @@ -828,34 +771,33 @@ def test_complex(): ), ) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([n0, n1, m1, l3]), res[2][1]) - assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([l0, l1])) - assert len(res2) == 2 + assert_iter_sum_pattern( + {res[0][1]: (32, 0), res[1][1]: (15, 0)}, var_dom([n0, n1, m1, l3]), res[2][1] + ) + assert_iter_sum_pattern({res[0][0]: (8, 0), res[1][0]: (4, 0)}, var_dom([l0, l1])) def test_normalize_iter_map_to_expr(): fld = tvm.tir.floordiv flm = tvm.tir.floormod - x = tvm.tir.Var("x", "int32"), 10 - y = tvm.tir.Var("y", "int32"), 9 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") - xo, xi = isplit(x, 5) - yo, yi = isplit(y, 3) + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) - - res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) + res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([(x, 10), (y, 9)])) tvm.ir.assert_structural_equal( - tvm.arith.normalize_iter_map_to_expr(res[0]), - fld(y[0], 3) * 6 + fld(x[0], 5) * 3 + flm(y[0], 3), + tvm.arith.normalize_iter_map_to_expr(res.indices[0]), + fld(y, 3) * 6 + fld(x, 5) * 3 + flm(y, 3), ) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res.indices[1]), flm(x, 5)) # iter mark wrap a complex expr - split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 1024), 1, 1024, 1) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1) + split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x * y + 1, 1024), 1, 1024, 1) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x * y + 1) def test_inverse_affine_iter_map(): @@ -869,7 +811,9 @@ def test_inverse_affine_iter_map(): l1_0, l1_1 = isplit(l1, 4) l0_1_l1_1_fused = ifuse([l0_1, l1_1]) - iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])) + iter_map = tvm.arith.detect_iter_map( + [l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1]) + ).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 2 @@ -888,7 +832,7 @@ def test_inverse_affine_iter_map(): iter_map = tvm.arith.detect_iter_map( [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2]) - ) + ).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 @@ -908,7 +852,7 @@ def test_inverse_affine_iter_map(): l1_0, l1_1 = isplit(l1, 8) l2 = ifuse([l1_1, l1_0]) - iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])) + iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 1 @@ -924,12 +868,11 @@ def test_free_variables(): z = tvm.tir.Var("z", "int32") # illegal iter if z is within dom - res = tvm.arith.detect_iter_map([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) - assert len(res) == 0 + assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) # iter is valid if z is free, even there are linear forms of z - res = tvm.arith.detect_iter_map( - [z * 19 + y * 3 + x], + assert_iter_sum_pattern( + {z * 19 + y * 3 + x: (9, z * 19)}, var_dom( [ (x, 3), @@ -937,9 +880,8 @@ def test_free_variables(): ] ), ) - assert_iter_sum_pattern(res[0], 9, z * 19) - res = tvm.arith.detect_iter_map( - [z * z + y * 3 + x], + assert_iter_sum_pattern( + {z * z + y * 3 + x: (9, z * z)}, var_dom( [ (x, 3), @@ -947,7 +889,6 @@ def test_free_variables(): ] ), ) - assert_iter_sum_pattern(res[0], 9, z * z) def test_padding(): @@ -956,61 +897,45 @@ def test_padding(): fld = tvm.tir.floordiv flm = tvm.tir.floormod - def assert_padding_pattern(expect_dict, dom_map, predicate=True, require_bijective=False): - keys = list(expect_dict.keys()) - res = tvm.arith.detect_iter_map( - keys, dom_map, predicate=predicate, require_bijective=require_bijective - ) - assert len(res) == len(keys) - print(res) - for i, iter_expr in enumerate(keys): - extent, base, scale = expect_dict[iter_expr] - assert_iter_sum_pattern(res[i], extent, base, scale) - - def assert_padding_pattern_failure(iters, dom_map, predicate=True, require_bijective=False): - res = tvm.arith.detect_iter_map( - list(iters), dom_map, predicate=predicate, require_bijective=False - ) - assert len(res) == 0 - # left padding only, offset divisible sum = 64 + y dom_map = var_dom([(y, 192)]) - assert_padding_pattern( + assert_iter_sum_pattern( {fld(sum, 32): (6, 2, 1), flm(sum, 32): (32, 0, 1)}, dom_map, - require_bijective=True, + check_level="bijective", ) # left padding only, offset non-divisible sum = 80 + y dom_map = var_dom([(y, 176)]) - assert_padding_pattern( + assert_iter_sum_pattern( {fld(sum, 32): (6, 2, 1)}, dom_map, ) - assert_padding_pattern( + assert_iter_sum_pattern( {flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)}, dom_map, ) - assert_padding_pattern_failure({fld(sum, 32), flm(sum, 32)}, dom_map) + assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map) + assert_iter_sum_failure({fld(sum, 32), fld(sum, 4)}, dom_map) # right padding only, offset divisible sum = x * 32 + y * 8 dom_map = var_dom([(x, 5), (y, 4)]) - assert_padding_pattern( + assert_iter_sum_pattern( {fld(sum, 16): (10, 0, 1), flm(sum, 16): (2, 0, 8)}, dom_map, ) - assert_padding_pattern_failure({fld(sum, 5)}, dom_map) + assert_iter_sum_failure({fld(sum, 5)}, dom_map) # right padding only, offset non-divisible dom_map = var_dom([(x, 26)]) - assert_padding_pattern( + assert_iter_sum_pattern( {fld(x, 15): (2, 0, 1)}, dom_map, ) - assert_padding_pattern( + assert_iter_sum_pattern( {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}, dom_map, ) @@ -1018,8 +943,8 @@ def assert_padding_pattern_failure(iters, dom_map, predicate=True, require_bijec # padding constants on both side sum = x + 71 dom_map = var_dom([(x, 45)]) - assert_padding_pattern({fld(sum, 32): (2, 2, 1)}, dom_map) - assert_padding_pattern( + assert_iter_sum_pattern({fld(sum, 32): (2, 2, 1)}, dom_map) + assert_iter_sum_pattern( {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}, dom_map, ) @@ -1027,8 +952,8 @@ def assert_padding_pattern_failure(iters, dom_map, predicate=True, require_bijec # padding for free iteration part sum = x * 360 + y dom_map = var_dom([(y, 360)]) - assert_padding_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map) - assert_padding_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map) + assert_iter_sum_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map) + assert_iter_sum_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map) if __name__ == "__main__": From 8d46bb5a23ded3baa91a6cd3f021247041c1a612 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 28 May 2022 14:29:58 +0800 Subject: [PATCH 5/6] - check incompatible left paddings - determine case like x % 16, x in [0, 5) to be non-surjective, since usages may treat the region extent as 16 by mistake. - skip second round of rewrite when there is no padding - fix some typo in comments --- include/tvm/arith/iter_affine_map.h | 7 +- src/arith/iter_affine_map.cc | 152 +++++++++++------- .../unittest/test_arith_iter_affine_map.py | 42 ++++- 3 files changed, 136 insertions(+), 65 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index aea67d54584c..2c0e5e92997a 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -308,8 +308,11 @@ class IterMapResultNode : public Object { */ class IterMapResult : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(IterMapResult, ObjectRef, IterMapResultNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMapResultNode); + // constructor + IterMapResult() { data_ = make_object(); } + + /*! \return mutable pointers to the node. */ + IterMapResultNode* operator->() const { return static_cast(get_mutable()); } }; /*! diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 6b724ca1d6e2..cce826fedca6 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -199,16 +199,17 @@ class IterMapRewriter : public ExprMutator { } PrimExpr padding_predicate() const { return padding_predicate_; } - PrimExpr requires_padding() const { return !analyzer_->CanProveEqual(padding_predicate_, 0); } + bool requires_padding() const { return requires_padding_; } IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - void UpdatePadding(const PrimExpr& expr) { + IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) { update_iterator_padding_ = true; - DirectMutate(expr); + auto res = Rewrite(expr); update_iterator_padding_ = false; + return res; } IterSumExpr RewriteIterConstraint(const PrimExpr& expr, @@ -425,27 +426,30 @@ class IterMapRewriter : public ExprMutator { // input iter marks std::vector input_marks_; - // Map from a split iter to the padded iterator information for + // Map from an iter mark to the padded iterator information for // it. This is necessary for introducing the same padding in all // usage of an input iterator. (e.g. (i-1) occurring in the // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be // left-padded by 31 for each occurrence.) std::unordered_map padded_iter_map_; - /* If allow_padding_ is true, allow the extents of the IterMap to be + // Map from padded iter mark to it's origin mark + std::unordered_map padded_origin_map_; + + /* If update_iterator_padding_ is true, allow the extents of the IterMap to be * padded beyond the original iterators. * - * For example, if allow_padding_ is true, the expressions i//4 and + * For example, if update_iterator_padding_ is true, the expressions i//4 and * i%4, where i is on the range [0,18), would be represented as * IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4). - * This representation would be forbidden if allow_padding_ is false, + * This representation would be forbidden if update_iterator_padding_ is false, * because lower_factor=4 does not evenly divide the original extent of * 18. */ bool update_iterator_padding_{false}; /* A boolean expression that is true for any padding that has been - * introduced, and false otherwise. If allow_padding_ is false, + * introduced, and false otherwise. If update_iterator_padding_ is false, * padding_predicate_ will always be false. * * Example: [i//4, i%4], i in range [0,16) @@ -459,6 +463,11 @@ class IterMapRewriter : public ExprMutator { */ PrimExpr padding_predicate_; + /* A boolean flag denotes there are padding iterations detected + * in the first round of indices rewriting. + */ + bool requires_padding_{false}; + // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly // an extra offset) // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) @@ -488,25 +497,6 @@ class IterMapRewriter : public ExprMutator { // The flattened forms of constrained iters std::vector constrained_iters_flattened_; - /*! - * \brief Extract original iteration mark's extent before padding, return NullOpt is - * there is no extra padding. - */ - Optional ExtractExtentBeforePadding(const IterMark& mark, Analyzer* analyzer) { - const IterSumExprNode* sum = mark->source.as(); - if (!sum || sum->args.size() != 1) { - return NullOpt; - } - IterSplitExpr split = sum->args[0]; - if (!analyzer->CanProveEqual(split->extent, mark->extent) && - analyzer->CanProveEqual(split->scale, 1) && - analyzer->CanProveEqual(split->lower_factor, 1) && - analyzer->CanProveEqual(split->source->extent, split->extent)) { - return sum->args[0]->extent; - } - return NullOpt; - } - /*! * \brief Look for a split in splits that is not used such that its lower_factor is smallest. * Note that here we use division to compare lower_factor. @@ -580,9 +570,9 @@ class IterMapRewriter : public ExprMutator { expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; } - // Extract padding info of the iteration mark, extent before padding - // is only defined when padding exists. - Optional extent_before_padding = ExtractExtentBeforePadding(mark, analyzer_); + // Extract iteration mark info before padding + auto pad_mark_it = padded_origin_map_.find(mark); + bool has_padding = pad_mark_it != padded_origin_map_.end(); bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent); bool match_iter_divisor = @@ -598,33 +588,46 @@ class IterMapRewriter : public ExprMutator { // // Case 3. bijective is not required and there exists padding. We check either // (3.1) The extent we calculate is consistent with the extent of the padded mark and it is - // the single split for the iter mark. + // the single split for the iter mark. // For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective // according to how we pad the original iteration mark. // (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent - // before padding is greater or equal than the extent we calculate. - // For example, padded iter p in [0, 24), the original extent is 14, [(p % 12)] is - // valid. + // before padding is greater or equal than the extent we calculate. + // For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24. // if (check_level == IterMapLevel::Bijective) { - if (extent_before_padding.defined() || !match_full_iter) { - return Array(); + if (has_padding) { + ErrorLogger(this) << "Bijectvie mapping should not take iter paddings"; + return {}; + } else if (!match_full_iter) { + ErrorLogger(this) << "The iterations do not traverse full iter space"; + return {}; } - } else if (!extent_before_padding.defined()) { + } else if (!has_padding) { if (!match_iter_divisor) { - return Array(); + ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; + return {}; } } else if (check_level == IterMapLevel::Surjective) { + PrimExpr extent_before_padding = pad_mark_it->second->extent; if (match_full_iter) { if (splits.size() != 1) { + ErrorLogger(this) << "Dependent iterations on padding iter space"; + return Array(); + } else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) && + !analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { + ErrorLogger(this) << "Split on padding iteration is not surjective " + << "if the split extent equals to the full iter space extent"; return Array(); } } else if (match_iter_divisor) { - if (!analyzer_->CanProve(extent_before_padding.value() >= expected_lower_factor)) { + if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { + ErrorLogger(this) << "The extent before padding is less than lower factor"; return Array(); } } else { - return Array(); + ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; + return {}; } } return Array(iters.rbegin(), iters.rend()); @@ -1056,22 +1059,21 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, IterMapLevel check_level, arith::Analyzer* analyzer, bool simplify_trivial_iterators) { - IterMapResult result_obj = IterMapResult(make_object()); - auto result = result_obj.CopyOnWrite(); + IterMapResult result; // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. if (!IterRangeSanityCheck(input_iters)) { result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); - return result_obj; + return result; } Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { result->errors.push_back("Could not parse predicate as constraints on the input iterators."); - return result_obj; + return result; } // We have to make sure when we visit an iterator, all the constraints related with its successors // in the iter var graph has been visited, where the expression of this iterator will contain the @@ -1090,32 +1092,39 @@ IterMapResult DetectIterMap(const Array& indices, const Maperrors.size()) { - return result_obj; + if (result->errors.size() > 0) { + return result; } } if (!rewriter.CheckConstraints()) { result->errors.push_back("Invalid constraints."); - return result_obj; + return result; } - // Step0.1: Check each index to determine required padding + // Step0.1: Rewrite indicies and determine required padding, + // if there is no padding, it should be the final result. + Array rewrite_indices; + rewrite_indices.reserve(indices.size()); bool allow_padding = check_level != IterMapLevel::Bijective; if (allow_padding) { for (PrimExpr value : indices) { - rewriter.UpdatePadding(value); + rewrite_indices.push_back(rewriter.RewriteAndUpdatePadding(value)); + if (result->errors.size() > 0) { + return result; + } } } - // Step0.2: rewrite indices - Array rewrite_indices; - for (PrimExpr value : indices) { - rewrite_indices.push_back(rewriter.Rewrite(value)); - if (result->errors.size()) { - return result_obj; + // Step0.2: Rewrite indices in the second round. + if (!allow_padding || rewriter.requires_padding()) { + rewrite_indices.clear(); + for (PrimExpr value : indices) { + rewrite_indices.push_back(rewriter.Rewrite(value)); + if (result->errors.size() > 0) { + return result; + } } } - result->padding_predicate = rewriter.padding_predicate(); // Step1: IterIndependenceChecker checks if the iterator are independent. @@ -1125,10 +1134,10 @@ IterMapResult DetectIterMap(const Array& indices, const Maperrors.push_back("Mapped indices are not independent."); } - return result_obj; + return result; } result->indices = rewrite_indices; - return result_obj; + return result; } TVM_REGISTER_GLOBAL("arith.DetectIterMap") @@ -1304,6 +1313,10 @@ PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyze auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second)); if (analyzer->CanProveEqual(p1.first, p2.first)) { return p1.first * const_lcm; + } else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) { + return p1.first * const_lcm; + } else if (analyzer->CanProveEqual(floormod(p2.first, p1.first), 0)) { + return p2.first * const_lcm; } else { return (p1.first * p2.first) * const_lcm; } @@ -1348,6 +1361,9 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl } // Update padding requirement on the lower side of the source iter mark. + // In the second pass, all splits would check whether the maximum left pading + // on the iter mark is compatible with it's own left padding. + requires_padding_ = true; PrimExpr mark_left_pad = left_pad * split->lower_factor; info.left_pad = max(info.left_pad, mark_left_pad); @@ -1381,6 +1397,22 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl if (!info.padded.defined()) { // the first time encounter the iter mark to pad, update the padded mark. PrimExpr mark_left_pad = info.left_pad; + if (CanProveDivisible(mark_left_pad, split->lower_factor)) { + // correct current split's left padding + // (mark_left_pad + iter) // lower_factor % extent => + // (left_pad * lower_factor + mark) // lower_factor % extent => + // (left_pad + mark // lower_factor) % extent => + // left_pad + (mark // lower_factor % extent) => + // left_pad + split + // since the extent covers the full padding range. + left_pad = floordiv(mark_left_pad, split->lower_factor); + } else { + ErrorLogger(this) << "Detect incompatible left padding on " + << tvm::PrettyPrint(NormalizeIterMapToExpr(split)) + << ", the iter mark is left padded with " << mark_left_pad; + return {IterSplitExpr(), PrimExpr()}; + } + PrimExpr right_edge = mark->extent + mark_left_pad; PrimExpr mark_right_pad; if (CanProveDivisible(right_edge, info.padding_factor)) { @@ -1391,6 +1423,7 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad); info.right_pad = mark_right_pad; info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent); + padded_origin_map_[info.padded] = mark; auto left_padding_introduced = (mark_left_pad != 0); @@ -1557,7 +1590,6 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below // where x=floormod(floordiv(iter, lower_factor), extent) + base - auto pair = PadDividendToDivisor(lhs, base, rhs); IterSplitExpr padded = pair.first; if (!padded.defined()) { @@ -1676,7 +1708,7 @@ bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs PrimExpr divisor = normalizer.Convert(rhs); return analyzer_->CanProveEqual(dividend, divisor) || - analyzer_->CanProve(analyzer_->Simplify(floormod(dividend, divisor), 8) == 0); + analyzer_->CanProve(floormod(dividend, divisor) == 0); } PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index ac3ae00b8a0c..d7bfa1c91947 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -58,8 +58,10 @@ def assert_iter_sum_pattern( predicate=predicate, check_level=check_level, simplify_trivial_iterators=simplify_trivial_iterators, - ).indices - assert len(res) == len(keys) + ) + indices = res.indices + assert len(indices) == len(keys), res.errors + print(indices) for i, input_iter in enumerate(keys): spec = expect_dict[input_iter] ( @@ -68,7 +70,7 @@ def assert_iter_sum_pattern( ) = spec[0:2] scale = spec[2] if len(spec) > 2 else 1 expect_iter = spec[3] if len(spec) > 3 else None - sum_expr = res[i] + sum_expr = indices[i] assert isinstance(sum_expr, tvm.arith.IterSumExpr) if extent == 1: assert len(sum_expr.args) == 0 @@ -955,6 +957,40 @@ def test_padding(): assert_iter_sum_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map) assert_iter_sum_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map) + # multiple split with same mark offset, could + # be surjective on missing (padded // LCM) + assert_iter_sum_pattern( + { + flm(x + 10, 3): (3, 0), + flm(fld(x + 10, 3), 4): (4, 0), + flm(fld(fld(x + 10, 3), 4), 5): (5, 0), + }, + var_dom([(x, 240)]), + ) + assert_iter_sum_failure( + { + flm(x + 10, 3), + flm(fld(x + 10, 3), 4), + flm(fld(fld(x + 10, 3), 4), 5), + fld(fld(fld(x + 10, 3), 4), 5), + }, + var_dom([(x, 240)]), + ) + + # different offsets on splits + assert_iter_sum_pattern( + { + flm(x + 1, 3): (3, 0), + flm(fld(x + 10, 3) + 2, 4): (4, 0), + flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0), + }, + var_dom([(x, 240)]), + ) + + # original extent is smaller than the divident + # it is not surjective wrt to the region [0, 16) + assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)])) + if __name__ == "__main__": tvm.testing.main() From 4d1239a0d5332cec1df577f2598fcea1d5459f42 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 31 May 2022 13:27:52 +0800 Subject: [PATCH 6/6] rebase upstream --- src/tir/schedule/primitive/layout_transformation.cc | 7 ++++--- tests/python/unittest/test_tir_schedule_compute_at.py | 11 ++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 6da796fc955f..692f68a600ae 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -392,8 +392,9 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, auto iter_map = arith::DetectIterMap( /*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true), - /*require_bijective=*/true, &analyzer, /*simplify_trivial_iterators=*/true); - if (iter_map.empty()) { + /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, + /*simplify_trivial_iterators=*/true); + if (iter_map->indices.empty()) { throw NotBijectiveAffineIndexMapError(self->mod, index_map); } @@ -417,7 +418,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters // in the body. - auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars); + auto inverse_map = arith::InverseAffineIterMap(iter_map->indices, new_block_vars); // Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant // zero. for (const auto& iter_var : block_ptr->iter_vars) { diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index c7552942cf7e..f477367adfad 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1281,13 +1281,10 @@ def grouped_channel_bias_non_perfect_tiled( cc = T.axis.spatial(720, c_o * 360 + c_i) Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] - def check_sched(debug_mask): - sch = tir.Schedule(grouped_channel_bias, debug_mask=debug_mask) - loop = sch.get_loops(sch.get_block("compute"))[0] - sch.compute_at(sch.get_block("init"), loop) - tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) - - check_sched("all") + sch = tir.Schedule(grouped_channel_bias, debug_mask="all") + loop = sch.get_loops(sch.get_block("compute"))[0] + sch.compute_at(sch.get_block("init"), loop) + tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) def test_fail_subtree_complete_block():