diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp index 38169f8191f9..bd5bf561e635 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp @@ -11,11 +11,22 @@ //===----------------------------------------------------------------------===// #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::iree_compiler::IREE::Flow { @@ -131,14 +142,76 @@ struct FoldSuccessiveTensorInsertSliceOps } }; +//===----------------------------------------------------------------------===// +// GatherFusionPattern +//===----------------------------------------------------------------------===// + +// Specific case. The linalg generic implementation of "gather" +// cannot be fused because it there is no producer-consumer +// relationship between the two generics. This is because the indexing +// is not affine (index values come from a tensor). +struct GatherFusionPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Check if extractOp is inside a generic op + auto consumerOp = + dyn_cast_or_null(extractOp->getParentOp()); + if (!consumerOp) { + return rewriter.notifyMatchFailure( + extractOp, "expected extract op to be inside a generic op"); + } + + auto producerOp = extractOp.getTensor().getDefiningOp(); + if (!producerOp) { + return rewriter.notifyMatchFailure( + consumerOp, "expected extract operand to be a generic op"); + } + + // Check if the producerOp is fusible + if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 || + !isElementwise(producerOp) || !isDequantizationLikeOp(producerOp)) { + return rewriter.notifyMatchFailure(producerOp, + "producer op is not fusible"); + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(extractOp); + + // Create a new extract op that extracts from the original tensor + // (after the original extract). Clone the producerOp's body into the + // consumerOp, inline the cloned block (erases the block) after the new + // extract, and clean up. + auto newExtractOp = rewriter.create( + extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(), + extractOp.getIndices()); + rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(), + consumerOp.getRegion().begin()); + Block &clonedBlock = consumerOp.getRegion().front(); + auto producerTermOp = clonedBlock.getTerminator(); + + rewriter.inlineBlockBefore( + &clonedBlock, extractOp->getNextNode(), + {newExtractOp.getResult(), newExtractOp.getResult()}); + + // Replace the the all references to the original extract result with the + // result from the inlined producerOp. + extractOp.getResult().replaceAllUsesWith(producerTermOp->getOperand(0)); + rewriter.eraseOp(producerTermOp); + rewriter.eraseOp(extractOp); + + return success(); + } +}; + struct FusionPreprocessingPass : public IREE::Flow::impl::FusionPreprocessingPassBase< FusionPreprocessingPass> { void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns - .add( - &getContext()); + patterns.add( + &getContext()); // Fold away `tensor.dim` operations that can be resolved in terms of its // operand shapes. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir index 82f9fb9c0bfd..ba200f46a277 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir @@ -54,3 +54,87 @@ util.func public @fold_insert_slices(%source : tensor, // CHECK: %[[RETURN:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]] // CHECK-SAME: [%[[NEW_OFFSET0]], %[[NEW_OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] // CHECK: util.return %[[RETURN]] + + +// ----- + +util.func public @fuse_generic_gather( + %11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>, + %13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>) + -> tensor<4x?x4096xf32>{ + + %15 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%11 : tensor<128256x4096xf16>) + outs(%14 : tensor<128256x4096xf32>) { + ^bb0(%in: f16, %out: f32): + %17 = arith.extf %in : f16 to f32 + linalg.yield %17 : f32 + } -> tensor<128256x4096xf32> + %16 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%12 : tensor<4x?xi64>) + outs(%13 : tensor<4x?x4096xf32>) { + ^bb0(%in: i64, %out: f32): + %17 = arith.index_cast %in : i64 to index + %18 = linalg.index 2 : index + %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32> + linalg.yield %extracted : f32 + } -> tensor<4x?x4096xf32> + util.return %16 : tensor<4x?x4096xf32> +} + +// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index +// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index +// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16> +// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32 +// CHECK-NEXT: linalg.yield %[[RES]] : f32 + + +// ----- + +util.func public @fuse_generic_gather2( + %11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>, + %13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>) + -> tensor<4x?x4096xf32>{ + + %15 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%11 : tensor<128256x4096xf16>) + outs(%14 : tensor<128256x4096xf32>) { + ^bb0(%in: f16, %out: f32): + %17 = arith.extf %in : f16 to f32 + linalg.yield %17 : f32 + } -> tensor<128256x4096xf32> + %16 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%12 : tensor<4x?xi64>) + outs(%13 : tensor<4x?x4096xf32>) { + ^bb0(%in: i64, %out: f32): + %17 = arith.index_cast %in : i64 to index + %18 = linalg.index 2 : index + %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32> + %result = arith.addf %extracted, %extracted : f32 + %result2 = arith.mulf %extracted, %extracted : f32 + %final = arith.addf %result, %result2 : f32 + linalg.yield %final: f32 + } -> tensor<4x?x4096xf32> + util.return %16 : tensor<4x?x4096xf32> +} + +// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index +// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index +// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16> +// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32 +// CHECK-NEXT: %[[RES2:[a-zA-Z0-9]+]] = arith.addf %[[RES]], %[[RES]] : f32 +// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32 +// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32 +// CHECK-NEXT: linalg.yield %[[RES4]] : f32