Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add SwapShapeCastOfTranspose canonicalizer pattern #100933

Closed
wants to merge 2 commits into from

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jul 28, 2024

A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions.

This simplifies the transpose making it more likely to be matched by further patterns.

Example:

BEFORE:

%0 = vector.transpose %vector, [3, 0, 1, 2]
       : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>

AFTER:

%0 = vector.shape_cast %vector : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>

Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME.

A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
shape_cast only drops unit dimensions.

This simplifies the transpose making it more likely to be matched by
further patterns.

Example:

BEFORE:
```mlir
%0 = vector.transpose %vector, [3, 0, 1, 2]
       : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
```

AFTER:
```mlir
%0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
```

Note: This moves this pattern from the ArmSME dialect to a general
vector pattern as it is useful for lowerings outside of ArmSME.
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir-vector

Author: Benjamin Maxwell (MacDue)

Changes

A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions.

This simplifies the transpose making it more likely to be matched by further patterns.

Example:

BEFORE:

%0 = vector.transpose %vector, [3, 0, 1, 2]
       : vector&lt;1x1x4x[4]xf32&gt; to vector&lt;[4]x1x1x4xf32&gt;
%1 = vector.shape_cast %0 : vector&lt;[4]x1x1x4xf32&gt; to vector&lt;[4]x4xf32&gt;

AFTER:

%0 = vector.shape_cast %arg0 : vector&lt;1x1x4x[4]xf32&gt; to vector&lt;4x[4]xf32&gt;
%1 = vector.transpose %0, [1, 0] : vector&lt;4x[4]xf32&gt; to vector&lt;[4]x4xf32&gt;

Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME.


Full diff: https://github.com/llvm/llvm-project/pull/100933.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+3-90)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+89-1)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-26)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 53df7af00aee8..ed6cd3d0cdbbc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
   }
 };
 
-/// Returns an iterator over the dims (inc scalability) of a VectorType.
-static auto getDims(VectorType vType) {
-  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
-}
-
-/// Helper to drop (fixed-size) unit dims from a VectorType.
-static VectorType dropUnitDims(VectorType vType) {
-  SmallVector<bool> scalableFlags;
-  SmallVector<int64_t> dimSizes;
-  for (auto dim : getDims(vType)) {
-    if (dim == std::make_tuple(1, false))
-      continue;
-    auto [size, scalableFlag] = dim;
-    dimSizes.push_back(size);
-    scalableFlags.push_back(scalableFlag);
-  }
-  return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
-}
-
-/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
-/// shape_cast only drops unit dimensions.
-///
-/// This simplifies the transpose making it possible for other legalization
-/// rewrites to handle it.
-///
-/// Example:
-///
-///  BEFORE:
-///  ```mlir
-///  %0 = vector.transpose %vector, [3, 0, 1, 2]
-///         : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-///  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-///  %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-///  ```
-struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto transposeOp =
-        shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
-    if (!transposeOp)
-      return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
-
-    auto resultType = shapeCastOp.getResultVectorType();
-    if (resultType.getRank() <= 1)
-      return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
-
-    if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
-      return rewriter.notifyMatchFailure(
-          shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
-
-    auto transposeSourceVectorType = transposeOp.getSourceVectorType();
-    auto transposeSourceDims =
-        llvm::to_vector(getDims(transposeSourceVectorType));
-
-    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
-    SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
-    int64_t droppedDims = 0;
-    for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
-      droppedDimsBefore[i] = droppedDims;
-      if (dim == std::make_tuple(1, false))
-        ++droppedDims;
-    }
-
-    // Drop unit dims from transpose permutation.
-    auto perm = transposeOp.getPermutation();
-    SmallVector<int64_t> newPerm;
-    for (int64_t idx : perm) {
-      if (transposeSourceDims[idx] == std::make_tuple(1, false))
-        continue;
-      newPerm.push_back(idx - droppedDimsBefore[idx]);
-    }
-
-    auto loc = shapeCastOp.getLoc();
-    auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
-        loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
-                                                     newShapeCastOp, newPerm);
-    return success();
-  }
-};
-
 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
 /// the ZA state. This workaround rewrite to support these transposes when ZA is
 /// available.
