Skip to content

Commit

Permalink
[Codegen] Allow memref type propagation through collapse_shape (#19400)
Browse files Browse the repository at this point in the history
This PR adds support for propagating memref type changes through
memref.collapse_shape ops in the `replaceMemrefUsesAndPropagateType`
util function. This propagation is used in allocation padding, since the
strides of the memref type change after padding.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Jan 17, 2025
1 parent 36c2353 commit dde5992
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,22 @@ func.func @dynamic_bound_alloca(%id : index) {
}
// CHECK-LABEL: func @dynamic_bound_alloca(
// CHECK: memref.alloca() : memref<4088xf32, 3>

// -----

func.func @dynamic_alloc_collapse_consumer(%id : index) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = util.assume.int %id<umin = 0, umax = 32> : index
%1 = memref.alloc(%0, %0) : memref<?x?xf32, 3>
%2 = memref.collapse_shape %1 [[0, 1]] : memref<?x?xf32, 3> into memref<?xf32, 3>
memref.store %cst, %2[%c0] : memref<?xf32, 3>
return
}
// CHECK-LABEL: func @dynamic_alloc_collapse_consumer(
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x32xf32, 3>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]]
// CHECK-SAME: [0, 0] [{{.*}}] [1, 1] : memref<32x32xf32, 3> to memref<?x?xf32, strided<[32, 1]>, 3>
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[}}[0, 1]]
// CHECK-SAME: : memref<?x?xf32, strided<[32, 1]>, 3> into memref<?xf32, strided<[?]>, 3>
// CHECK: memref.store {{.*}} %[[COLLAPSE]]{{.*}} : memref<?xf32, strided<[?]>, 3>
18 changes: 18 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,24 @@ replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use,
});
return llvm::to_vector_of<Value>(newExpandOp->getResults());
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(user)) {
auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
FailureOr<MemRefType> newResultType =
memref::CollapseShapeOp::computeCollapsedType(
newSourceType, collapseOp.getReassociationIndices());
if (failed(newResultType)) {
return std::nullopt;
}

auto newCollapseOp = rewriter.create<memref::CollapseShapeOp>(
loc, *newResultType, replacement, collapseOp.getReassociation());
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return llvm::to_vector_of<Value>(newCollapseOp->getResults());
}
return std::nullopt;
}

Expand Down

0 comments on commit dde5992

Please sign in to comment.