Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Merge surjective/non-surjective iter mapping detections #11287

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
- check incompatible left paddings
- determine case like x % 16, x in [0, 5) to be non-surjective, since usages may treat the region extent as 16 by mistake.
- skip second round of rewrite when there is no padding
- fix some typo in comments
  • Loading branch information
wrongtest-intellif committed May 31, 2022
commit 8d46bb5a23ded3baa91a6cd3f021247041c1a612
7 changes: 5 additions & 2 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,11 @@ class IterMapResultNode : public Object {
*/
class IterMapResult : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IterMapResult, ObjectRef, IterMapResultNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMapResultNode);
// constructor
IterMapResult() { data_ = make_object<IterMapResultNode>(); }

/*! \return mutable pointers to the node. */
IterMapResultNode* operator->() const { return static_cast<IterMapResultNode*>(get_mutable()); }
};

/*!
Expand Down
152 changes: 92 additions & 60 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,17 @@ class IterMapRewriter : public ExprMutator {
}

PrimExpr padding_predicate() const { return padding_predicate_; }
PrimExpr requires_padding() const { return !analyzer_->CanProveEqual(padding_predicate_, 0); }
bool requires_padding() const { return requires_padding_; }

IterSumExpr Rewrite(const PrimExpr& expr) {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}

void UpdatePadding(const PrimExpr& expr) {
IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) {
update_iterator_padding_ = true;
DirectMutate(expr);
auto res = Rewrite(expr);
update_iterator_padding_ = false;
return res;
}

IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
Expand Down Expand Up @@ -425,27 +426,30 @@ class IterMapRewriter : public ExprMutator {
// input iter marks
std::vector<IterMark> input_marks_;

// Map from a split iter to the padded iterator information for
// Map from an iter mark to the padded iterator information for
// it. This is necessary for introducing the same padding in all
// usage of an input iterator. (e.g. (i-1) occurring in the
// expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be
// left-padded by 31 for each occurrence.)
std::unordered_map<IterMark, IterPaddingInfo, StructuralHash, StructuralEqual> padded_iter_map_;

/* If allow_padding_ is true, allow the extents of the IterMap to be
// Map from padded iter mark to it's origin mark
std::unordered_map<IterMark, IterMark, StructuralHash, StructuralEqual> padded_origin_map_;

/* If update_iterator_padding_ is true, allow the extents of the IterMap to be
* padded beyond the original iterators.
*
* For example, if allow_padding_ is true, the expressions i//4 and
* For example, if update_iterator_padding_ is true, the expressions i//4 and
* i%4, where i is on the range [0,18), would be represented as
* IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4).
* This representation would be forbidden if allow_padding_ is false,
* This representation would be forbidden if update_iterator_padding_ is false,
* because lower_factor=4 does not evenly divide the original extent of
* 18.
*/
bool update_iterator_padding_{false};

/* A boolean expression that is true for any padding that has been
* introduced, and false otherwise. If allow_padding_ is false,
* introduced, and false otherwise. If update_iterator_padding_ is false,
* padding_predicate_ will always be false.
*
* Example: [i//4, i%4], i in range [0,16)
Expand All @@ -459,6 +463,11 @@ class IterMapRewriter : public ExprMutator {
*/
PrimExpr padding_predicate_;

/* A boolean flag denotes there are padding iterations detected
* in the first round of indices rewriting.
*/
bool requires_padding_{false};

