Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] DetectIterMap support overlapped iteration sum #12039

Merged
merged 1 commit into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 66 additions & 25 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,12 @@ class IterMapRewriter : public ExprMutator {
using Parent = ExprMutator;

explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
bool simplify_trivial_iterators, Array<String>* errors)
: analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) {
IterMapLevel check_level, bool simplify_trivial_iterators,
Array<String>* errors)
: analyzer_(analyzer),
check_level_(check_level),
errors_(*errors),
padding_predicate_(const_false()) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
Expand Down Expand Up @@ -419,6 +423,8 @@ class IterMapRewriter : public ExprMutator {

// Internal analyzer
Analyzer* analyzer_;
// Iter map check level
IterMapLevel check_level_;
// Error messages for each unresolved expression.
Array<String>& errors_;
// The var map
Expand Down Expand Up @@ -651,7 +657,7 @@ class IterMapRewriter : public ExprMutator {
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
Optional<IterSumExpr> opt = TryFuseIters(expr);
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
Expand Down Expand Up @@ -702,7 +708,7 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
Optional<IterSumExpr> opt = TryFuseIters(expr);
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
if (opt.defined()) {
return opt.value();
} else {
Expand Down Expand Up @@ -735,9 +741,10 @@ class IterMapRewriter : public ExprMutator {
* return a corresponding IterSumExpr with extra offset if needed.
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
* \param check_level The check level if iter mapping.
* \return The sum with the fused IterMark and extra offset if succeed.
*/
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
Expand All @@ -758,14 +765,42 @@ class IterMapRewriter : public ExprMutator {
}
// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
PrimExpr tail_extent = 0;
PrimExpr expected_scale = base_scale.value();
for (size_t i = 0; i < expr->args.size();) {
// find j such that expr->args[j] has expected scale
size_t j = i == 0 ? base_index : 0;
for (; j < expr->args.size(); ++j) {
if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
// find position such that expr->args[j] match expected scale
int j = i == 0 ? base_index : expr->args.size() - 1;

size_t matched_pos = expr->args.size();
PrimExpr matched_scale{nullptr};
bool is_exact_match{false};

for (; j >= 0; --j) {
if (visited[j]) {
continue;
}
const PrimExpr& cur_scale = expr->args[j]->scale;

// for bijective mapping, the matched scale must equal to expected scale
if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
matched_pos = j;
matched_scale = cur_scale;
is_exact_match = true;
break;
}
if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
// find the closest scale which is less or equal to expected scale
if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
if (matched_pos == expr->args.size() ||
analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
matched_pos = j;
matched_scale = cur_scale;
}
}
}
}
if (j == expr->args.size()) {
if (matched_pos == expr->args.size()) {
return NullOpt;
}
// look for the longest constrained iter started from expr->args[j]
Expand All @@ -775,8 +810,8 @@ class IterMapRewriter : public ExprMutator {
// otherwise we expect the scale of i to be 2*5=10
Optional<IterSumExpr> constraint_to_match;
for (const IterSumExpr& iter : constrained_iters_flattened_) {
if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
// find a predicate started from expr->args[j]
if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
// find a predicate started from match position
if (!constraint_to_match ||
constraint_to_match.value()->args.size() < iter->args.size()) {
constraint_to_match = iter;
Expand All @@ -793,7 +828,7 @@ class IterMapRewriter : public ExprMutator {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale))
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
break;
}
}
Expand All @@ -806,20 +841,25 @@ class IterMapRewriter : public ExprMutator {
auto iter = sum_fuse_map_.find(constraint_to_match.value());
ICHECK(iter != sum_fuse_map_.end());
const IterMarkWithOffset& iter_matched = iter->second;
grouped_iters.emplace_back(iter_matched.mark, expected_scale);
expected_extra_base += iter_matched.offset * expected_scale;
expected_scale *= iter_matched.mark->extent;
grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
expected_extra_base += iter_matched.offset * matched_scale;
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * iter_matched.mark->extent;
// move forward
i += constraint_to_match.value()->args.size();
} else {
// constraint_to_match not found, skip this iterator
visited[j] = true;
IterSplitExpr arg = expr->args[j];
arg.CopyOnWrite()->scale =
analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
visited[matched_pos] = true;
IterSplitExpr arg = expr->args[matched_pos];
arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
flattened_iters.push_back(arg);
grouped_iters.push_back(arg);
expected_scale *= expr->args[j]->extent;
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * expr->args[matched_pos]->extent;
++i;
}
}
Expand All @@ -843,7 +883,8 @@ class IterMapRewriter : public ExprMutator {
expr->base + expected_extra_base);
} else {
// new iter, form a new mark
IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
IterMark mark =
IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
Expand Down Expand Up @@ -1086,8 +1127,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators,
&result->errors);
IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
simplify_trivial_iterators, &result->errors);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
Expand Down Expand Up @@ -1281,7 +1322,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
} else if (sum->args.size() == 1) {
return sum;
}
auto opt_fused = TryFuseIters(sum);
auto opt_fused = TryFuseIters(sum, check_level_);
if (!opt_fused) {
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend)
<< ", can't be written as a single fused IterSum";
Expand Down
7 changes: 2 additions & 5 deletions tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,6 @@ def do_test_point_access(point, predicates, var_dom, expect):


