Skip to content

Commit

Permalink
transforms: Allow to pass a pattern rewriter in CSE (#3539)
Browse files Browse the repository at this point in the history
Stacked PRs:
 * #3540
 * __->__#3539
 * #3538
 * #3537


--- --- ---

### transforms: Allow to pass a pattern rewriter in CSE


Without passing the pattern rewriter, CSE couldn't be called inside
a pattern rewriter walker, as it would not notify the operations that
were deleted or replaced.
  • Loading branch information
math-fehr authored Dec 18, 2024
1 parent b8611e0 commit b394263
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion xdsl/transforms/canonicalization_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N
continue
a.replace_by(bbargs[rbargs[i]])

cse(op.region.block)
cse(op.region.block, rewriter)


class ApplyUnusedOperands(RewritePattern):
Expand Down
21 changes: 12 additions & 9 deletions xdsl/transforms/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xdsl.dialects.builtin import ModuleOp, UnregisteredOp
from xdsl.ir import Block, Operation, Region, Use
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import PatternRewriter
from xdsl.rewriter import Rewriter
from xdsl.traits import (
IsolatedFromAbove,
Expand Down Expand Up @@ -115,19 +116,15 @@ def has_other_side_effecting_op_in_between(
return False


@dataclass
class CSEDriver:
"""
Boilerplate class to handle and carry the state for CSE.
"""

_rewriter: Rewriter
_rewriter: Rewriter | PatternRewriter = field(default_factory=Rewriter)
_to_erase: set[Operation] = field(default_factory=set)
_known_ops: KnownOps = KnownOps()

def __init__(self):
self._rewriter = Rewriter()
self._to_erase = set()
self._known_ops = KnownOps()
_known_ops: KnownOps = field(default_factory=KnownOps)

def _mark_erasure(self, op: Operation):
self._to_erase.add(op)
Expand Down Expand Up @@ -250,8 +247,14 @@ def simplify(self, thing: Operation | Block | Region):
self._commit_erasures()


def cse(thing: Operation | Block | Region):
CSEDriver().simplify(thing)
def cse(
thing: Operation | Block | Region,
rewriter: Rewriter | PatternRewriter | None = None,
):
if rewriter is not None:
CSEDriver(_rewriter=rewriter).simplify(thing)
else:
CSEDriver().simplify(thing)


class CommonSubexpressionElimination(ModulePass):
Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/control_flow_hoist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def match_and_rewrite(self, op: affine.IfOp, rewriter: PatternRewriter):
return
block = op.parent
if block:
cse(block)
cse(block, rewriter)


class SCFIfHoistPattern(RewritePattern):
Expand All @@ -84,7 +84,7 @@ def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter):
block = op.parent
if block:
# If we hoisted some ops, run CSE on that block to not keep pushing duplicates upward.
cse(block)
cse(block, rewriter)


class ControlFlowHoistPass(ModulePass):
Expand Down

0 comments on commit b394263

Please sign in to comment.