@@ -1027,8 +939,9 @@ struct VectorLegalizationPass
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
                  ConvertIllegalShapeCastOpsToTransposes,
-                 SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
-        context);
+                 LowerIllegalTransposeStoreViaZA>(context);
+    vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
     // Note: These two patterns are added with a high benefit to ensure:
     //  - Masked outer products are handled before unmasked ones
     //  - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..2da46aa86d74b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5480,12 +5480,100 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop (fixed-size) unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+  SmallVector<bool> scalableFlags;
+  SmallVector<int64_t> dimSizes;
+  for (auto dim : getDims(vType)) {
+    if (dim == std::make_tuple(1, false))
+      continue;
+    auto [size, scalableFlag] = dim;
+    dimSizes.push_back(size);
+    scalableFlags.push_back(scalableFlag);
+  }
+  return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it more likely to be matched by further
+/// patterns.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.transpose %vector, [3, 0, 1, 2]
+///         : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+///  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+///  %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+///  ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeOp =
+        shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp)
+      return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (resultType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+    if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+    auto transposeSourceVectorType = transposeOp.getSourceVectorType();
+    auto transposeSourceDims =
+        llvm::to_vector(getDims(transposeSourceVectorType));
+
+    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+    SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
+    int64_t droppedDims = 0;
+    for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
+      droppedDimsBefore[i] = droppedDims;
+      if (dim == std::make_tuple(1, false))
+        ++droppedDims;
+    }
+
+    // Drop unit dims from transpose permutation.
+    auto perm = transposeOp.getPermutation();
+    SmallVector<int64_t> newPerm;
+    for (int64_t idx : perm) {
+      if (transposeSourceDims[idx] == std::make_tuple(1, false))
+        continue;
+      newPerm.push_back(idx - droppedDimsBefore[idx]);
+    }
+
+    auto loc = shapeCastOp.getLoc();
+    auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
+        loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
+                                                     newShapeCastOp, newPerm);
+    return success();
+  }
+};
+
 } // namespace
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+              ShapeCastBroadcastFolder, SwapShapeCastOfTranspose>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index adc02adb6e974..458906a187982 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
   vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>,  memref<?x?xf32>
   return
 }
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose(
-// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
-func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
-  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-  // CHECK: return %[[TRANSPOSE]]
-  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-  return %1 : vector<[4]x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
-// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
-func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
-  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-  // CHECK: return %[[TRANSPOSE]]
-  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
-  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
-  return %1 : vector<[4]x4xf32>
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46..f1a1120bf874e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -867,6 +867,32 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 
 // -----
 
+// CHECK-LABEL: @swap_shape_cast_of_transpose(
+// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
+  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  // CHECK: return %[[TRANSPOSE]]
+  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+  return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
+// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
+  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  // CHECK: return %[[TRANSPOSE]]
+  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
+  return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2024

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions.

This simplifies the transpose making it more likely to be matched by further patterns.

Example:

BEFORE:

%0 = vector.transpose %vector, [3, 0, 1, 2]
       : vector&lt;1x1x4x[4]xf32&gt; to vector&lt;[4]x1x1x4xf32&gt;
%1 = vector.shape_cast %0 : vector&lt;[4]x1x1x4xf32&gt; to vector&lt;[4]x4xf32&gt;

AFTER:

%0 = vector.shape_cast %arg0 : vector&lt;1x1x4x[4]xf32&gt; to vector&lt;4x[4]xf32&gt;
%1 = vector.transpose %0, [1, 0] : vector&lt;4x[4]xf32&gt; to vector&lt;[4]x4xf32&gt;

Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME.


Full diff: https://github.com/llvm/llvm-project/pull/100933.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+3-90)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+89-1)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-26)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 53df7af00aee8..ed6cd3d0cdbbc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
   }
 };
 
