Skip to content

Commit

Permalink
[FuseConsumerIntoLoop] Add support for fusing iteratively (#1003)
Browse files Browse the repository at this point in the history
This PR adds support for iterative consumer fusion until no fusion
opportunity can be found.
  • Loading branch information
jtuyls authored Jan 2, 2025
1 parent 155723c commit 39ae70d
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-amdaie-fuse-consumer-into-loop"
namespace mlir::iree_compiler::AMDAIE {
Expand All @@ -21,6 +22,9 @@ class AMDAIEFuseConsumerIntoLoopPass
public:
AMDAIEFuseConsumerIntoLoopPass() = default;
AMDAIEFuseConsumerIntoLoopPass(const AMDAIEFuseConsumerIntoLoopPass &pass) {}
AMDAIEFuseConsumerIntoLoopPass(
const AMDAIEFuseConsumerIntoLoopOptions &options)
: AMDAIEFuseConsumerIntoLoopBase(options) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect>();
Expand All @@ -33,6 +37,11 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
mlir::FunctionOpInterface funcOp = getOperation();
IRRewriter rewriter(context);

RewritePatternSet patterns(context);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
FrozenRewritePatternSet canonicalizationPatterns(std::move(patterns));

// Step 1. The depth until which we would keep fusing consumer chain.
// TODO(avarma): This should also be part of KernelDispatch logic.
unsigned fuseDepth = 1;
Expand Down Expand Up @@ -74,43 +83,67 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
return;
}

// Step 4. Based on the `fuseDepth`, we would greedily fuse the consumer ops.
for (unsigned depth = 1; depth <= fuseDepth; depth++) {
do {
Value::user_range users = computeOp->getResult(0).getUsers();
if (!llvm::hasSingleElement(users)) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected only one user of the compute op\n");
return signalPassFailure();
}
// Step 4. Greedily fuse the consumer ops for a specified fusion depth and
// while fusion keeps occurring or until the maximum number of iterations is
// exceeded.
bool changed{true};
int64_t iter{0};
while (changed && iter < maxIterations) {
changed = false;
// Canonicalize before every iteration to enable more back-to-back fusion
// opportunities.
(void)applyPatternsAndFoldGreedily(funcOp, canonicalizationPatterns);
Operation *producerOp = computeOp;
// TODO(jornt): Refactor fuseDepth to avoid hardcoding and fuse greedily
// with any depth instead.
for (unsigned depth = 1; depth <= fuseDepth; depth++) {
do {
Value::user_range users = producerOp->getResult(0).getUsers();
if (!llvm::hasSingleElement(users)) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected only one user of the compute op\n");
break;
}

Operation *candidateSliceOp = *(users.begin());
if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))) {
computeOp = computeOp->getParentOfType<LoopLikeOpInterface>();
LLVM_DEBUG(llvm::dbgs()
<< "Going to reattempt fusion because didn't find "
"tensor.insert_slice/tensor.parallel_insert_slice as the "
"user of the compute op\n");
continue;
}
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedConsumer =
scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
if (!fusedConsumer) {
candidateSliceOp->emitOpError(
"Failed to fuse any consumer op into the producer");
return signalPassFailure();
}
fusedConsumer->origConsumerOperand->getOwner()->erase();
computeOp = fusedConsumer->tiledAndFusedConsumerOperand->getOwner();
break;
} while (computeOp && computeOp->getParentOp() != funcOp);
Operation *candidateSliceOp = *(users.begin());
if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))) {
producerOp = producerOp->getParentOfType<LoopLikeOpInterface>();
LLVM_DEBUG(
llvm::dbgs()
<< "Going to reattempt fusion because didn't find "
"tensor.insert_slice/tensor.parallel_insert_slice as the "
"user of the compute op\n");
continue;
}
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedConsumer =
scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
if (!fusedConsumer) {
producerOp = producerOp->getParentOfType<LoopLikeOpInterface>();
LLVM_DEBUG(llvm::dbgs()
<< "Failed to fuse any consumer op into the producer. "
"Reattempt with loop-like parent operation.\n");
continue;
}
changed = true;
fusedConsumer->origConsumerOperand->getOwner()->erase();
producerOp = fusedConsumer->tiledAndFusedConsumerOperand->getOwner();
break;
} while (producerOp && producerOp->getParentOp() != funcOp);
}
iter++;
}
if (iter >= maxIterations) {
funcOp.emitOpError() << "Maximum number of iterations reached, consumer "
"fusion is likely stuck in an infinite loop.";
return signalPassFailure();
}
}

} // namespace

std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass() {
return std::make_unique<AMDAIEFuseConsumerIntoLoopPass>();
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
AMDAIEFuseConsumerIntoLoopOptions options) {
return std::make_unique<AMDAIEFuseConsumerIntoLoopPass>(options);
}
} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ std::unique_ptr<Pass> createAMDAIEFlattenLogicalObjectFifoPass();
std::unique_ptr<Pass> createAMDAIELinalgFunctionOutliningPass();

/// Create a pass to fuse the consumer op into the innermost last scf loop.
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass();
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
AMDAIEFuseConsumerIntoLoopOptions options = {});

/// Create a pass to fuse the linalg.fill into the forall loops.
std::unique_ptr<Pass> createAMDAIEFuseFillIntoForallPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ def AMDAIEFuseConsumerIntoLoop :
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> {
let summary = "Fuse the consumer operation into the innermost last scf loop.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFuseConsumerIntoLoopPass()";
let options = [
Option<"maxIterations", "max-iterations", "int64_t", /*default=*/"100",
"The maximum number of iterations the consumer fusion should be applied.">,
];
}

def AMDAIEFuseFillIntoForall :
Expand Down
Loading

0 comments on commit 39ae70d

Please sign in to comment.