From 0cb024b357aff294b1ba0f9d3de8f48ab684962b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 15 Jan 2024 09:00:43 +0100 Subject: [PATCH] [mlir][Mesh] Fix invalid IR in rewrite pattern (#78094) This commit fixes `test/Dialect/Mesh/folding.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. ``` /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: error: Unexpected number of results 0. Expected 2. %0:2 = mesh.cluster_shape @mesh1 : index, index ^ /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: note: see current operation: "mesh.cluster_shape"() <{axes = array, mesh = @mesh1}> : () -> () mlir-asm-printer: Verifying operation: builtin.module Unexpected number of results 0. Expected 2. mlir-asm-printer: 'builtin.module' failed to verify and will be printed in generic form "builtin.module"() ({ "mesh.cluster"() <{dim_sizes = array, rank = 2 : i64, sym_name = "mesh1"}> : () -> () "func.func"() <{function_type = () -> (index, index), sym_name = "cluster_shape_op_folding_all_axes_static_mesh"}> ({ %0 = "arith.constant"() <{value = 2 : index}> : () -> index %1 = "arith.constant"() <{value = 3 : index}> : () -> index "mesh.cluster_shape"() <{axes = array, mesh = @mesh1}> : () -> () %2:2 = "mesh.cluster_shape"() <{axes = array, mesh = @mesh1}> : () -> (index, index) "func.return"(%0, %1) : (index, index) -> () }) : () -> () }) : () -> () LLVM ERROR: IR failed to verify after pattern application ``` If `axes` is empty, the op verifier assumes that all dimensions are queried. (Expected 2 results.) --- .../lib/Dialect/Mesh/Transforms/Simplifications.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp index c9275ad5ad4551..67e1bf6320dbf3 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp @@ -103,13 +103,14 @@ struct ClusterShapeFolder : OpRewritePattern { } // Leave only the dynamic mesh axes to be queried. - ClusterShapeOp newShapeOp = - builder.create(mesh.getSymName(), newShapeOpMeshAxes); - for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { - newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; + if (!newShapeOpMeshAxes.empty()) { + ClusterShapeOp newShapeOp = + builder.create(mesh.getSymName(), newShapeOpMeshAxes); + for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { + newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; + } } - - rewriter.replaceAllUsesWith(op.getResults(), newResults); + rewriter.replaceOp(op, newResults); return success(); }