-/// Returns an iterator over the dims (inc scalability) of a VectorType.
-static auto getDims(VectorType vType) {
-  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
-}
-
-/// Helper to drop (fixed-size) unit dims from a VectorType.
-static VectorType dropUnitDims(VectorType vType) {
-  SmallVector<bool> scalableFlags;
-  SmallVector<int64_t> dimSizes;
-  for (auto dim : getDims(vType)) {
-    if (dim == std::make_tuple(1, false))
-      continue;
-    auto [size, scalableFlag] = dim;
-    dimSizes.push_back(size);
-    scalableFlags.push_back(scalableFlag);
-  }
-  return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
-}
-
-/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
-/// shape_cast only drops unit dimensions.
-///
-/// This simplifies the transpose making it possible for other legalization
-/// rewrites to handle it.
-///
-/// Example:
-///
-///  BEFORE:
-///  ```mlir
-///  %0 = vector.transpose %vector, [3, 0, 1, 2]
-///         : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-///  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-///  %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-///  ```
-struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto transposeOp =
-        shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
-    if (!transposeOp)
-      return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
-
-    auto resultType = shapeCastOp.getResultVectorType();
-    if (resultType.getRank() <= 1)
-      return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
-
-    if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
-      return rewriter.notifyMatchFailure(
-          shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
-
-    auto transposeSourceVectorType = transposeOp.getSourceVectorType();
-    auto transposeSourceDims =
-        llvm::to_vector(getDims(transposeSourceVectorType));
-
-    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
-    SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
-    int64_t droppedDims = 0;
-    for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
-      droppedDimsBefore[i] = droppedDims;
-      if (dim == std::make_tuple(1, false))
-        ++droppedDims;
-    }
-
-    // Drop unit dims from transpose permutation.
-    auto perm = transposeOp.getPermutation();
-    SmallVector<int64_t> newPerm;
-    for (int64_t idx : perm) {
-      if (transposeSourceDims[idx] == std::make_tuple(1, false))
-        continue;
-      newPerm.push_back(idx - droppedDimsBefore[idx]);
-    }
-
-    auto loc = shapeCastOp.getLoc();
-    auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
-        loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
-                                                     newShapeCastOp, newPerm);
-    return success();
-  }
-};
-
 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
 /// the ZA state. This workaround rewrite to support these transposes when ZA is
 /// available.
@@ -1027,8 +939,9 @@ struct VectorLegalizationPass
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
                  ConvertIllegalShapeCastOpsToTransposes,
-                 SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
-        context);
+                 LowerIllegalTransposeStoreViaZA>(context);
+    vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
     // Note: These two patterns are added with a high benefit to ensure:
     //  - Masked outer products are handled before unmasked ones
     //  - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..2da46aa86d74b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5480,12 +5480,100 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop (fixed-size) unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+  SmallVector<bool> scalableFlags;
