From 52ea3ca731b4b0c5d86a5f5c44e49baeb098f743 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 6 Jul 2022 23:41:10 +0800 Subject: [PATCH] support overlapped itersum --- src/arith/iter_affine_map.cc | 91 ++++++++++++++----- tests/python/unittest/test_arith_intset.py | 7 +- .../unittest/test_arith_iter_affine_map.py | 58 +++++++++++- .../unittest/test_meta_schedule_space_cpu.py | 26 +++--- .../unittest/test_meta_schedule_space_cuda.py | 12 +-- .../unittest/test_tir_schedule_reorder.py | 30 +++++- .../unittest/test_tir_schedule_split_fuse.py | 8 +- .../test_tir_schedule_state_cached_flags.py | 2 +- 8 files changed, 176 insertions(+), 58 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index d2aa16ded1f6..83e2821c9800 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -177,8 +177,12 @@ class IterMapRewriter : public ExprMutator { using Parent = ExprMutator; explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, - bool simplify_trivial_iterators, Array* errors) - : analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) { + IterMapLevel check_level, bool simplify_trivial_iterators, + Array* errors) + : analyzer_(analyzer), + check_level_(check_level), + errors_(*errors), + padding_predicate_(const_false()) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -419,6 +423,8 @@ class IterMapRewriter : public ExprMutator { // Internal analyzer Analyzer* analyzer_; + // Iter map check level + IterMapLevel check_level_; // Error messages for each unresolved expression. Array& errors_; // The var map @@ -651,7 +657,7 @@ class IterMapRewriter : public ExprMutator { if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max.value() - base; } - Optional opt = TryFuseIters(expr); + Optional opt = TryFuseIters(expr, check_level_); ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 if (opt.defined() && is_one(opt.value()->args[0]->scale)) { @@ -702,7 +708,7 @@ class IterMapRewriter : public ExprMutator { IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter if (expr->args.size() < 1) return expr; - Optional opt = TryFuseIters(expr); + Optional opt = TryFuseIters(expr, check_level_); if (opt.defined()) { return opt.value(); } else { @@ -735,9 +741,10 @@ class IterMapRewriter : public ExprMutator { * return a corresponding IterSumExpr with extra offset if needed. * Try to normalize IterSum into a fused IterMark * \param expr The input sum. + * \param check_level The check level if iter mapping. * \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryFuseIters(IterSumExpr expr) { + Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level) { // select the iterators in order std::vector visited(expr->args.size(), false); std::vector flattened_iters, grouped_iters; @@ -758,14 +765,42 @@ class IterMapRewriter : public ExprMutator { } // check if it can be remapped into a fused pattern. PrimExpr expected_extra_base = 0; + PrimExpr tail_extent = 0; PrimExpr expected_scale = base_scale.value(); for (size_t i = 0; i < expr->args.size();) { - // find j such that expr->args[j] has expected scale - size_t j = i == 0 ? base_index : 0; - for (; j < expr->args.size(); ++j) { - if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break; + // find position such that expr->args[j] match expected scale + int j = i == 0 ? base_index : expr->args.size() - 1; + + size_t matched_pos = expr->args.size(); + PrimExpr matched_scale{nullptr}; + bool is_exact_match{false}; + + for (; j >= 0; --j) { + if (visited[j]) { + continue; + } + const PrimExpr& cur_scale = expr->args[j]->scale; + + // for bijective mapping, the matched scale must equal to expected scale + if (analyzer_->CanProveEqual(cur_scale, expected_scale)) { + matched_pos = j; + matched_scale = cur_scale; + is_exact_match = true; + break; + } + if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) { + // find the closest scale which is less or equal to expected scale + if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) && + analyzer_->CanProveGreaterEqual(cur_scale, 0)) { + if (matched_pos == expr->args.size() || + analyzer_->CanProveLess(matched_scale - cur_scale, 0)) { + matched_pos = j; + matched_scale = cur_scale; + } + } + } } - if (j == expr->args.size()) { + if (matched_pos == expr->args.size()) { return NullOpt; } // look for the longest constrained iter started from expr->args[j] @@ -775,8 +810,8 @@ class IterMapRewriter : public ExprMutator { // otherwise we expect the scale of i to be 2*5=10 Optional constraint_to_match; for (const IterSumExpr& iter : constrained_iters_flattened_) { - if (IterSplitEqual(expr->args[j], iter->args.back(), false)) { - // find a predicate started from expr->args[j] + if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) { + // find a predicate started from match position if (!constraint_to_match || constraint_to_match.value()->args.size() < iter->args.size()) { constraint_to_match = iter; @@ -793,7 +828,7 @@ class IterMapRewriter : public ExprMutator { size_t k = 0; for (; k < expr->args.size(); ++k) { if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) { - if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale)) + if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale)) break; } } @@ -806,20 +841,25 @@ class IterMapRewriter : public ExprMutator { auto iter = sum_fuse_map_.find(constraint_to_match.value()); ICHECK(iter != sum_fuse_map_.end()); const IterMarkWithOffset& iter_matched = iter->second; - grouped_iters.emplace_back(iter_matched.mark, expected_scale); - expected_extra_base += iter_matched.offset * expected_scale; - expected_scale *= iter_matched.mark->extent; + grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value())); + expected_extra_base += iter_matched.offset * matched_scale; + if (!is_exact_match) { + tail_extent += expected_scale - matched_scale; + } + expected_scale = matched_scale * iter_matched.mark->extent; // move forward i += constraint_to_match.value()->args.size(); } else { // constraint_to_match not found, skip this iterator - visited[j] = true; - IterSplitExpr arg = expr->args[j]; - arg.CopyOnWrite()->scale = - analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value())); + visited[matched_pos] = true; + IterSplitExpr arg = expr->args[matched_pos]; + arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value())); flattened_iters.push_back(arg); grouped_iters.push_back(arg); - expected_scale *= expr->args[j]->extent; + if (!is_exact_match) { + tail_extent += expected_scale - matched_scale; + } + expected_scale = matched_scale * expr->args[matched_pos]->extent; ++i; } } @@ -843,7 +883,8 @@ class IterMapRewriter : public ExprMutator { expr->base + expected_extra_base); } else { // new iter, form a new mark - IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value())); + IterMark mark = + IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent); sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0); flattened_map_[structured_form] = flattened_form; return IterSumExpr({IterSplitExpr(mark, base_scale.value())}, @@ -1086,8 +1127,8 @@ IterMapResult DetectIterMap(const Array& indices, const Maperrors); + IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level, + simplify_trivial_iterators, &result->errors); // Step0.0: rewrite constraints in the order from size-small ones to size-big ones for (const IterConstraint& constraint : constraints) { auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, @@ -1281,7 +1322,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } else if (sum->args.size() == 1) { return sum; } - auto opt_fused = TryFuseIters(sum); + auto opt_fused = TryFuseIters(sum, check_level_); if (!opt_fused) { ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend) << ", can't be written as a single fused IterSum"; diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index ca9d1077feb2..74b53442ec27 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -323,10 +323,6 @@ def do_test_point_access(point, predicates, var_dom, expect): def test_region_lower_bound_unfusable(): - # This test is designed to trigger an error in DetectIterMap, - # resulting from a numerator which required multiple input - # variables. The bug resulted in an exception being thrown, - # rather than a return value of None. var_dom = { tvm.tir.Var("i", "int32"): tvm.ir.Range(8), tvm.tir.Var("j", "int32"): tvm.ir.Range(4), @@ -336,7 +332,8 @@ def test_region_lower_bound_unfusable(): tvm.ir.Range.from_min_extent((i + j) // 2, 1), ] result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True) - assert result is None + assert result[0].min_value == 0 + assert result[0].max_value == 5 def test_union_lower_bound(): diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 7bc5ead2984a..6a2fdbbb3f1c 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -61,7 +61,6 @@ def assert_iter_sum_pattern( ) indices = res.indices assert len(indices) == len(keys), res.errors - print(indices) for i, input_iter in enumerate(keys): spec = expect_dict[input_iter] ( @@ -446,6 +445,13 @@ def test_predicate(): predicate=xo * 129 + xi < 128, ) + # strided iteration predicate + assert_iter_sum_pattern( + {xo * 16 + xi * 4: (10, 0, 4)}, + var_dom([(xo, 3), (xi, 4)]), + predicate=xo * 4 + xi < 10, + ) + def convert_division(divisions): if divisions is None or len(divisions) == 0: @@ -1010,5 +1016,55 @@ def test_padding(): assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)])) +def test_overlapped_fuse(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + a = tvm.tir.Var("x", "int32") + b = tvm.tir.Var("y", "int32") + + # non-bijective fuse of two + assert_iter_sum_pattern( + { + x * 7 + y: (22, 0, 1), + }, + var_dom([(x, 3), (y, 8)]), + check_level="surjective", + ) + assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective") + + # non-bijective fuse of three + assert_iter_sum_pattern( + { + x * 18 + y * 7 + z: (40, 0, 1), + }, + var_dom([(x, 2), (y, 3), (z, 8)]), + check_level="surjective", + ) + assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective") + + # negative scale fusion is not allowed + assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective") + assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective") + + # with predicate + assert_iter_sum_pattern( + { + a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1), + }, + var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]), + predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10), + check_level="surjective", + ) + + # stride=1 kernel + assert_iter_sum_pattern( + {x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective" + ) + + # do not allow both strided and overlapped + assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective") + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py index c4cfc222e42d..f752c4e48f0e 100644 --- a/tests/python/unittest/test_meta_schedule_space_cpu.py +++ b/tests/python/unittest/test_meta_schedule_space_cpu.py @@ -48,11 +48,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 1, 8): for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3) - l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3) - co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 + i2_2 + i2_3) + n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0) + l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + i1_2 * 2 + i1_3) + co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 + i2_2) rl = T.axis.reduce(3, i3_0 * 3 + i3_1) - rc = T.axis.reduce(64, i4_0 + i4_1) + rc = T.axis.reduce(64, i4_1 + i4_0) T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co]) T.writes(conv1d_nlc_global[n, l, co]) T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) @@ -89,11 +89,11 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32") for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) - l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3) - co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3) + n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) + l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3) + co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2) rl = T.axis.reduce(3, i3_0 * 3 + i3_1) - rc = T.axis.reduce(64, i4_0 + i4_1) + rc = T.axis.reduce(64, i4_1 + i4_0) T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co]) T.writes(conv1d_nlc_global[n, l, co]) T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) @@ -107,7 +107,7 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 T.reads(conv1d_nlc_global[v0, v1, v2]) T.writes(conv1d_nlc[v0, v1, v2]) conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2] - + @T.prim_func def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None: # function attr dict @@ -119,11 +119,11 @@ def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) - l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3) - co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3) + n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) + l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3) + co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2) rl = T.axis.reduce(3, i3_0 * 3 + i3_1) - rc = T.axis.reduce(64, i4_0 + i4_1) + rc = T.axis.reduce(64, i4_1 + i4_0) T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], weight[rl, rc, co]) T.writes(conv1d_nlc[n, l, co]) T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index 1ead63b9c115..804c9e22523c 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -47,7 +47,7 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 for ax0_ax1_ax2_fused in T.serial(260): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4) + v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4) v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4) T.reads(inputs[v0, v1 - 1, v2]) T.writes(PadInput_shared[v0, v1, v2]) @@ -64,11 +64,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 weight_shared[v0, v1, v2] = weight[v0, v1, v2] for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8): with T.block("conv1d_nlc"): - n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0) - l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4) - co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4) - rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2) - rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2) + n = T.axis.spatial(1, i0_4 + i0_3) + l = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4) + co = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4) + rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 + i3_2) + rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 + i4_2) T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co]) T.writes(conv1d_nlc_local[n, l, co]) T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 4351fe5b6361..b859b655efc8 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -214,9 +214,9 @@ def test_reorder_with_opaque_access(): verify_trace_roundtrip(sch=sch, mod=opaque_access) -def test_reorder_with_partial_affineness(): +def test_reorder_overlapped_access(): @T.prim_func - def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): + def overlapped_access(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): # example to write first axis multiple times for v0, v1, v2 in T.grid(6, 4, 4): with T.block("block"): @@ -225,7 +225,7 @@ def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float B[i, j] = A[i, j] + 1.0 @T.prim_func - def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): + def overlapped_access_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): # example to write first axis multiple times for v0, v2, v1 in T.grid(6, 4, 4): with T.block("block"): @@ -233,6 +233,30 @@ def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4) j = T.axis.spatial(4, v2) B[i, j] = A[i, j] + 1.0 + sch = tir.Schedule(overlapped_access, debug_mask="all") + v0, v1, v2 = sch.get_loops(sch.get_block("block")) + sch.reorder(v0, v2, v1) + tvm.ir.assert_structural_equal(overlapped_access_reorder, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=overlapped_access) + + +def test_reorder_with_partial_affineness(): + @T.prim_func + def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): + for v0, v1, v2 in T.grid(6, 4, 4): + with T.block("block"): + i = T.axis.spatial(14, v0 * v0 + v1) + j = T.axis.spatial(4, v2) + B[i, j] = A[i, j] + 1.0 + + @T.prim_func + def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): + for v0, v2, v1 in T.grid(6, 4, 4): + with T.block("block"): + i = T.axis.spatial(14, v0 * v0 + v1) + j = T.axis.spatial(4, v2) + B[i, j] = A[i, j] + 1.0 + sch = tir.Schedule(non_affine_func, debug_mask="all") v0, v1, v2 = sch.get_loops(sch.get_block("block")) with pytest.raises(tvm.tir.ScheduleError): diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 0bfac4e425b9..16bd18c680be 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -176,7 +176,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8): with T.block("B"): - vi = T.axis.S(128, (i1 + i2) * 64 + i3) + vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3) vj = T.axis.S(128, j1 * 32 + j2) vk = T.axis.S(128, k1 * 8 + k2) T.reads([A[vi, vj, vk]]) @@ -190,9 +190,9 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): with T.block("B"): - vi = T.axis.S(128, (i1 + i2) * 64 + i3) - vj = T.axis.S(128, (j1 + j2) * 64 + j3) - vk = T.axis.S(128, (k1 + k2) * 64 + k3) + vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3) + vj = T.axis.S(128, j1 * 64 + j2 * 64 + j3) + vk = T.axis.S(128, k1 * 64 + k2 * 64 + k3) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index 1b4c34973f6c..bbeb8d87600b 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -758,7 +758,7 @@ def test_non_perfect_tiling_cache(): s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags( - affine_binding=False, + affine_binding=True, region_cover=True, stage_pipeline=True, )