From 605517173018d7cca901752b3aceb7b98da67c62 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Fri, 6 Dec 2024 11:45:28 -0600 Subject: [PATCH] [Codegen] Allow memref type propagation through collapse_shape Signed-off-by: Max Dawkins --- .../Common/test/pad_dynamic_alloc.mlir | 19 +++++++++++++++++++ .../src/iree/compiler/Codegen/Utils/Utils.cpp | 18 ++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir index 0b56bd2cb963..18a15c6a521b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir @@ -38,3 +38,22 @@ func.func @dynamic_bound_alloc(%id : index) { } // CHECK-LABEL: func @dynamic_bound_alloc( // CHECK: %alloc = memref.alloc() : memref<4088xf32, 3> + +// ----- + +func.func @dynamic_alloc_collapse_consumer(%id : index) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = util.assume.int %id : index + %1 = memref.alloc(%0, %0) : memref + %2 = memref.collapse_shape %1 [[0, 1]] : memref into memref + memref.store %cst, %2[%c0] : memref + return +} +// CHECK-LABEL: func @dynamic_alloc_collapse_consumer( +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x32xf32, 3> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]] +// CHECK-SAME: [0, 0] [{{.*}}] [1, 1] : memref<32x32xf32, 3> to memref, 3> +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[}}[0, 1]] +// CHECK-SAME: : memref, 3> into memref, 3> +// CHECK: memref.store {{.*}} %[[COLLAPSE]]{{.*}} : memref, 3> diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index f86f447c49dc..7cf6cb0e4d30 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -1305,6 +1305,24 @@ replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use, }); return llvm::to_vector_of(newExpandOp->getResults()); } + if (auto collapseOp = dyn_cast(user)) { + auto newSourceType = llvm::cast(replacement.getType()); + FailureOr newResultType = + memref::CollapseShapeOp::computeCollapsedType( + newSourceType, collapseOp.getReassociationIndices()); + if (failed(newResultType)) { + return std::nullopt; + } + + auto newCollapseOp = rewriter.create( + loc, *newResultType, replacement, collapseOp.getReassociation()); + LLVM_DEBUG({ + llvm::dbgs() << "\t\tNew user : "; + newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified()); + llvm::dbgs() << "\n"; + }); + return llvm::to_vector_of(newCollapseOp->getResults()); + } return std::nullopt; }