From 81dd4e629539facd3d57723c455d7922b427c000 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Thu, 14 Nov 2024 14:57:28 -0800 Subject: [PATCH] [Util][NFC] OptimizeIntArithmetic: reduce calls to `eraseState` (#19130) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This pass is causing long compilation times for llama3 405b (even when cherry-picking https://github.com/llvm/llvm-project/pull/115399). The majority of the time is spent in this one pass. The compilation times improve when calling `eraseState` only when ops are deleted. This is similar to the upstream listeners in `UnsignedWhenEquivalent.cpp` and `IntRangeOptimizations.cpp`. It appears this function loops over all `LatticeAnchors` on each invocation to find the one to delete, causing it to be slow. My (nonrigorous) experiment showed a decrease from 18 min to 3 min compile time. My main concern here would be this affecting correctness, as I don't know if this has unaccounted for side effects. Signed-off-by: Ian Wood --- .../Util/Transforms/OptimizeIntArithmetic.cpp | 43 +++---------------- 1 file changed, 5 insertions(+), 38 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index 1049f3950bc3..d4b3a14b43f5 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -23,7 +23,7 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "iree-util-optimize-arithmetic" +#define DEBUG_TYPE "iree-util-optimize-int-arithmetic" using llvm::dbgs; using namespace mlir::dataflow; @@ -289,43 +289,7 @@ class DataFlowListener : public RewriterBase::Listener { void notifyOperationErased(Operation *op) override { s.eraseState(s.getProgramPointAfter(op)); for (Value res : op->getResults()) - flushValue(res); - } - void notifyOperationModified(Operation *op) override { - for (Value res : op->getResults()) - flushValue(res); - } - void notifyOperationReplaced(Operation *op, Operation *replacement) override { - for (Value res : op->getResults()) - flushValue(res); - } - - void notifyOperationReplaced(Operation *op, ValueRange replacement) override { - for (Value res : op->getResults()) - flushValue(res); - } - - void flushValue(Value value) { - SmallVector worklist; - SmallVector process; - worklist.push_back(value); - - while (!worklist.empty()) { - process.clear(); - process.swap(worklist); - for (Value childValue : process) { - auto *state = s.lookupState(childValue); - if (!state) { - continue; - } - s.eraseState(childValue); - for (auto user : childValue.getUsers()) { - for (Value result : user->getResults()) { - worklist.push_back(result); - } - } - } - } + s.eraseState(res); } DataFlowSolver &s; @@ -386,11 +350,14 @@ class OptimizeIntArithmeticPass FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (int i = 0;; ++i) { + LLVM_DEBUG(dbgs() << " * Starting iteration: " << i << "\n"); if (failed(solver.initializeAndRun(op))) { emitError(op->getLoc()) << "failed to perform int range analysis"; return signalPassFailure(); } + LLVM_DEBUG( + dbgs() << " * Finished Running Solver -- Applying Patterns\n"); bool changed = false; if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config, &changed))) {