Skip to content

Commit

Permalink
[GPU] Add pattern to fuse tensor.extract_slice into forall producer (#…
Browse files Browse the repository at this point in the history
…19296)

This PR adds a pattern to fuse a consumer tensor.extract_slice into a
producer scf.forall op. The transform is added to
FuseAndHoistParallelLoops, where it helps to fuse tensor.unpack ops with
extract_slice semantics into producer loops. This is needed when
targeting MFMA intrinsics for unaligned shapes, and also in generating
code for unset encoding ops on GPU. This is a follow up to
#19295, which has the complementing
pattern for collapse_shape.

The PR also adds a transform op to keep the long lit tests separate from
the FuseAndHoistParallelLoop tests.

---------

Signed-off-by: Max Dawkins <maxdawkins19@gmail.com>
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Co-authored-by: Max Dawkins <maxdawkins19@gmail.com>
  • Loading branch information
Max191 and Max Dawkins authored Jan 28, 2025
1 parent ecd67d9 commit 6a5c12e
Show file tree
Hide file tree
Showing 9 changed files with 741 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,27 @@ struct FuseCollapseShapeConsumers final
}
};

struct FuseExtractSliceConsumers final
: OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
PatternRewriter &rewriter) const override {
// Find the scf::ForallOp producer, and get the corresponding
// tensor::ParallelInsertSliceOp.
auto forallOp = extractSliceOp.getSource().getDefiningOp<scf::ForallOp>();
if (!forallOp) {
return rewriter.notifyMatchFailure(extractSliceOp,
"No forall op producer");
}

if (failed(fuseExtractSliceIntoProducerForall(rewriter, forallOp,
extractSliceOp))) {
return failure();
}
return success();
}
};

void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();

Expand Down Expand Up @@ -391,6 +412,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
patterns.add<FuseUnitLoopDestination>(context);
patterns.add<FuseTilableForallConsumers>(context);
patterns.add<FuseCollapseShapeConsumers>(context);
patterns.add<FuseExtractSliceConsumers>(context);
populateSwapExtractWithExpandPattern(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,30 @@ func.func @no_fuse_collapse_shape_rank_reduced(%arg0: tensor<8x8xf32>) -> tensor
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_extract_slice_rank_reduced(%arg0: tensor<4x8xf32>, %size1: index) -> tensor<?xf32> {
%0 = tensor.empty() : tensor<4x8xf32>
%1 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0) -> (tensor<4x8xf32>) {
%2 = affine.apply #map(%arg2)
%extracted_slice_0 = tensor.extract_slice %arg0[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32>
%extracted_slice_1 = tensor.extract_slice %arg3[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32>
%3 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg3[0, %2] [1, 2] [1, 1] : tensor<2xf32> into tensor<4x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%extracted_slice = tensor.extract_slice %1[0, 0] [1, %size1] [1, 1] : tensor<4x8xf32> to tensor<?xf32>
return %extracted_slice : tensor<?xf32>
}

// CHECK-LABEL: func @no_fuse_extract_slice_rank_reduced
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<4x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2xf32> into tensor<4x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[FORALL_RESULT]]
// CHECK: return %[[EXTRACT]]
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,54 @@ void transform_dialect::FuseCollapseShapeIntoForallOp::getEffects(
transform::modifiesPayload(effects);
}

//===---------------------------------------------------------------------===//
// FuseExtractSliceIntoForallOp
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform_dialect::FuseExtractSliceIntoForallOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto producers = state.getPayloadOps(getProducer());
auto consumers = state.getPayloadOps(getConsumer());

int64_t numProducers = llvm::range_size(producers);
int64_t numConsumers = llvm::range_size(consumers);
if (numProducers != 1 || numConsumers != 1) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"More than one producer or consumer");
}

auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
if (!producer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-forall producer");
}
auto consumer = dyn_cast<tensor::ExtractSliceOp>(*consumers.begin());
if (!consumer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-extract_slice consumer");
}

FailureOr<scf::ForallOp> fusedForallOp =
GPU::fuseExtractSliceIntoProducerForall(rewriter, producer, consumer);
if (failed(fusedForallOp)) {
return mlir::emitSilenceableFailure(*this,
"failed to fuse extract_slice op");
}

results.set(getOperation()->getOpResult(0), {fusedForallOp.value()});
return DiagnosedSilenceableFailure::success();
}

void transform_dialect::FuseExtractSliceIntoForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getProducerMutable(), effects);
transform::consumesHandle(getConsumerMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}

} // namespace mlir::iree_compiler::IREE

void mlir::iree_compiler::registerTransformDialectIREEGPUExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,38 @@ def FuseCollapseShapeIntoForallOp : Op<Transform_Dialect, "iree.fuse_collapse_sh
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

def FuseExtractSliceIntoForallOp : Op<Transform_Dialect, "iree.fuse_extract_slice_into_forall",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Fuses a consumer tensor.extract_slice op into a producer scf.forall op.
This transform is supported if the extract_slice op has all zero offsets,
and if all the offsets, sizes, and strides dominate the scf.forall op.
After the transformation, the forall loop output argument corresponding
to the sliced result will be replaced with a slice of it with the same
offsets, sizes, and strides as the original extract_slice. The source of
the corresponding tensor.parallel_insert_slice of the scf.forall will also
become a slice of the original parallel insert source, clamped to fit within
the new sliced result tensor.

#### Return modes
Emits a definite failure if either the producer is not an scf.forall op or
if the consumer is not a tensor.extract_slice op.
}];

let arguments = (
ins TransformHandleTypeInterface:$producer,
TransformHandleTypeInterface:$consumer
);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = [{
$consumer `into` $producer attr-dict
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"lower_multi_mma.mlir",
"lower_vector_barrier.mlir",
"transform_fuse_collapse_shape_with_forall.mlir",
"transform_fuse_extract_slice_with_forall.mlir",
"transform_fuse_forall.mlir",
"transform_lower_barrier_region.mlir",
"vectorize_iree_gpu_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_lit_test_suite(
"lower_multi_mma.mlir"
"lower_vector_barrier.mlir"
"transform_fuse_collapse_shape_with_forall.mlir"
"transform_fuse_extract_slice_with_forall.mlir"
"transform_fuse_forall.mlir"
"transform_lower_barrier_region.mlir"
"unroll_multi_mma.mlir"
Expand Down
Loading

0 comments on commit 6a5c12e

Please sign in to comment.