Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Use affine.delinearize_index for MMA tiles and vector distribution #19228

Merged
merged 8 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,6 @@ namespace mlir::iree_compiler {
using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

/// Helper to linearize the given |ids| with maximum values given as |sizes|.
/// Gets the element ID in terms of |elementCount| and adds the element
/// |offset|. For example,
///
/// IDs = [d0, d1, d2, d3]
/// sizes = [s0, s1, s2, s3]
/// linear_index = d0 * (s1 * s2 * s3)
/// + d1 * (s2 * s3)
/// + d2 * (s3)
/// + d3
/// return element_index = linear_index * |elementCount| + |offset|;
static Value linearizeIndex(OpBuilder &builder, Value offset,
ArrayRef<OpFoldResult> ids, ArrayRef<int64_t> sizes,
int64_t elementCount) {
SmallVector<AffineExpr> exprs(ids.size() + 1);
bindSymbolsList(builder.getContext(), MutableArrayRef{exprs});
AffineExpr idExpr = builder.getAffineConstantExpr(0);

for (int i = 0, e = ids.size(); i < e; ++i) {
if (sizes[i] > 1) {
// Multiply by the residual threads along this dimension (which must be
// faster changing than all previous dimensions) and add the id for this
// dimension.
idExpr = idExpr * builder.getAffineConstantExpr(sizes[i]) + exprs[i];
}
}
idExpr = idExpr * builder.getAffineConstantExpr(elementCount);
idExpr = idExpr + exprs.back();
SmallVector<OpFoldResult> mapArgs(ids);
mapArgs.push_back(offset);
return affine::makeComposedAffineApply(
builder, offset.getLoc(),
AffineMap::get(0, mapArgs.size(), idExpr), mapArgs)
.getResult();
}

/// Given a set of base transfer |indices|, |offsets| for the batch/outer
/// dimensions, and distributed warp and thread indices, computes the indices
/// of the distributed transfer operation based on the |vectorLayout|.
Expand All @@ -94,16 +58,28 @@ static SmallVector<Value> getTransferIndicesFromNestedLayout(
continue;
}
unsigned pos = cast<AffineDimExpr>(dim).getPosition();
SmallVector<OpFoldResult> ids = {
warpIndices[i], b.getIndexAttr(batchOffsets[i]),
b.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
Value offset = indices[pos];
int64_t elementCount = vectorLayout.getElementTile()[i];
Location loc = offset.getLoc();
SmallVector<Value> ids = {
warpIndices[i], b.create<arith::ConstantIndexOp>(loc, batchOffsets[i]),
b.create<arith::ConstantIndexOp>(loc, outerVectorOffsets[i]),
threadIndices[i], offset};
// The order in which a vector dimension is "tiled" is
// subgroups -> batches -> outer vectors -> threads -> elements
SmallVector<int64_t> sizes = {
vectorLayout.getSubgroupTile()[i], vectorLayout.getBatchTile()[i],
vectorLayout.getOuterTile()[i], vectorLayout.getThreadTile()[i]};
slicedIndices[pos] = linearizeIndex(b, indices[pos], ids, sizes,
vectorLayout.getElementTile()[i]);
vectorLayout.getOuterTile()[i], vectorLayout.getThreadTile()[i],
elementCount};
// The offset is often not an offset within `elementCount`, so, in general,
// we can't mark this `disjoint`. However, if `offset` is known to be
// a constant less than `elementCount`, we can do this, unlocking
// potential optimizations.
bool disjoint = false;
if (std::optional<int64_t> offsetConst = getConstantIntValue(offset))
disjoint = *offsetConst < elementCount;
slicedIndices[pos] =
b.create<affine::AffineLinearizeIndexOp>(loc, ids, sizes, disjoint);
}
return slicedIndices;
}
Expand All @@ -123,19 +99,21 @@ getElementVectorTileShape(NestedLayoutAttr vectorLayout) {

/// Computes the warp and thread indices for the given vector layout from a
/// single linearized thread ID.
static void populateWarpAndThreadIndices(RewriterBase &rewriter, Value threadId,
int64_t subgroupSize,
NestedLayoutAttr vectorLayout,
SmallVector<Value> &warpIndices,
SmallVector<Value> &threadIndices) {
static LogicalResult populateWarpAndThreadIndices(
RewriterBase &rewriter, Value threadId, int64_t subgroupSize,
NestedLayoutAttr vectorLayout, SmallVector<Value> &warpIndices,
SmallVector<Value> &threadIndices) {
// The delinearized thread IDs are returned from outer most to inner most,
// i.e. before applying the layout described dimensions ordering.
int64_t rank = vectorLayout.getRank();
SmallVector<Value> threadIds =
vectorLayout.computeThreadIds(threadId, subgroupSize, rewriter);
if (threadIds.empty() && rank != 0)
return failure();
warpIndices = SmallVector<Value>(threadIds.begin(), threadIds.begin() + rank);
threadIndices = SmallVector<Value>(threadIds.begin() + rank,
threadIds.begin() + 2 * rank);
return success();
}

namespace {
Expand Down Expand Up @@ -189,8 +167,12 @@ struct DistributeTransferRead final
VectorValue acc = cast<VectorValue>(zero);

SmallVector<Value> warpIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, vectorLayout,
warpIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
vectorLayout, warpIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
readOp, "warp or thread tiles have overlapping strides");
}

ValueRange indices = readOp.getIndices();
SmallVector<int64_t> strides(rank, 1);
Expand Down Expand Up @@ -259,8 +241,12 @@ struct DistributeTransferWrite final
int64_t rank = vectorLayout.getRank();

SmallVector<Value> warpIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, vectorLayout,
warpIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
vectorLayout, warpIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
writeOp, "warp or thread tiles have overlapping strides");
}

Value distributedVector =
getDistributed(rewriter, writeOp.getVector(), vectorLayout);
Expand Down Expand Up @@ -1282,8 +1268,12 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
stepOp, "missing nested layout for step op result");
}
SmallVector<Value> subgroupIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
subgroupIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
resultLayout, subgroupIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
stepOp, "warp or thread tiles have overlapping strides");
}

SmallVector<int64_t> undistributedShape =
resultLayout.getUndistributedPackedShape();
Expand Down
Loading
Loading