diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp index 25f1685035c6..708b6a3e4075 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -149,58 +150,58 @@ struct FoldSuccessiveTensorInsertSliceOps // cannot be fused because it there is no producer-consumer // relationship between the two generics. This is because the indexing // is not affine (index values come from a tensor). -struct GatherFusionPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::GenericOp consumerOp, +struct GatherFusionPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, PatternRewriter &rewriter) const override { - auto extractOps = consumerOp.getOps(); - if (extractOps.empty()) { - return failure(); + // Check if extractOp is inside a generic op + auto consumerOp = + dyn_cast_or_null(extractOp->getParentOp()); + if (!consumerOp) { + return rewriter.notifyMatchFailure( + extractOp, "expected extract op to be inside a generic op"); } - for (tensor::ExtractOp extractOp : extractOps) { - auto producerOp = dyn_cast_or_null( - extractOp.getOperand(0).getDefiningOp()); - if (!producerOp) { - return rewriter.notifyMatchFailure( - consumerOp, "expected extract operand to be a generic op"); - } - - // Check if the producerOp is fusible - if (producerOp.getNumDpsInputs() != 1 || !isElementwise(producerOp) || - !isDequantizationLikeOp(producerOp)) { - return rewriter.notifyMatchFailure(producerOp, - "producer op is not fusible"); - } + auto producerOp = dyn_cast_or_null( + extractOp.getOperand(0).getDefiningOp()); + if (!producerOp) { + return rewriter.notifyMatchFailure( + consumerOp, "expected extract operand to be a generic op"); + } - // fuse by performing the dequantization-like operation after - // tensor.extract - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(extractOp); - - // Create a new extract op that extracts from the original tensor - // (after the original extract). Clone the producerOp's body into the - // consumerOp, inline the cloned block (erases the block) after the new - // extract, and clean up. - auto newExtractOp = rewriter.create( - extractOp->getLoc(), producerOp->getOperand(0), - extractOp.getIndices()); - rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(), - consumerOp.getRegion().begin()); - Block &clonedBlock = consumerOp.getRegion().front(); - auto producerTermOp = clonedBlock.getTerminator(); - - rewriter.inlineBlockBefore( - &clonedBlock, extractOp->getNextNode(), - {newExtractOp.getResult(), newExtractOp->getResult(0)}); - - // Replace the the all references to the original extract result with the - // result from the inlined producerOp. - extractOp->getResult(0).replaceAllUsesWith(producerTermOp->getOperand(0)); - rewriter.eraseOp(producerTermOp); - rewriter.eraseOp(extractOp); + // Check if the producerOp is fusible + if (producerOp.getNumDpsInputs() != 1 || !isElementwise(producerOp) || + !isDequantizationLikeOp(producerOp)) { + return rewriter.notifyMatchFailure(producerOp, + "producer op is not fusible"); } + // fuse by performing the dequantization-like operation after + // tensor.extract + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(extractOp); + + // Create a new extract op that extracts from the original tensor + // (after the original extract). Clone the producerOp's body into the + // consumerOp, inline the cloned block (erases the block) after the new + // extract, and clean up. + auto newExtractOp = rewriter.create( + extractOp->getLoc(), producerOp->getOperand(0), extractOp.getIndices()); + rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(), + consumerOp.getRegion().begin()); + Block &clonedBlock = consumerOp.getRegion().front(); + auto producerTermOp = clonedBlock.getTerminator(); + + rewriter.inlineBlockBefore( + &clonedBlock, extractOp->getNextNode(), + {newExtractOp.getResult(), newExtractOp->getResult(0)}); + + // Replace the the all references to the original extract result with the + // result from the inlined producerOp. + extractOp->getResult(0).replaceAllUsesWith(producerTermOp->getOperand(0)); + rewriter.eraseOp(producerTermOp); + rewriter.eraseOp(extractOp); + return success(); } };