Skip to content

Commit

Permalink
[Codegen] Add support for memref.expand_shape to propagation util (#1…
Browse files Browse the repository at this point in the history
…8202)

Similar to `memref.subview`, `memref.expand_shape` needs to have its
type updated when propagating type changes. This adds support for expand
shape to the propagation util so that passes like GPUReduceBankConflicts
can handle `memref.expand_shape`.
  • Loading branch information
qedawkins authored Aug 13, 2024
1 parent 9c951ca commit 2ea9b14
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ func.func @pad_alloc(%a: memref<1024x1024xf32>) {

// -----

// CHECK-LABEL: func.func @pad_alloc_expand_shape
// CHECK: %[[A:.*]] = memref.alloc() : memref<4x32x66xf32, #gpu.address_space<workgroup>>
// CHECK: %[[S1:.*]] = memref.subview %[[A]][0, 0, 0] [4, 32, 64] [1, 1, 1] :
// CHECK-SAME: memref<4x32x66xf32, #gpu.address_space<workgroup>> to memref<4x32x64xf32, strided<[2112, 66, 1]>, #gpu.address_space<workgroup>>
// CHECK: %[[E:.*]] = memref.expand_shape %[[S1]] {{\[}}[0], [1, 2], [3, 4]] output_shape [4, 2, 16, 8, 8]
// CHECK-SAME: memref<4x32x64xf32, strided<[2112, 66, 1]>, #gpu.address_space<workgroup>> into
// CHECK-SAME: memref<4x2x16x8x8xf32, strided<[2112, 1056, 66, 8, 1]>, #gpu.address_space<workgroup>>
// CHECK: vector.transfer_write %{{.*}}, %[[E]][%{{.*}}, %{{.*}}, %{{.*}}] {in_bounds = [true]} :
// CHECK-SAME: vector<4xf32>, memref<4x2x16x8x8xf32, strided<[2112, 1056, 66, 8, 1]>, #gpu.address_space<workgroup>
func.func @pad_alloc_expand_shape(%a: memref<1024x1024xf32>) {
%0 = memref.alloc() : memref<4x32x64xf32, #gpu.address_space<workgroup>>
%1 = memref.expand_shape %0 [[0], [1, 2], [3, 4]] output_shape [4, 2, 16, 8, 8]
: memref<4x32x64xf32, #gpu.address_space<workgroup>> into memref<4x2x16x8x8xf32, #gpu.address_space<workgroup>>
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%3 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} :
memref<1024x1024xf32>, vector<4xf32>
vector.transfer_write %3, %1[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true]} :
vector<4xf32>, memref<4x2x16x8x8xf32, #gpu.address_space<workgroup>>
return
}

// -----

// CHECK-LABEL: func.func @pad_alloc_negative
// CHECK: memref.alloc(%{{.*}}) : memref<?x32x64xf32, #gpu.address_space<workgroup>
func.func @pad_alloc_negative(%a: memref<1024x1024xf32>, %i: index, %v: vector<4xf32>) {
Expand Down
26 changes: 24 additions & 2 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,8 +988,30 @@ replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use,
newSubviewOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return SmallVector<Value>(newSubviewOp->result_begin(),
newSubviewOp->result_end());
return llvm::to_vector_of<Value>(newSubviewOp->getResults());
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(user)) {
auto currResultType =
llvm::cast<MemRefType>(expandOp.getResult().getType());
auto newSourceType = llvm::cast<MemRefType>(replacement.getType());

FailureOr<MemRefType> newResultType =
memref::ExpandShapeOp::computeExpandedType(
newSourceType, currResultType.getShape(),
expandOp.getReassociationIndices());
if (failed(newResultType)) {
return std::nullopt;
}

auto newExpandOp = rewriter.create<memref::ExpandShapeOp>(
loc, *newResultType, replacement, expandOp.getReassociation(),
expandOp.getOutputShape(), expandOp.getStaticOutputShape());
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newExpandOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return llvm::to_vector_of<Value>(newExpandOp->getResults());
}
return std::nullopt;
}
Expand Down

0 comments on commit 2ea9b14

Please sign in to comment.