def test_region_lower_bound_unfusable():
# This test is designed to trigger an error in DetectIterMap,
# resulting from a numerator which required multiple input
# variables. The bug resulted in an exception being thrown,
# rather than a return value of None.
var_dom = {
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
Expand All @@ -336,7 +332,8 @@ def test_region_lower_bound_unfusable():
tvm.ir.Range.from_min_extent((i + j) // 2, 1),
]
result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True)
assert result is None
assert result[0].min_value == 0
assert result[0].max_value == 5


def test_union_lower_bound():
Expand Down
58 changes: 57 additions & 1 deletion tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def assert_iter_sum_pattern(
)
indices = res.indices
assert len(indices) == len(keys), res.errors
print(indices)
for i, input_iter in enumerate(keys):
spec = expect_dict[input_iter]
(
Expand Down Expand Up @@ -446,6 +445,13 @@ def test_predicate():
predicate=xo * 129 + xi < 128,
)

# strided iteration predicate
assert_iter_sum_pattern(
{xo * 16 + xi * 4: (10, 0, 4)},
var_dom([(xo, 3), (xi, 4)]),
predicate=xo * 4 + xi < 10,
)


def convert_division(divisions):
if divisions is None or len(divisions) == 0:
Expand Down Expand Up @@ -1010,5 +1016,55 @@ def test_padding():
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))


def test_overlapped_fuse():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
a = tvm.tir.Var("x", "int32")
b = tvm.tir.Var("y", "int32")

# non-bijective fuse of two
assert_iter_sum_pattern(
{
x * 7 + y: (22, 0, 1),
},
var_dom([(x, 3), (y, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")

# non-bijective fuse of three
assert_iter_sum_pattern(
{
x * 18 + y * 7 + z: (40, 0, 1),
},
var_dom([(x, 2), (y, 3), (z, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")

# negative scale fusion is not allowed
assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")

# with predicate
assert_iter_sum_pattern(
{
a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
},
var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
check_level="surjective",
)

# stride=1 kernel
assert_iter_sum_pattern(
{x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective"
)

# do not allow both strided and overlapped
assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")


if __name__ == "__main__":
tvm.testing.main()
26 changes: 13 additions & 13 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 1, 8):
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 + i2_2 + i2_3)
n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0)
l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 + i2_2)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
rc = T.axis.reduce(64, i4_0 + i4_1)
rc = T.axis.reduce(64, i4_1 + i4_0)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
Expand Down Expand Up @@ -89,11 +89,11 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32")
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
rc = T.axis.reduce(64, i4_0 + i4_1)
rc = T.axis.reduce(64, i4_1 + i4_0)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
Expand All @@ -107,7 +107,7 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
T.reads(conv1d_nlc_global[v0, v1, v2])
T.writes(conv1d_nlc[v0, v1, v2])
conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2]

@T.prim_func
def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
Expand All @@ -119,11 +119,11 @@ def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
rc = T.axis.reduce(64, i4_0 + i4_1)
rc = T.axis.reduce(64, i4_1 + i4_0)
T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
for ax0_ax1_ax2_fused in T.serial(260):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4)
v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4)
T.reads(inputs[v0, v1 - 1, v2])
T.writes(PadInput_shared[v0, v1, v2])
Expand All @@ -64,11 +64,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
weight_shared[v0, v1, v2] = weight[v0, v1, v2]
for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4)
co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4)
rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2)
rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2)
n = T.axis.spatial(1, i0_4 + i0_3)
l = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4)
co = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 + i3_2)
rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 + i4_2)
T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co])
T.writes(conv1d_nlc_local[n, l, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
Expand Down
Loading