Skip to content

Commit

Permalink
[mlir][Mesh] Fix invalid IR in rewrite pattern (#78094)
Browse files Browse the repository at this point in the history
This commit fixes `test/Dialect/Mesh/folding.mlir` when running with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

```
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: error: Unexpected number of results 0. Expected 2.
  %0:2 = mesh.cluster_shape @mesh1 : index, index
         ^
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: note: see current operation: "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
mlir-asm-printer: Verifying operation: builtin.module
Unexpected number of results 0. Expected 2.
mlir-asm-printer: 'builtin.module' failed to verify and will be printed in generic form
"builtin.module"() ({
  "mesh.cluster"() <{dim_sizes = array<i64: 2, 3>, rank = 2 : i64, sym_name = "mesh1"}> : () -> ()
  "func.func"() <{function_type = () -> (index, index), sym_name = "cluster_shape_op_folding_all_axes_static_mesh"}> ({
    %0 = "arith.constant"() <{value = 2 : index}> : () -> index
    %1 = "arith.constant"() <{value = 3 : index}> : () -> index
    "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
    %2:2 = "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> (index, index)
    "func.return"(%0, %1) : (index, index) -> ()
  }) : () -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```

If `axes` is empty, the op verifier assumes that all dimensions are
queried. (Expected 2 results.)
  • Loading branch information
matthias-springer authored Jan 15, 2024
1 parent 844f833 commit 0cb024b
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
}

// Leave only the dynamic mesh axes to be queried.
ClusterShapeOp newShapeOp =
builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
if (!newShapeOpMeshAxes.empty()) {
ClusterShapeOp newShapeOp =
builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
}
}

rewriter.replaceAllUsesWith(op.getResults(), newResults);
rewriter.replaceOp(op, newResults);

return success();
}
Expand Down

0 comments on commit 0cb024b

Please sign in to comment.