Skip to content

Commit

Permalink
[Codegen] Allow memref type propagation through collapse_shape
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 committed Dec 6, 2024
1 parent 1c73358 commit 6055171
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,22 @@ func.func @dynamic_bound_alloc(%id : index) {
}
// CHECK-LABEL: func @dynamic_bound_alloc(
// CHECK: %alloc = memref.alloc() : memref<4088xf32, 3>

// -----

func.func @dynamic_alloc_collapse_consumer(%id : index) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = util.assume.int %id<umin = 0, umax = 32> : index
%1 = memref.alloc(%0, %0) : memref<?x?xf32, 3>
%2 = memref.collapse_shape %1 [[0, 1]] : memref<?x?xf32, 3> into memref<?xf32, 3>
memref.store %cst, %2[%c0] : memref<?xf32, 3>
return
}
// CHECK-LABEL: func @dynamic_alloc_collapse_consumer(
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x32xf32, 3>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]]
// CHECK-SAME: [0, 0] [{{.*}}] [1, 1] : memref<32x32xf32, 3> to memref<?x?xf32, strided<[32, 1]>, 3>
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[}}[0, 1]]
// CHECK-SAME: : memref<?x?xf32, strided<[32, 1]>, 3> into memref<?xf32, strided<[?]>, 3>
// CHECK: memref.store {{.*}} %[[COLLAPSE]]{{.*}} : memref<?xf32, strided<[?]>, 3>
18 changes: 18 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,24 @@ replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use,
});
return llvm::to_vector_of<Value>(newExpandOp->getResults());
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(user)) {
auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
FailureOr<MemRefType> newResultType =
memref::CollapseShapeOp::computeCollapsedType(
newSourceType, collapseOp.getReassociationIndices());
if (failed(newResultType)) {
return std::nullopt;
}

auto newCollapseOp = rewriter.create<memref::CollapseShapeOp>(
loc, *newResultType, replacement, collapseOp.getReassociation());
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return llvm::to_vector_of<Value>(newCollapseOp->getResults());
}
return std::nullopt;
}

Expand Down

0 comments on commit 6055171

Please sign in to comment.