From b2f3398664de2232796a21556b7e45ac471caf13 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Fri, 15 Nov 2024 12:40:24 -0800 Subject: [PATCH] Revert "[Util][NFC] OptimizeIntArithmetic: reduce calls to `eraseState` (#19130)" This reverts commit 81dd4e629539facd3d57723c455d7922b427c000. Signed-off-by: Ian Wood --- .../Util/Transforms/OptimizeIntArithmetic.cpp | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index d4b3a14b43f5..1049f3950bc3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -23,7 +23,7 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "iree-util-optimize-int-arithmetic" +#define DEBUG_TYPE "iree-util-optimize-arithmetic" using llvm::dbgs; using namespace mlir::dataflow; @@ -289,7 +289,43 @@ class DataFlowListener : public RewriterBase::Listener { void notifyOperationErased(Operation *op) override { s.eraseState(s.getProgramPointAfter(op)); for (Value res : op->getResults()) - s.eraseState(res); + flushValue(res); + } + void notifyOperationModified(Operation *op) override { + for (Value res : op->getResults()) + flushValue(res); + } + void notifyOperationReplaced(Operation *op, Operation *replacement) override { + for (Value res : op->getResults()) + flushValue(res); + } + + void notifyOperationReplaced(Operation *op, ValueRange replacement) override { + for (Value res : op->getResults()) + flushValue(res); + } + + void flushValue(Value value) { + SmallVector worklist; + SmallVector process; + worklist.push_back(value); + + while (!worklist.empty()) { + process.clear(); + process.swap(worklist); + for (Value childValue : process) { + auto *state = s.lookupState(childValue); + if (!state) { + continue; + } + s.eraseState(childValue); + for (auto user : childValue.getUsers()) { + for (Value result : user->getResults()) { + worklist.push_back(result); + } + } + } + } } DataFlowSolver &s; @@ -350,14 +386,11 @@ class OptimizeIntArithmeticPass FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (int i = 0;; ++i) { - LLVM_DEBUG(dbgs() << " * Starting iteration: " << i << "\n"); if (failed(solver.initializeAndRun(op))) { emitError(op->getLoc()) << "failed to perform int range analysis"; return signalPassFailure(); } - LLVM_DEBUG( - dbgs() << " * Finished Running Solver -- Applying Patterns\n"); bool changed = false; if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config, &changed))) {