From 596a4b8a404c999ccff717376a83e46d347c4962 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Fri, 6 Dec 2024 11:45:28 -0600 Subject: [PATCH] [Codegen] Allow memref type propagation through collapse_shape Signed-off-by: Max Dawkins --- .../Common/test/pad_dynamic_alloc.mlir | 19 +++++++++++++++++++ .../src/iree/compiler/Codegen/Utils/Utils.cpp | 18 ++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir index e9d4d7b82181..4c08ebdb676c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir @@ -48,3 +48,22 @@ func.func @dynamic_bound_alloca(%id : index) { } // CHECK-LABEL: func @dynamic_bound_alloca( // CHECK: memref.alloca() : memref<4088xf32, 3> + +// ----- + +func.func @dynamic_alloc_collapse_consumer(%id : index) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = util.assume.int %id : index + %1 = memref.alloc(%0, %0) : memref + %2 = memref.collapse_shape %1 [[0, 1]] : memref into memref + memref.store %cst, %2[%c0] : memref + return +} +// CHECK-LABEL: func @dynamic_alloc_collapse_consumer( +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x32xf32, 3> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]] +// CHECK-SAME: [0, 0] [{{.*}}] [1, 1] : memref<32x32xf32, 3> to memref, 3> +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[}}[0, 1]] +// CHECK-SAME: : memref, 3> into memref, 3> +// CHECK: memref.store {{.*}} %[[COLLAPSE]]{{.*}} : memref, 3> diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 37d061a00c2d..ddc0b9a52070 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -1342,6 +1342,24 @@ replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use, }); return llvm::to_vector_of(newExpandOp->getResults()); } + if (auto collapseOp = dyn_cast(user)) { + auto newSourceType = llvm::cast(replacement.getType()); + FailureOr newResultType = + memref::CollapseShapeOp::computeCollapsedType( + newSourceType, collapseOp.getReassociationIndices()); + if (failed(newResultType)) { + return std::nullopt; + } + + auto newCollapseOp = rewriter.create( + loc, *newResultType, replacement, collapseOp.getReassociation()); + LLVM_DEBUG({ + llvm::dbgs() << "\t\tNew user : "; + newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified()); + llvm::dbgs() << "\n"; + }); + return llvm::to_vector_of(newCollapseOp->getResults()); + } return std::nullopt; }