// The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
// an extra offset)
// Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
Expand Down Expand Up @@ -488,25 +497,6 @@ class IterMapRewriter : public ExprMutator {
// The flattened forms of constrained iters
std::vector<IterSumExpr> constrained_iters_flattened_;

/*!
* \brief Extract original iteration mark's extent before padding, return NullOpt is
* there is no extra padding.
*/
Optional<PrimExpr> ExtractExtentBeforePadding(const IterMark& mark, Analyzer* analyzer) {
const IterSumExprNode* sum = mark->source.as<IterSumExprNode>();
if (!sum || sum->args.size() != 1) {
return NullOpt;
}
IterSplitExpr split = sum->args[0];
if (!analyzer->CanProveEqual(split->extent, mark->extent) &&
analyzer->CanProveEqual(split->scale, 1) &&
analyzer->CanProveEqual(split->lower_factor, 1) &&
analyzer->CanProveEqual(split->source->extent, split->extent)) {
return sum->args[0]->extent;
}
return NullOpt;
}

/*!
* \brief Look for a split in splits that is not used such that its lower_factor is smallest.
* Note that here we use division to compare lower_factor.
Expand Down Expand Up @@ -580,9 +570,9 @@ class IterMapRewriter : public ExprMutator {
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}

// Extract padding info of the iteration mark, extent before padding
// is only defined when padding exists.
Optional<PrimExpr> extent_before_padding = ExtractExtentBeforePadding(mark, analyzer_);
// Extract iteration mark info before padding
auto pad_mark_it = padded_origin_map_.find(mark);
bool has_padding = pad_mark_it != padded_origin_map_.end();

bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent);
bool match_iter_divisor =
Expand All @@ -598,33 +588,46 @@ class IterMapRewriter : public ExprMutator {
//
// Case 3. bijective is not required and there exists padding. We check either
// (3.1) The extent we calculate is consistent with the extent of the padded mark and it is
// the single split for the iter mark.
// the single split for the iter mark.
// For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective
// according to how we pad the original iteration mark.
// (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent
// before padding is greater or equal than the extent we calculate.
// For example, padded iter p in [0, 24), the original extent is 14, [(p % 12)] is
// valid.
// before padding is greater or equal than the extent we calculate.
// For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24.
//
if (check_level == IterMapLevel::Bijective) {
if (extent_before_padding.defined() || !match_full_iter) {
return Array<IterSplitExpr>();
if (has_padding) {
ErrorLogger(this) << "Bijectvie mapping should not take iter paddings";
return {};
} else if (!match_full_iter) {
ErrorLogger(this) << "The iterations do not traverse full iter space";
return {};
}
} else if (!extent_before_padding.defined()) {
} else if (!has_padding) {
if (!match_iter_divisor) {
return Array<IterSplitExpr>();
ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
return {};
}
} else if (check_level == IterMapLevel::Surjective) {
PrimExpr extent_before_padding = pad_mark_it->second->extent;
if (match_full_iter) {
if (splits.size() != 1) {
ErrorLogger(this) << "Dependent iterations on padding iter space";
return Array<IterSplitExpr>();
} else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) &&
!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
ErrorLogger(this) << "Split on padding iteration is not surjective "
<< "if the split extent equals to the full iter space extent";
return Array<IterSplitExpr>();
}
} else if (match_iter_divisor) {
if (!analyzer_->CanProve(extent_before_padding.value() >= expected_lower_factor)) {
if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
ErrorLogger(this) << "The extent before padding is less than lower factor";
return Array<IterSplitExpr>();
}
} else {
return Array<IterSplitExpr>();
ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
return {};
}
}
return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
Expand Down Expand Up @@ -1056,22 +1059,21 @@ bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, IterMapLevel check_level,
arith::Analyzer* analyzer, bool simplify_trivial_iterators) {
IterMapResult result_obj = IterMapResult(make_object<IterMapResultNode>());
auto result = result_obj.CopyOnWrite();
IterMapResult result;

// Overall detection algorithm is divided into two steps:
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
if (!IterRangeSanityCheck(input_iters)) {
result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other.");
return result_obj;
return result;
}
Map<Var, Range> constrained_input_iters = input_iters;
std::vector<IterConstraint> constraints;
if (!is_one(predicate) &&
!MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) {
result->errors.push_back("Could not parse predicate as constraints on the input iterators.");
return result_obj;
return result;
}
// We have to make sure when we visit an iterator, all the constraints related with its successors
// in the iter var graph has been visited, where the expression of this iterator will contain the
Expand All @@ -1090,32 +1092,39 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
constraint.upper_bound);
if (result->errors.size()) {
return result_obj;
if (result->errors.size() > 0) {
return result;
}
}
if (!rewriter.CheckConstraints()) {
result->errors.push_back("Invalid constraints.");
return result_obj;
return result;
}

// Step0.1: Check each index to determine required padding
// Step0.1: Rewrite indicies and determine required padding,
// if there is no padding, it should be the final result.
Array<IterSumExpr> rewrite_indices;
rewrite_indices.reserve(indices.size());
bool allow_padding = check_level != IterMapLevel::Bijective;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would enable padding for IterMapLevel::Surjective, which I don't think is correct. Since padding is any output value for which no input value exists, any introduction of padding wouldn't be surjective.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the claim~ I try to change padding to iter mark itself.

For example,(x + 7) x in [0, 8) => IterMark(IterSplit(IterSum({x}, 7), lower_factor=1, extent=16, scale=1), extent=16 with left_pad=7, right_pad=1

Then (x + 7) // 8 is mapped to range [0, extent//2) == [0, 2), though we have padding into iter mark, the IterSplit's range can be achieved when we only iterate x in it's original domain: (0 + 7) // 8 = 0, (7 + 7) // 8 = 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and that does maintain surjectivity for a single index. I'm not entirely sure for the case of two indices, though. For the same x ∈ [0,8), the indices [(x+7)//8, (x+7)%8] would have the same padding left_pad=7 and right_pad=1. Even though each individual index can take any value in the output ((x+7)//8 ∈[0,2) and (x+7)%8 ∈ [0,8)), there are some coordinate pairs that cannot be generated for any value of x (e.g. [0,0] and [1,7]).

Copy link
Contributor Author

@wrongtest-intellif wrongtest-intellif May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree! This is where we should be careful. In CheckMapping with surjective mode when padding exists, we check padded // LCM and padded % LCM(or it's sub-splits) must not both exists. The case below depict this check:

sum = 80 + y
dom_map = var_dom([(y, 176)])

# (80 + y) // 32 itself could be surjective
assert_iter_sum_pattern(
    {fld(sum, 32): (6, 2, 1)},
    dom_map,
 )

# (80 + y) % 2, ((80 + y) // 2) % 16) could be surjective,
# since they can be seen as sub-splits of (80 + y) % 32
assert_iter_sum_pattern(
    {flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)},
    dom_map,
)

# but (80 + y) // 32, (80 + y) % 32 are not surjective
assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map)

Other kinds of negatives like (80 + y) // 32, (80 + y) // 4 would be banned by existing checking rule.

if (allow_padding) {
for (PrimExpr value : indices) {
rewriter.UpdatePadding(value);
rewrite_indices.push_back(rewriter.RewriteAndUpdatePadding(value));
if (result->errors.size() > 0) {
return result;
}
}
}

// Step0.2: rewrite indices
Array<IterSumExpr> rewrite_indices;
for (PrimExpr value : indices) {
rewrite_indices.push_back(rewriter.Rewrite(value));
if (result->errors.size()) {
return result_obj;
// Step0.2: Rewrite indices in the second round.
if (!allow_padding || rewriter.requires_padding()) {
rewrite_indices.clear();
for (PrimExpr value : indices) {
rewrite_indices.push_back(rewriter.Rewrite(value));
if (result->errors.size() > 0) {
return result;
}
}
}

result->padding_predicate = rewriter.padding_predicate();

// Step1: IterIndependenceChecker checks if the iterator are independent.
Expand All @@ -1125,10 +1134,10 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
} else {
result->errors.push_back("Mapped indices are not independent.");
}
return result_obj;
return result;
}
result->indices = rewrite_indices;
return result_obj;
return result;
}

TVM_REGISTER_GLOBAL("arith.DetectIterMap")
Expand Down Expand Up @@ -1304,6 +1313,10 @@ PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyze
auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second));
if (analyzer->CanProveEqual(p1.first, p2.first)) {
return p1.first * const_lcm;
} else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) {
return p1.first * const_lcm;
} else if (analyzer->CanProveEqual(floormod(p2.first, p1.first), 0)) {
return p2.first * const_lcm;
} else {
return (p1.first * p2.first) * const_lcm;
}
Expand Down Expand Up @@ -1348,6 +1361,9 @@ std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSpl
}

// Update padding requirement on the lower side of the source iter mark.
// In the second pass, all splits would check whether the maximum left pading
// on the iter mark is compatible with it's own left padding.
requires_padding_ = true;
PrimExpr mark_left_pad = left_pad * split->lower_factor;
info.left_pad = max(info.left_pad, mark_left_pad);
wrongtest-intellif marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -1381,6 +1397,22 @@ std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSpl
if (!info.padded.defined()) {
// the first time encounter the iter mark to pad, update the padded mark.
PrimExpr mark_left_pad = info.left_pad;
if (CanProveDivisible(mark_left_pad, split->lower_factor)) {
// correct current split's left padding
// (mark_left_pad + iter) // lower_factor % extent =>
// (left_pad * lower_factor + mark) // lower_factor % extent =>
// (left_pad + mark // lower_factor) % extent =>
// left_pad + (mark // lower_factor % extent) =>
// left_pad + split
// since the extent covers the full padding range.
left_pad = floordiv(mark_left_pad, split->lower_factor);
} else {
ErrorLogger(this) << "Detect incompatible left padding on "
<< tvm::PrettyPrint(NormalizeIterMapToExpr(split))
<< ", the iter mark is left padded with " << mark_left_pad;
return {IterSplitExpr(), PrimExpr()};
}

PrimExpr right_edge = mark->extent + mark_left_pad;
PrimExpr mark_right_pad;
if (CanProveDivisible(right_edge, info.padding_factor)) {
Expand All @@ -1391,6 +1423,7 @@ std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSpl
PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad);
info.right_pad = mark_right_pad;
info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent);
padded_origin_map_[info.padded] = mark;

auto left_padding_introduced = (mark_left_pad != 0);

Expand Down Expand Up @@ -1557,7 +1590,6 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P

// We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below
// where x=floormod(floordiv(iter, lower_factor), extent) + base

auto pair = PadDividendToDivisor(lhs, base, rhs);
IterSplitExpr padded = pair.first;
if (!padded.defined()) {
Expand Down Expand Up @@ -1676,7 +1708,7 @@ bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs
PrimExpr divisor = normalizer.Convert(rhs);

return analyzer_->CanProveEqual(dividend, divisor) ||
analyzer_->CanProve(analyzer_->Simplify(floormod(dividend, divisor), 8) == 0);
analyzer_->CanProve(floormod(dividend, divisor) == 0);
}

PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) {
Expand Down
Loading