Skip to content

Commit

Permalink
core: Simplify the pattern rewriter (#1910)
Browse files Browse the repository at this point in the history
Now that the pattern rewriter use the worklist, it is not necessary to
keep track of the inserted operations before/after
the matched operation.
We can thus remove all these checks.
  • Loading branch information
math-fehr authored Jan 22, 2024
1 parent e6ca26f commit 5248ab0
Showing 1 changed file with 20 additions and 179 deletions.
199 changes: 20 additions & 179 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,97 +89,19 @@ class PatternRewriter(PatternRewriterListener):
current_operation: Operation
"""The matched operation."""

has_erased_matched_operation: bool = field(default=False, init=False)
"""Was the matched operation erased."""

added_operations_before: list[Operation] = field(default_factory=list, init=False)
"""The operations added directly before the matched operation."""

added_operations_after: list[Operation] = field(default_factory=list, init=False)
"""The operations added directly after the matched operation."""

has_done_action: bool = field(default=False, init=False)
"""Has the rewriter done any action during the current match."""

def _can_modify_op(self, op: Operation) -> bool:
"""Check if the operation and its children can be modified by this rewriter."""
if op == self.current_operation:
return True
if op.parent is None:
return self.current_operation.get_toplevel_object() is not op
return self._can_modify_op_in_block(op.parent)

def _can_modify_block(self, block: Block) -> bool:
"""Check if the block can be modified by this rewriter."""
if block is self.current_operation.parent:
return True
return self._can_modify_op_in_block(block)

def _can_modify_op_in_block(self, block: Block) -> bool:
"""Check if the block and its children can be modified by this rewriter."""
if block.parent is None:
return True # Toplevel operation of current_operation is always a ModuleOp
return self._can_modify_region(block.parent)

def _can_modify_region(self, region: Region) -> bool:
"""Check if the region and its children can be modified by this rewriter."""
if region.parent is None:
return True # Toplevel operation of current_operation is always a ModuleOp
if region is self.current_operation.parent_region():
return True
return self._can_modify_op(region.parent)

def _assert_can_modify_op(self, op: Operation) -> None:
"""Asssert the operation and its children can be modified by this rewriter."""
if not self._can_modify_op(op):
raise Exception("Cannot modify the operation or its children")

def _assert_can_modify_block(self, block: Block) -> None:
"""Assert the block can be modified by this rewriter."""
if not self._can_modify_block(block):
raise Exception("Cannot modify the block")

def _assert_can_modify_op_in_block(self, block: Block) -> None:
"""Assert the block and its children can be modified by this rewriter."""
if not self._can_modify_op_in_block(block):
raise Exception("Cannot modify the block or its children")

def _assert_can_modify_region(self, region: Region) -> None:
"""Assert the region and its children can be modified by this rewriter."""
if not self._can_modify_region(region):
raise Exception("Cannot modify the region or its children")

def insert_op_before_matched_op(self, op: (Operation | Sequence[Operation])):
"""Insert operations before the matched operation."""
if self.current_operation.parent is None:
raise Exception("Cannot insert an operation before a toplevel operation.")
self.has_done_action = True
block = self.current_operation.parent
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
return
block.insert_ops_before(op, self.current_operation)
self.added_operations_before.extend(op)
for op_ in op:
self.handle_operation_insertion(op_)
self.insert_op_before(op, self.current_operation)

def insert_op_after_matched_op(self, op: (Operation | Sequence[Operation])):
"""Insert operations after the matched operation."""
if self.current_operation.parent is None:
raise Exception("Cannot insert an operation after a toplevel operation.")
self.has_done_action = True
block = self.current_operation.parent
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
return
block.insert_ops_after(op, self.current_operation)
self.added_operations_after.extend(op)
for op_ in op:
self.handle_operation_insertion(op_)
self.insert_op_after(op, self.current_operation)

def insert_op_at_end(self, op: Operation | Sequence[Operation], block: Block):
"""Insert operations in a block contained in the matched operation."""
self._assert_can_modify_block(block)
"""Insert operations at the end of a block."""
self.has_done_action = True
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
Expand All @@ -189,71 +111,57 @@ def insert_op_at_end(self, op: Operation | Sequence[Operation], block: Block):
self.handle_operation_insertion(op_)

def insert_op_at_start(self, op: Operation | Sequence[Operation], block: Block):
"""Insert operations in a block contained in the matched operation."""
self._assert_can_modify_block(block)
first_op = block.first_op
if first_op is None:
self.insert_op_at_end(op, block)
else:
"""Insert operations at the start of a block."""
if (first_op := block.first_op) is not None:
self.insert_op_before(op, first_op)
else:
self.insert_op_at_end(op, block)

def insert_op_before(
self, op: Operation | Sequence[Operation], target_op: Operation
):
"""Insert operations before an operation contained in the matched operation."""
"""Insert operations before an operation."""
if target_op.parent is None:
raise Exception("Cannot insert operations before toplevel operation.")
target_block = target_op.parent
self._assert_can_modify_block(target_block)
self.has_done_action = True
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
return
target_block.insert_ops_before(op, target_op)
for op_ in op:
self.handle_operation_insertion(op_)
if target_op is self.current_operation:
self.added_operations_before.extend(op)

def insert_op_after(
self, op: Operation | Sequence[Operation], target_op: Operation
):
"""Insert operations after an operation contained in the matched operation."""
"""Insert operations after an operation."""
if target_op.parent is None:
raise Exception("Cannot insert operations after toplevel operation.")
target_block = target_op.parent
self._assert_can_modify_block(target_block)
self.has_done_action = True
ops = [op] if isinstance(op, Operation) else op
if len(ops) == 0:
return
target_block.insert_ops_after(ops, target_op)
for op_ in ops:
self.handle_operation_insertion(op_)
if target_op is self.current_operation:
self.added_operations_after.extend(ops)

def erase_matched_op(self, safe_erase: bool = True):
"""
Erase the operation that was matched to.
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
self.has_done_action = True
self.has_erased_matched_operation = True
self.handle_operation_removal(self.current_operation)
Rewriter.erase_op(self.current_operation, safe_erase=safe_erase)
self.erase_op(self.current_operation, safe_erase=safe_erase)

def erase_op(self, op: Operation, safe_erase: bool = True):
"""
Erase an operation contained in the matched operation children.
Erase an operation.
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
self.has_done_action = True
if op == self.current_operation:
return self.erase_matched_op(safe_erase)
self._assert_can_modify_op(op)
self.handle_operation_removal(op)
Rewriter.erase_op(op, safe_erase=safe_erase)

Expand Down Expand Up @@ -293,13 +201,11 @@ def replace_op(
):
"""
Replace an operation with new operations.
The operation should be a child of the matched operation.
Also, optionally specify SSA values to replace the operation results.
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
self.has_done_action = True
self._assert_can_modify_op(op)
if isinstance(new_ops, Operation):
new_ops = [new_ops]

Expand Down Expand Up @@ -330,11 +236,7 @@ def replace_op(
self.erase_op(op, safe_erase=safe_erase)

def modify_block_argument_type(self, arg: BlockArgument, new_type: Attribute):
"""
Modify the type of a block argument.
The block should be contained in the matched operation.
"""
self._assert_can_modify_block(arg.block)
"""Modify the type of a block argument."""
self.has_done_action = True
arg.type = new_type

Expand All @@ -344,22 +246,16 @@ def modify_block_argument_type(self, arg: BlockArgument, new_type: Attribute):
def insert_block_argument(
self, block: Block, index: int, arg_type: Attribute
) -> BlockArgument:
"""
Insert a new block argument.
The block should be contained in the matched operation.
"""
self._assert_can_modify_block(block)
"""Insert a new block argument."""
self.has_done_action = True
return block.insert_arg(arg_type, index)

def erase_block_argument(self, arg: BlockArgument, safe_erase: bool = True) -> None:
"""
Erase a new block argument.
The block should be contained in the matched operation.
If safe_erase is true, then raise an exception if the block argument has still
uses, otherwise, replace it with an ErasedSSAValue.
"""
self._assert_can_modify_block(arg.block)
self.has_done_action = True
self._replace_all_uses_with(arg, None, safe_erase=safe_erase)
arg.block.erase_arg(arg, safe_erase)
Expand All @@ -370,8 +266,6 @@ def inline_block_at_end(self, block: Block, target_block: Block):
This block should not be a parent of the block to move to.
"""
self.has_done_action = True
self._assert_can_modify_block(target_block)
self._assert_can_modify_block(block)
Rewriter.inline_block_at_end(block, target_block)

def inline_block_at_start(self, block: Block, target_block: Block):
Expand All @@ -380,116 +274,63 @@ def inline_block_at_start(self, block: Block, target_block: Block):
This block should not be a parent of the block to move to.
"""
self.has_done_action = True
self._assert_can_modify_block(target_block)
self._assert_can_modify_block(block)
Rewriter.inline_block_at_start(block, target_block)

def inline_block_before_matched_op(self, block: Block):
"""
Move the block operations before the matched operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
self._assert_can_modify_block(block)
self.added_operations_before.extend(block.ops)
Rewriter.inline_block_before(block, self.current_operation)
self.inline_block_before(block, self.current_operation)

def inline_block_before(self, block: Block, op: Operation):
"""
Move the block operations before the given operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The operation should also be a child of the matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
if op is self.current_operation:
return self.inline_block_before_matched_op(block)
self._assert_can_modify_block(block)
self._assert_can_modify_op(op)
Rewriter.inline_block_before(block, op)

def inline_block_after_matched_op(self, block: Block):
"""
Move the block operations after the matched operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
self._assert_can_modify_block(block)
self.added_operations_after.extend(block.ops)
Rewriter.inline_block_after(block, self.current_operation)
self.inline_block_after(block, self.current_operation)

def inline_block_after(self, block: Block, op: Operation):
"""
Move the block operations after the given operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The operation should also be a child of the matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
if op is self.current_operation:
return self.inline_block_after_matched_op(block)
self._assert_can_modify_block(block)
if op.parent is not None:
self._assert_can_modify_block(op.parent)
Rewriter.inline_block_after(block, op)

def move_region_contents_to_new_regions(self, region: Region) -> Region:
"""
Move the region blocks to a new region.
The region should be a child of the matched operation.
"""
"""Move the region blocks to a new region."""
self.has_done_action = True
self._assert_can_modify_region(region)
return Rewriter.move_region_contents_to_new_regions(region)

def inline_region_before(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_block(target)
Rewriter.inline_region_before(region, target)

def inline_region_after(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_block(target)
Rewriter.inline_region_after(region, target)

def inline_region_at_start(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_region(target)
Rewriter.inline_region_at_start(region, target)

def inline_region_at_end(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_region(target)
Rewriter.inline_region_at_end(region, target)

def iter_affected_ops(self) -> Iterable[Operation]:
"""
Iterate newly added operations, in the order that they are in the module.
"""
yield from self.added_operations_before
if not self.has_erased_matched_operation:
yield self.current_operation
yield from self.added_operations_after

def iter_affected_ops_reversed(self) -> Iterable[Operation]:
"""
Iterate newly added operations, in reverse order from that in the module.
"""
yield from reversed(self.added_operations_after)
if not self.has_erased_matched_operation:
yield self.current_operation
yield from reversed(self.added_operations_before)


class RewritePattern(ABC):
"""
Expand Down

0 comments on commit 5248ab0

Please sign in to comment.