From 39c56de27df9802479d64d6a08e4d12d2662e37c Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Sat, 7 Dec 2024 19:47:31 -0800 Subject: [PATCH] [Dispatch] Disable UnpackLikeOp+ExtractSlice fusion. (#19408) It is no longer needed because unset_encoding ops carries the slicing semantics. Instead of adding complexity on the checks (whether the consumer is rank-reducing slice or not), we can disable the fusion at all. The revision updates the test cases that were created before the unset_encoding evolution and add a negative test for the issue. Fixes https://github.com/iree-org/iree/issues/19386 Signed-off-by: hanhanW --- .../DispatchCreation/FormDispatchRegions.cpp | 38 ++---------------- .../test/dispatch_linalg_on_tensors.mlir | 11 ++--- .../test/form_dispatch_regions.mlir | 40 ++++++++++++------- 3 files changed, 34 insertions(+), 55 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index e866022eb9a9..73d306bd7f6e 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -182,28 +182,9 @@ static bool isPackLikeOp(Operation *op) { return isa(op); } -/// Returns true if the operation is an `unpack` op or an `unset_encoding` op, -/// or an `extract_slice` op whose source operand matches those criteria, -/// recursively. -/// The idea is that we want to ensure that `extract_slice` ops can't prevent -/// fusion between a `unset_encoding` producer and some linalg consumer. In -/// %0 = unset_encoding ... -/// %1 = extract_slice %0 ... -/// %2 = linalg.generic ins(%1) ... -/// we are not content to be fusing %1 into %0, we also want to be fusing %2, -/// so we want to prevent %1 from acting as a consumer fusion barrier. -static bool isUnpackLikeOpViaExtractSliceOps(Operation *op) { - if (isa(op)) { - return true; - } - if (isa(op)) { - Value source = op->getOperand(0); - Operation *producer = source.getDefiningOp(); - if (isUnpackLikeOpViaExtractSliceOps(producer)) { - return true; - } - } - return false; +/// Returns true if the operation is an `unpack` op or an `unset_encoding` op. +static bool isUnpackLikeOp(Operation *op) { + return isa(op); } /// Since `iree_encoding.set_encoding` doesnt have padding semantics a @@ -476,18 +457,7 @@ isFusableWithConsumer(OpOperand &fusedOperand, // Fuse unset_encoding operations with `tensor.extract_slice` and elementwise // generic ops. - if (isUnpackLikeOpViaExtractSliceOps(producer)) { - // Fuse `unset_encoding` -> `extract_slice` op since they get folded into - // `unpack` on materialization. - if (isa(consumer)) { - auto sliceOp = cast(consumer); - return llvm::all_of( - sliceOp.getMixedOffsets(), - [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) && - llvm::all_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return isConstantIntValue(ofr, 1); - }); - } + if (isUnpackLikeOp(producer)) { // Fuse `unset_encoding/unpack` -> elementwise operations. Fuse unpack with // non-overlapping reductions (i.e., the reduction dimension is not packed). if (auto consumerLinalgOp = dyn_cast(consumer)) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir index 2b20e30f843a..acbeed9bff35 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir @@ -1958,17 +1958,15 @@ util.func public @pad_and_set_encoding_op(%arg0 : tensor) // ----- #encoding = #iree_encoding.encoding -util.func public @unset_encoding_and_slice( +util.func public @unset_encoding_with_encoded_slice( %arg0: tensor, %arg1 : index, %arg2 : index) -> tensor { %0 = iree_encoding.unset_encoding %arg0 : tensor -> tensor{%arg1, %arg2} - %1 = tensor.extract_slice %0[0, 0] [%arg1, %arg2] [1, 1] - : tensor to tensor - util.return %1 : tensor + util.return %0 : tensor } // CHECK: #[[ENCODING:.+]] = #iree_encoding.encoding -// CHECK: util.func public @unset_encoding_and_slice +// CHECK: util.func public @unset_encoding_with_encoded_slice // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index @@ -1991,8 +1989,7 @@ util.func public @unset_encoding_and_slice( // CHECK-SAME: sizes = [%[[D0_W]], %[[D1_W]]] // CHECK-SAME: !flow.dispatch.tensor>{%[[D0_W]], %[[D1_W]]} // CHECK: %[[UNSET_ENCODING:.+]] = iree_encoding.unset_encoding %[[LOAD]] -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNSET_ENCODING]][0, 0] [%[[ARG0_W]], %[[ARG1_W]]] -// CHECK: flow.dispatch.tensor.store %[[SLICE]], %[[OUTARG]] +// CHECK: flow.dispatch.tensor.store %[[UNSET_ENCODING]], %[[OUTARG]] // CHECK-SAME: sizes = [%[[ARG0_W]], %[[ARG1_W]]] // CHECK-SAME: !flow.dispatch.tensor>{%[[ARG0_W]], %[[ARG1_W]]} // CHECK: flow.return diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir index b29f43ed47fa..196fc8795718 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir @@ -304,38 +304,34 @@ util.func public @unset_encoding_elementwise_fusion( // ----- #encoding = #iree_encoding.encoding -util.func public @unset_encoding_slice_elementwise_fusion( +util.func public @unset_encoding_elementwise_fusion( %arg0: tensor, %arg1: tensor, %arg2 : index, %arg3 : index) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %0 = iree_encoding.unset_encoding %arg0 : tensor> -> tensor{%arg2, %arg3} - %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor to tensor - %2 = tensor.dim %1, %c0 : tensor - %3 = tensor.dim %1, %c1 : tensor - %4 = tensor.empty(%2, %3) : tensor - %5 = linalg.generic { + %1 = tensor.empty(%arg2, %arg3) : tensor + %2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} - ins(%1, %arg1 : tensor, tensor) - outs(%4 : tensor) { + ins(%0, %arg1 : tensor, tensor) + outs(%1 : tensor) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): - %6 = arith.addf %b0, %b1 : f32 - linalg.yield %6 : f32 + %3 = arith.addf %b0, %b1 : f32 + linalg.yield %3 : f32 } -> tensor - util.return %5 : tensor + util.return %2 : tensor } // CHECK: #[[$ENCODING:.+]] = #iree_encoding.encoding -// CHECK-LABEL: util.func public @unset_encoding_slice_elementwise_fusion( +// CHECK-LABEL: util.func public @unset_encoding_elementwise_fusion( // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor // CHECK: %[[RESULT0:.+]] = flow.dispatch.region // CHECK: %[[UNSET_ENCODING:.+]] = iree_encoding.unset_encoding %[[ARG0]] -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNSET_ENCODING]] -// CHECK: %[[GENERIC:.+]] = linalg.generic {{.*}} ins(%[[SLICE]] +// CHECK: %[[GENERIC:.+]] = linalg.generic {{.*}} ins(%[[UNSET_ENCODING]] // CHECK: flow.return %[[GENERIC]] // CHECK: util.return %[[RESULT0]] @@ -382,6 +378,22 @@ util.func public @unpack_elementwise_fusion( // ----- +#encoding = #iree_encoding.encoding +util.func public @unset_encoding_slice(%arg0: tensor<1x50x384xf32, #encoding>) -> tensor<384xf32> { + %0 = iree_encoding.unset_encoding %arg0 : tensor<1x50x384xf32, #encoding> -> tensor<1x50x384xf32> + %extracted_slice = tensor.extract_slice %0[0, 0, 0] [1, 1, 384] [1, 1, 1] : tensor<1x50x384xf32> to tensor<384xf32> + util.return %extracted_slice : tensor<384xf32> +} +// CHECK-LABEL: util.func public @unset_encoding_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[RESULT:.+]] = flow.dispatch.region +// CHECK: %[[UNSET_ENCODING:.+]] = iree_encoding.unset_encoding +// CHECK: flow.return %[[UNSET_ENCODING]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[RESULT]] +// CHECK: util.return %[[SLICE]] + +// ----- + util.func public @unpack_non_intersecting_reduction( %arg0: tensor, %arg1: tensor) -> tensor {