-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add SwapShapeCastOfTranspose
canonicalizer pattern
#100933
Conversation
A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions. This simplifies the transpose making it more likely to be matched by further patterns. Example: BEFORE: ```mlir %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> ``` AFTER: ```mlir %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> ``` Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME.
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir-vector Author: Benjamin Maxwell (MacDue) ChangesA pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions. This simplifies the transpose making it more likely to be matched by further patterns. Example: BEFORE: %0 = vector.transpose %vector, [3, 0, 1, 2]
: vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> AFTER: %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME. Full diff: https://github.com/llvm/llvm-project/pull/100933.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 53df7af00aee8..ed6cd3d0cdbbc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};
-/// Returns an iterator over the dims (inc scalability) of a VectorType.
-static auto getDims(VectorType vType) {
- return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
-}
-
-/// Helper to drop (fixed-size) unit dims from a VectorType.
-static VectorType dropUnitDims(VectorType vType) {
- SmallVector<bool> scalableFlags;
- SmallVector<int64_t> dimSizes;
- for (auto dim : getDims(vType)) {
- if (dim == std::make_tuple(1, false))
- continue;
- auto [size, scalableFlag] = dim;
- dimSizes.push_back(size);
- scalableFlags.push_back(scalableFlag);
- }
- return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
-}
-
-/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
-/// shape_cast only drops unit dimensions.
-///
-/// This simplifies the transpose making it possible for other legalization
-/// rewrites to handle it.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.transpose %vector, [3, 0, 1, 2]
-/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-/// ```
-struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto transposeOp =
- shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
- if (!transposeOp)
- return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
-
- auto resultType = shapeCastOp.getResultVectorType();
- if (resultType.getRank() <= 1)
- return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
-
- if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
- return rewriter.notifyMatchFailure(
- shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
-
- auto transposeSourceVectorType = transposeOp.getSourceVectorType();
- auto transposeSourceDims =
- llvm::to_vector(getDims(transposeSourceVectorType));
-
- // Construct a map from dimIdx -> number of dims dropped before dimIdx.
- SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
- int64_t droppedDims = 0;
- for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
- droppedDimsBefore[i] = droppedDims;
- if (dim == std::make_tuple(1, false))
- ++droppedDims;
- }
-
- // Drop unit dims from transpose permutation.
- auto perm = transposeOp.getPermutation();
- SmallVector<int64_t> newPerm;
- for (int64_t idx : perm) {
- if (transposeSourceDims[idx] == std::make_tuple(1, false))
- continue;
- newPerm.push_back(idx - droppedDimsBefore[idx]);
- }
-
- auto loc = shapeCastOp.getLoc();
- auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
- loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
- newShapeCastOp, newPerm);
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -1027,8 +939,9 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
- SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
- context);
+ LowerIllegalTransposeStoreViaZA>(context);
+ vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..2da46aa86d74b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5480,12 +5480,100 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+ return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop (fixed-size) unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+ SmallVector<bool> scalableFlags;
+ SmallVector<int64_t> dimSizes;
+ for (auto dim : getDims(vType)) {
+ if (dim == std::make_tuple(1, false))
+ continue;
+ auto [size, scalableFlag] = dim;
+ dimSizes.push_back(size);
+ scalableFlags.push_back(scalableFlag);
+ }
+ return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it more likely to be matched by further
+/// patterns.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.transpose %vector, [3, 0, 1, 2]
+/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeOp =
+ shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp)
+ return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (resultType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+ if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+ auto transposeSourceVectorType = transposeOp.getSourceVectorType();
+ auto transposeSourceDims =
+ llvm::to_vector(getDims(transposeSourceVectorType));
+
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+ SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
+ int64_t droppedDims = 0;
+ for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
+ droppedDimsBefore[i] = droppedDims;
+ if (dim == std::make_tuple(1, false))
+ ++droppedDims;
+ }
+
+ // Drop unit dims from transpose permutation.
+ auto perm = transposeOp.getPermutation();
+ SmallVector<int64_t> newPerm;
+ for (int64_t idx : perm) {
+ if (transposeSourceDims[idx] == std::make_tuple(1, false))
+ continue;
+ newPerm.push_back(idx - droppedDimsBefore[idx]);
+ }
+
+ auto loc = shapeCastOp.getLoc();
+ auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
+ rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
+ newShapeCastOp, newPerm);
+ return success();
+ }
+};
+
} // namespace
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
- ShapeCastBroadcastFolder>(context);
+ ShapeCastBroadcastFolder, SwapShapeCastOfTranspose>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index adc02adb6e974..458906a187982 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose(
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
-func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
- // CHECK: return %[[TRANSPOSE]]
- %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
- %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
- return %1 : vector<[4]x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
-func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
- // CHECK: return %[[TRANSPOSE]]
- %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
- %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
- return %1 : vector<[4]x4xf32>
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46..f1a1120bf874e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -867,6 +867,32 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
// -----
+// CHECK-LABEL: @swap_shape_cast_of_transpose(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ // CHECK: return %[[TRANSPOSE]]
+ %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+ return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ // CHECK: return %[[TRANSPOSE]]
+ %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
+ return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
@llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesA pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions. This simplifies the transpose making it more likely to be matched by further patterns. Example: BEFORE: %0 = vector.transpose %vector, [3, 0, 1, 2]
: vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> AFTER: %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME. Full diff: https://github.com/llvm/llvm-project/pull/100933.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 53df7af00aee8..ed6cd3d0cdbbc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};
-/// Returns an iterator over the dims (inc scalability) of a VectorType.
-static auto getDims(VectorType vType) {
- return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
-}
-
-/// Helper to drop (fixed-size) unit dims from a VectorType.
-static VectorType dropUnitDims(VectorType vType) {
- SmallVector<bool> scalableFlags;
- SmallVector<int64_t> dimSizes;
- for (auto dim : getDims(vType)) {
- if (dim == std::make_tuple(1, false))
- continue;
- auto [size, scalableFlag] = dim;
- dimSizes.push_back(size);
- scalableFlags.push_back(scalableFlag);
- }
- return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
-}
-
-/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
-/// shape_cast only drops unit dimensions.
-///
-/// This simplifies the transpose making it possible for other legalization
-/// rewrites to handle it.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.transpose %vector, [3, 0, 1, 2]
-/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-/// ```
-struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto transposeOp =
- shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
- if (!transposeOp)
- return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
-
- auto resultType = shapeCastOp.getResultVectorType();
- if (resultType.getRank() <= 1)
- return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
-
- if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
- return rewriter.notifyMatchFailure(
- shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
-
- auto transposeSourceVectorType = transposeOp.getSourceVectorType();
- auto transposeSourceDims =
- llvm::to_vector(getDims(transposeSourceVectorType));
-
- // Construct a map from dimIdx -> number of dims dropped before dimIdx.
- SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
- int64_t droppedDims = 0;
- for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
- droppedDimsBefore[i] = droppedDims;
- if (dim == std::make_tuple(1, false))
- ++droppedDims;
- }
-
- // Drop unit dims from transpose permutation.
- auto perm = transposeOp.getPermutation();
- SmallVector<int64_t> newPerm;
- for (int64_t idx : perm) {
- if (transposeSourceDims[idx] == std::make_tuple(1, false))
- continue;
- newPerm.push_back(idx - droppedDimsBefore[idx]);
- }
-
- auto loc = shapeCastOp.getLoc();
- auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
- loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
- newShapeCastOp, newPerm);
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -1027,8 +939,9 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
- SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
- context);
+ LowerIllegalTransposeStoreViaZA>(context);
+ vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..2da46aa86d74b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5480,12 +5480,100 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+ return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop (fixed-size) unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+ SmallVector<bool> scalableFlags;
+ SmallVector<int64_t> dimSizes;
+ for (auto dim : getDims(vType)) {
+ if (dim == std::make_tuple(1, false))
+ continue;
+ auto [size, scalableFlag] = dim;
+ dimSizes.push_back(size);
+ scalableFlags.push_back(scalableFlag);
+ }
+ return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it more likely to be matched by further
+/// patterns.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.transpose %vector, [3, 0, 1, 2]
+/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeOp =
+ shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp)
+ return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (resultType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+ if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+ auto transposeSourceVectorType = transposeOp.getSourceVectorType();
+ auto transposeSourceDims =
+ llvm::to_vector(getDims(transposeSourceVectorType));
+
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+ SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
+ int64_t droppedDims = 0;
+ for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
+ droppedDimsBefore[i] = droppedDims;
+ if (dim == std::make_tuple(1, false))
+ ++droppedDims;
+ }
+
+ // Drop unit dims from transpose permutation.
+ auto perm = transposeOp.getPermutation();
+ SmallVector<int64_t> newPerm;
+ for (int64_t idx : perm) {
+ if (transposeSourceDims[idx] == std::make_tuple(1, false))
+ continue;
+ newPerm.push_back(idx - droppedDimsBefore[idx]);
+ }
+
+ auto loc = shapeCastOp.getLoc();
+ auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
+ rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
+ newShapeCastOp, newPerm);
+ return success();
+ }
+};
+
} // namespace
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
- ShapeCastBroadcastFolder>(context);
+ ShapeCastBroadcastFolder, SwapShapeCastOfTranspose>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index adc02adb6e974..458906a187982 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose(
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
-func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
- // CHECK: return %[[TRANSPOSE]]
- %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
- %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
- return %1 : vector<[4]x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
-func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
- // CHECK: return %[[TRANSPOSE]]
- %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
- %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
- return %1 : vector<[4]x4xf32>
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46..f1a1120bf874e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -867,6 +867,32 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
// -----
+// CHECK-LABEL: @swap_shape_cast_of_transpose(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ // CHECK: return %[[TRANSPOSE]]
+ %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+ return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ // CHECK: return %[[TRANSPOSE]]
+ %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
+ return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, but we should make sure that this is also fine for SPIR-V - @kuhar ? Also, given that this is a canonicalisation, lets wait a few days so that people have a chance to take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking on the SPIR-V side. Given this is conditioned on unit dims, I think it should be fine for us, but it's hard for me to give a definite answer without trying this.
FYI, we are working on better upstream testing for the SPIR-V lowering so that there's less guessing in the future: #95942
@kuhar Shall we land this and just revert if something breaks on the SPIR-V side? |
Yes, this SGTM. |
Isn't this pointing at an implementation gap in the drop unit dimension transformations? If we apply drop unit dimensions here, we should get a shape_cast before the transpose, which should cancel the shape_cast after the transpose? |
I think this alone is a reasonable canonicalization (as in it takes some IR and rewrites in a more canonical form, without introducing new operations). I'm happy to look at more unit dim dropping patterns, but I think that is a separate thing. |
Probably good to get feedback from @MaheshRavishankar as this topic is sensitive. I don't have a strong opinion but I think we left unit dims out of the canonicalization scope because we may lose information in some cases and for that reason we have the drop unit dim passes/patterns taking that responsibility. I would advocate for consistency if there isn't a strong case to go otherwise so that we have all the unit dims in a single place. |
My intiution here is that such a pattern can be a canonicalization. If I had to give a reasoning here, it is effectively bubbling up the shape cast operation, and that could be applied repeatedly as a fixed point to move the shape casts up as much as possible which I think is good for the program as a whole. But for consistency, I would also see if there are a group of patterns that do the same thing and create a fixed point outside of canonicalizations to apply the fixed point where necessary (basically everything in the compiler stack should be deliberate and whoever is putting things together must know what they are doing when they add something to a pass pipeline. The kitchen sink approach to canonicalizer is a problem in general). But I wont really push back against this being a canonicalizer pattern. |
Thanks for checking @dcaballe |
Also wondering... isn't this |
Yes, I tested it and this #102017 solves the same issue in IREE. |
A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions.
This simplifies the transpose making it more likely to be matched by further patterns.
Example:
BEFORE:
AFTER:
Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME.