Skip to content

Commit

Permalink
[GPU] Add chained reshape support for scf.forall expand destination p…
Browse files Browse the repository at this point in the history
…attern

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
  • Loading branch information
nirvedhmeshram committed Jan 6, 2025
1 parent b245e6b commit d0b0364
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ static LogicalResult verifyAndCollectExpandableUsers(
return failure();
if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets())
return failure();
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
if (!expandShapeOp)
return failure();
SmallVector<ReassociationIndices> expandReIndices =
expandShapeOp.getReassociationIndices();
if (reIndices != expandReIndices)
return failure();
for (Operation *user : extractSliceOp->getUsers()) {
auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
if (!expandShapeOp)
return failure();
SmallVector<ReassociationIndices> expandReIndices =
expandShapeOp.getReassociationIndices();
if (reIndices != expandReIndices)
return failure();
}
expandableUsers.push_back(extractSliceOp);
}
return success();
Expand Down Expand Up @@ -155,9 +156,14 @@ expandVerifiedUsers(PatternRewriter &rewriter, Location loc, MLIRContext *ctx,
expandedOffsets, expandedSizes, expandedStrides);
for (tensor::ExtractSliceOp extractSliceOp : expandableUsers) {
rewriter.setInsertionPoint(extractSliceOp);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSliceOp, resultType, extractSliceOp.getSource(), expandedOffsets,
expandedSizes, expandedStrides);
auto newExtractSliceOp =
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSliceOp, resultType, extractSliceOp.getSource(),
expandedOffsets, expandedSizes, expandedStrides);
for (Operation *user : newExtractSliceOp->getUsers()) {
auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
expandShapeOp->replaceAllUsesWith(newExtractSliceOp);
}
}
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,54 @@ func.func @noexpand_dest_forall_notfullslicestore() {
// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
// CHECK-SAME: offsets = [1], sizes = [32], strides = [1] : tensor<32xf32>
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<34xf32>>

// -----
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
#hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
func.func @expand_dest_forall_chained() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%index = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<?x64x32xf32>>{%index}
%1 = tensor.empty(%index) : tensor<?x64x32xf32>
%extra = tensor.empty() : tensor<32x32xf32>
%2 = scf.forall (%arg0, %arg1) = (0, 0) to (64, 32) step (16, 16)
shared_outs(%arg2 = %1) -> (tensor<?x64x32xf32>) {
%extracted_slice = tensor.extract_slice %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
: tensor<?x64x32xf32> to tensor<1x16x16xf32>
%expanded = tensor.expand_shape %extracted_slice [[0], [1], [2, 3, 4]]
output_shape [1, 16, 2, 4, 2] : tensor<1x16x16xf32> into tensor<1x16x2x4x2xf32>
%expanded2 = tensor.expand_shape %expanded [[0], [1, 2], [3], [4], [5]]
output_shape [1, 8, 2, 2, 4, 2] : tensor<1x16x2x4x2xf32> into tensor<1x8x2x2x4x2xf32>
%expanded_barrier = util.optimization_barrier %expanded2 : tensor<1x8x2x2x4x2xf32>
%collapsed = tensor.collapse_shape %expanded_barrier [[0], [1, 2], [3], [4], [5]] : tensor<1x8x2x2x4x2xf32> into tensor<1x16x2x4x2xf32>
%collapsed2 = tensor.collapse_shape %collapsed [[0], [1], [2, 3, 4]] : tensor<1x16x2x4x2xf32> into tensor<1x16x16xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %collapsed2 into %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
: tensor<1x16x16xf32> into tensor<?x64x32xf32>
}
} {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
flow.dispatch.tensor.store %2, %0, offsets = [0, 0, 0], sizes = [%index, 64, 32], strides = [1, 1, 1]
: tensor<?x64x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x64x32xf32>>{%index}
return
}

// CHECK-LABEL: func @expand_dest_forall_chained(
// CHECK: %[[LOAD_CONST:.+]] = hal.interface.constant.load
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[LOAD_CONST]]) : tensor<?x32x2x4x4x2xf32>
// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0)
// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<?x32x2x4x4x2xf32>) {
// CHECK-DAG: %[[OFFSET0:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%[[ARG1]]]
// CHECK-DAG: %[[OFFSET1:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[ARG0]]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]]
// CHECK-SAME: [0, %[[OFFSET1]], 0, %[[OFFSET0]], 0, 0] [1, 8, 2, 2, 4, 2] [1, 1, 1, 1, 1, 1]
// CHECK-SAME: tensor<?x32x2x4x4x2xf32> to tensor<1x8x2x2x4x2xf32>
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[EXTRACT]] : tensor<1x8x2x2x4x2xf32>
// CHECK: tensor.parallel_insert_slice %[[BARRIER]] into %[[ARG2]]
// CHECK-SAME: [0, %[[OFFSET1]], 0, %[[OFFSET0]], 0, 0] [1, 8, 2, 2, 4, 2] [1, 1, 1, 1, 1, 1]
// CHECK-SAME: tensor<1x8x2x2x4x2xf32> into tensor<?x32x2x4x4x2xf32>
// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
// CHECK-SAME: offsets = [0, 0, 0, 0, 0, 0], sizes = [%[[LOAD_CONST]], 32, 2, 4, 4, 2], strides = [1, 1, 1, 1, 1, 1]
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x32x2x4x4x2xf32>>{%[[LOAD_CONST]]}

0 comments on commit d0b0364

Please sign in to comment.