+  SmallVector<int64_t> dimSizes;
+  for (auto dim : getDims(vType)) {
+    if (dim == std::make_tuple(1, false))
+      continue;
+    auto [size, scalableFlag] = dim;
+    dimSizes.push_back(size);
+    scalableFlags.push_back(scalableFlag);
+  }
+  return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it more likely to be matched by further
+/// patterns.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.transpose %vector, [3, 0, 1, 2]
+///         : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+///  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+///  %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+///  ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeOp =
+        shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp)
+      return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (resultType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+    if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+    auto transposeSourceVectorType = transposeOp.getSourceVectorType();
+    auto transposeSourceDims =
+        llvm::to_vector(getDims(transposeSourceVectorType));
+
+    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+    SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
+    int64_t droppedDims = 0;
+    for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
+      droppedDimsBefore[i] = droppedDims;
+      if (dim == std::make_tuple(1, false))
+        ++droppedDims;
+    }
+
+    // Drop unit dims from transpose permutation.
+    auto perm = transposeOp.getPermutation();
+    SmallVector<int64_t> newPerm;
+    for (int64_t idx : perm) {
+      if (transposeSourceDims[idx] == std::make_tuple(1, false))
+        continue;
+      newPerm.push_back(idx - droppedDimsBefore[idx]);
+    }
+
+    auto loc = shapeCastOp.getLoc();
+    auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
+        loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
+                                                     newShapeCastOp, newPerm);
+    return success();
+  }
+};
+
 } // namespace
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+              ShapeCastBroadcastFolder, SwapShapeCastOfTranspose>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index adc02adb6e974..458906a187982 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
   vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>,  memref<?x?xf32>
   return
 }
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose(
-// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
-func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
-  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-  // CHECK: return %[[TRANSPOSE]]
-  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-  return %1 : vector<[4]x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
-// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
-func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
-  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-  // CHECK: return %[[TRANSPOSE]]
-  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
-  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
-  return %1 : vector<[4]x4xf32>
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46..f1a1120bf874e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -867,6 +867,32 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 
 // -----
 
+// CHECK-LABEL: @swap_shape_cast_of_transpose(
+// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
+  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  // CHECK: return %[[TRANSPOSE]]
+  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+  return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
+// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
+  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  // CHECK: return %[[TRANSPOSE]]
+  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
+  return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

@banach-space banach-space requested a review from kuhar July 29, 2024 08:19
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, but we should make sure that this is also fine for SPIR-V - @kuhar ? Also, given that this is a canonicalisation, lets wait a few days so that people have a chance to take a look.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking on the SPIR-V side. Given this is conditioned on unit dims, I think it should be fine for us, but it's hard for me to give a definite answer without trying this.
FYI, we are working on better upstream testing for the SPIR-V lowering so that there's less guessing in the future: #95942

@banach-space
Copy link
Contributor

banach-space commented Jul 30, 2024

Thanks for checking on the SPIR-V side. Given this is conditioned on unit dims, I think it should be fine for us, but it's hard for me to give a definite answer without trying this. FYI, we are working on better upstream testing for the SPIR-V lowering so that there's less guessing in the future: #95942

@kuhar Shall we land this and just revert if something breaks on the SPIR-V side?

@kuhar
Copy link
Member

kuhar commented Jul 31, 2024

Yes, this SGTM.

@dcaballe
Copy link
Contributor

Isn't this pointing at an implementation gap in the drop unit dimension transformations? If we apply drop unit dimensions here, we should get a shape_cast before the transpose, which should cancel the shape_cast after the transpose?

@MacDue
Copy link
Member Author

MacDue commented Jul 31, 2024

I think this alone is a reasonable canonicalization (as in it takes some IR and rewrites in a more canonical form, without introducing new operations). I'm happy to look at more unit dim dropping patterns, but I think that is a separate thing.

@dcaballe
Copy link
Contributor

dcaballe commented Aug 1, 2024

Probably good to get feedback from @MaheshRavishankar as this topic is sensitive. I don't have a strong opinion but I think we left unit dims out of the canonicalization scope because we may lose information in some cases and for that reason we have the drop unit dim passes/patterns taking that responsibility. I would advocate for consistency if there isn't a strong case to go otherwise so that we have all the unit dims in a single place.

@MaheshRavishankar
Copy link
Contributor

My intiution here is that such a pattern can be a canonicalization. If I had to give a reasoning here, it is effectively bubbling up the shape cast operation, and that could be applied repeatedly as a fixed point to move the shape casts up as much as possible which I think is good for the program as a whole. But for consistency, I would also see if there are a group of patterns that do the same thing and create a fixed point outside of canonicalizations to apply the fixed point where necessary (basically everything in the compiler stack should be deliberate and whoever is putting things together must know what they are doing when they add something to a pass pipeline. The kitchen sink approach to canonicalizer is a problem in general). But I wont really push back against this being a canonicalizer pattern.

@MaheshRavishankar
Copy link
Contributor

Probably good to get feedback from @MaheshRavishankar as this topic is sensitive. I don't have a strong opinion but I think we left unit dims out of the canonicalization scope because we may lose information in some cases and for that reason we have the drop unit dim passes/patterns taking that responsibility. I would advocate for consistency if there isn't a strong case to go otherwise so that we have all the unit dims in a single place.

Thanks for checking @dcaballe

@dcaballe
Copy link
Contributor

dcaballe commented Aug 1, 2024

       : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>

Also wondering... isn't this vector.shape_cast introduced by the drop unit dimension pass itself?

@MacDue
Copy link
Member Author

MacDue commented Aug 5, 2024

@dcaballe I've posted a patch for a dropping unit dims from vector.transpose using the same logic as this patch here: #102017

@banach-space
Copy link
Contributor

@dcaballe I've posted a patch for a dropping unit dims from vector.transpose using the same logic as this patch here: #102017

IIUC, we can close this in favour of #102017, right?

@MacDue
Copy link
Member Author

MacDue commented Aug 7, 2024

Yes, I tested it and this #102017 solves the same issue in IREE.

@MacDue MacDue closed this Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants