Skip to content

Commit

Permalink
Change rewriter to erase state less often
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 committed Nov 13, 2024
1 parent 4aa08f2 commit 7d19205
Showing 1 changed file with 5 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-util-optimize-arithmetic"
#define DEBUG_TYPE "iree-util-optimize-int-arithmetic"
using llvm::dbgs;

using namespace mlir::dataflow;
Expand Down Expand Up @@ -289,43 +289,7 @@ class DataFlowListener : public RewriterBase::Listener {
void notifyOperationErased(Operation *op) override {
s.eraseState(s.getProgramPointAfter(op));
for (Value res : op->getResults())
flushValue(res);
}
void notifyOperationModified(Operation *op) override {
for (Value res : op->getResults())
flushValue(res);
}
void notifyOperationReplaced(Operation *op, Operation *replacement) override {
for (Value res : op->getResults())
flushValue(res);
}

void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
for (Value res : op->getResults())
flushValue(res);
}

void flushValue(Value value) {
SmallVector<Value> worklist;
SmallVector<Value> process;
worklist.push_back(value);

while (!worklist.empty()) {
process.clear();
process.swap(worklist);
for (Value childValue : process) {
auto *state = s.lookupState<IntegerValueRangeLattice>(childValue);
if (!state) {
continue;
}
s.eraseState(childValue);
for (auto user : childValue.getUsers()) {
for (Value result : user->getResults()) {
worklist.push_back(result);
}
}
}
}
s.eraseState(res);
}

DataFlowSolver &s;
Expand Down Expand Up @@ -386,11 +350,14 @@ class OptimizeIntArithmeticPass

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (int i = 0;; ++i) {
LLVM_DEBUG(dbgs() << " * Starting iteration: " << i << "\n");
if (failed(solver.initializeAndRun(op))) {
emitError(op->getLoc()) << "failed to perform int range analysis";
return signalPassFailure();
}

LLVM_DEBUG(
dbgs() << " * Finished Running Solver -- Applying Patterns\n");
bool changed = false;
if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config,
&changed))) {
Expand Down

0 comments on commit 7d19205

Please sign in to comment.