Skip to content

Commit

Permalink
Revert "[Util][NFC] OptimizeIntArithmetic: reduce calls to `eraseStat…
Browse files Browse the repository at this point in the history
…e` (iree-org#19130)"

This reverts commit 81dd4e6.

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 committed Nov 15, 2024
1 parent 60cf4ab commit b2f3398
Showing 1 changed file with 38 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-util-optimize-int-arithmetic"
#define DEBUG_TYPE "iree-util-optimize-arithmetic"
using llvm::dbgs;

using namespace mlir::dataflow;
Expand Down Expand Up @@ -289,7 +289,43 @@ class DataFlowListener : public RewriterBase::Listener {
void notifyOperationErased(Operation *op) override {
s.eraseState(s.getProgramPointAfter(op));
for (Value res : op->getResults())
s.eraseState(res);
flushValue(res);
}
void notifyOperationModified(Operation *op) override {
for (Value res : op->getResults())
flushValue(res);
}
void notifyOperationReplaced(Operation *op, Operation *replacement) override {
for (Value res : op->getResults())
flushValue(res);
}

void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
for (Value res : op->getResults())
flushValue(res);
}

void flushValue(Value value) {
SmallVector<Value> worklist;
SmallVector<Value> process;
worklist.push_back(value);

while (!worklist.empty()) {
process.clear();
process.swap(worklist);
for (Value childValue : process) {
auto *state = s.lookupState<IntegerValueRangeLattice>(childValue);
if (!state) {
continue;
}
s.eraseState(childValue);
for (auto user : childValue.getUsers()) {
for (Value result : user->getResults()) {
worklist.push_back(result);
}
}
}
}
}

DataFlowSolver &s;
Expand Down Expand Up @@ -350,14 +386,11 @@ class OptimizeIntArithmeticPass

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (int i = 0;; ++i) {
LLVM_DEBUG(dbgs() << " * Starting iteration: " << i << "\n");
if (failed(solver.initializeAndRun(op))) {
emitError(op->getLoc()) << "failed to perform int range analysis";
return signalPassFailure();
}

LLVM_DEBUG(
dbgs() << " * Finished Running Solver -- Applying Patterns\n");
bool changed = false;
if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config,
&changed))) {
Expand Down

0 comments on commit b2f3398

Please sign in to comment.