Skip to content

Commit

Permalink
core: Simplify the pattern rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 22, 2024
1 parent e6ca26f commit 21a2af8
Showing 1 changed file with 20 additions and 179 deletions.
199 changes: 20 additions & 179 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,97 +89,19 @@ class PatternRewriter(PatternRewriterListener):
current_operation: Operation
"""The matched operation."""

has_erased_matched_operation: bool = field(default=False, init=False)
"""Was the matched operation erased."""

added_operations_before: list[Operation] = field(default_factory=list, init=False)
"""The operations added directly before the matched operation."""

added_operations_after: list[Operation] = field(default_factory=list, init=False)
"""The operations added directly after the matched operation."""

has_done_action: bool = field(default=False, init=False)
"""Has the rewriter done any action during the current match."""

def _can_modify_op(self, op: Operation) -> bool:
"""Check if the operation and its children can be modified by this rewriter."""
if op == self.current_operation:
return True
if op.parent is None:
return self.current_operation.get_toplevel_object() is not op
return self._can_modify_op_in_block(op.parent)

def _can_modify_block(self, block: Block) -> bool:
"""Check if the block can be modified by this rewriter."""
if block is self.current_operation.parent:
return True
return self._can_modify_op_in_block(block)

def _can_modify_op_in_block(self, block: Block) -> bool:
"""Check if the block and its children can be modified by this rewriter."""
if block.parent is None:
return True # Toplevel operation of current_operation is always a ModuleOp
return self._can_modify_region(block.parent)

def _can_modify_region(self, region: Region) -> bool:
"""Check if the region and its children can be modified by this rewriter."""
if region.parent is None:
return True # Toplevel operation of current_operation is always a ModuleOp
if region is self.current_operation.parent_region():
return True
return self._can_modify_op(region.parent)

def _assert_can_modify_op(self, op: Operation) -> None:
"""Asssert the operation and its children can be modified by this rewriter."""
if not self._can_modify_op(op):
raise Exception("Cannot modify the operation or its children")

def _assert_can_modify_block(self, block: Block) -> None:
"""Assert the block can be modified by this rewriter."""
if not self._can_modify_block(block):
raise Exception("Cannot modify the block")

def _assert_can_modify_op_in_block(self, block: Block) -> None:
"""Assert the block and its children can be modified by this rewriter."""
if not self._can_modify_op_in_block(block):
raise Exception("Cannot modify the block or its children")

def _assert_can_modify_region(self, region: Region) -> None:
"""Assert the region and its children can be modified by this rewriter."""
if not self._can_modify_region(region):
raise Exception("Cannot modify the region or its children")

def insert_op_before_matched_op(self, op: (Operation | Sequence[Operation])):
"""Insert operations before the matched operation."""
if self.current_operation.parent is None:
raise Exception("Cannot insert an operation before a toplevel operation.")
self.has_done_action = True
block = self.current_operation.parent
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
return
block.insert_ops_before(op, self.current_operation)
self.added_operations_before.extend(op)
for op_ in op:
self.handle_operation_insertion(op_)
self.insert_op_before(op, self.current_operation)

def insert_op_after_matched_op(self, op: (Operation | Sequence[Operation])):
"""Insert operations after the matched operation."""
if self.current_operation.parent is None:
raise Exception("Cannot insert an operation after a toplevel operation.")
self.has_done_action = True
block = self.current_operation.parent
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
return
block.insert_ops_after(op, self.current_operation)
self.added_operations_after.extend(op)
for op_ in op:
self.handle_operation_insertion(op_)
self.insert_op_after(op, self.current_operation)

def insert_op_at_end(self, op: Operation | Sequence[Operation], block: Block):
"""Insert operations in a block contained in the matched operation."""
self._assert_can_modify_block(block)
"""Insert operations at the end of a block."""
self.has_done_action = True
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
Expand All @@ -189,71 +111,57 @@ def insert_op_at_end(self, op: Operation | Sequence[Operation], block: Block):
self.handle_operation_insertion(op_)

def insert_op_at_start(self, op: Operation | Sequence[Operation], block: Block):
"""Insert operations in a block contained in the matched operation."""
self._assert_can_modify_block(block)
first_op = block.first_op
if first_op is None:
self.insert_op_at_end(op, block)
else:
"""Insert operations at the start of a block."""
if (first_op := block.first_op) is not None:
self.insert_op_before(op, first_op)
else:
self.insert_op_at_end(op, block)

def insert_op_before(
self, op: Operation | Sequence[Operation], target_op: Operation
):
"""Insert operations before an operation contained in the matched operation."""
"""Insert operations before an operation."""
if target_op.parent is None:
raise Exception("Cannot insert operations before toplevel operation.")
target_block = target_op.parent
self._assert_can_modify_block(target_block)
self.has_done_action = True
op = [op] if isinstance(op, Operation) else op
if len(op) == 0:
return
target_block.insert_ops_before(op, target_op)
for op_ in op:
self.handle_operation_insertion(op_)
if target_op is self.current_operation:
self.added_operations_before.extend(op)

def insert_op_after(
self, op: Operation | Sequence[Operation], target_op: Operation
):
"""Insert operations after an operation contained in the matched operation."""
"""Insert operations after an operation."""
if target_op.parent is None:
raise Exception("Cannot insert operations after toplevel operation.")
target_block = target_op.parent
self._assert_can_modify_block(target_block)
self.has_done_action = True
ops = [op] if isinstance(op, Operation) else op
if len(ops) == 0:
return
target_block.insert_ops_after(ops, target_op)
for op_ in ops:
self.handle_operation_insertion(op_)
if target_op is self.current_operation:
self.added_operations_after.extend(ops)

def erase_matched_op(self, safe_erase: bool = True):
"""
Erase the operation that was matched to.
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
self.has_done_action = True
self.has_erased_matched_operation = True
self.handle_operation_removal(self.current_operation)
Rewriter.erase_op(self.current_operation, safe_erase=safe_erase)
self.erase_op(self.current_operation, safe_erase=safe_erase)

def erase_op(self, op: Operation, safe_erase: bool = True):
"""
Erase an operation contained in the matched operation children.
Erase an operation.
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
self.has_done_action = True
if op == self.current_operation:
return self.erase_matched_op(safe_erase)
self._assert_can_modify_op(op)
self.handle_operation_removal(op)
Rewriter.erase_op(op, safe_erase=safe_erase)

Expand Down Expand Up @@ -293,13 +201,11 @@ def replace_op(
):
"""
Replace an operation with new operations.
The operation should be a child of the matched operation.
Also, optionally specify SSA values to replace the operation results.
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
self.has_done_action = True
self._assert_can_modify_op(op)
if isinstance(new_ops, Operation):
new_ops = [new_ops]

Expand Down Expand Up @@ -330,11 +236,7 @@ def replace_op(
self.erase_op(op, safe_erase=safe_erase)

def modify_block_argument_type(self, arg: BlockArgument, new_type: Attribute):
"""
Modify the type of a block argument.
The block should be contained in the matched operation.
"""
self._assert_can_modify_block(arg.block)
"""Modify the type of a block argument."""
self.has_done_action = True
arg.type = new_type

Expand All @@ -344,22 +246,16 @@ def modify_block_argument_type(self, arg: BlockArgument, new_type: Attribute):
def insert_block_argument(
self, block: Block, index: int, arg_type: Attribute
) -> BlockArgument:
"""
Insert a new block argument.
The block should be contained in the matched operation.
"""
self._assert_can_modify_block(block)
"""Insert a new block argument."""
self.has_done_action = True
return block.insert_arg(arg_type, index)

def erase_block_argument(self, arg: BlockArgument, safe_erase: bool = True) -> None:
"""
Erase a new block argument.
The block should be contained in the matched operation.
If safe_erase is true, then raise an exception if the block argument has still
uses, otherwise, replace it with an ErasedSSAValue.
"""
self._assert_can_modify_block(arg.block)
self.has_done_action = True
self._replace_all_uses_with(arg, None, safe_erase=safe_erase)
arg.block.erase_arg(arg, safe_erase)
Expand All @@ -370,8 +266,6 @@ def inline_block_at_end(self, block: Block, target_block: Block):
This block should not be a parent of the block to move to.
"""
self.has_done_action = True
self._assert_can_modify_block(target_block)
self._assert_can_modify_block(block)
Rewriter.inline_block_at_end(block, target_block)

def inline_block_at_start(self, block: Block, target_block: Block):
Expand All @@ -380,116 +274,63 @@ def inline_block_at_start(self, block: Block, target_block: Block):
This block should not be a parent of the block to move to.
"""
self.has_done_action = True
self._assert_can_modify_block(target_block)
self._assert_can_modify_block(block)
Rewriter.inline_block_at_start(block, target_block)

def inline_block_before_matched_op(self, block: Block):
"""
Move the block operations before the matched operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
self._assert_can_modify_block(block)
self.added_operations_before.extend(block.ops)
Rewriter.inline_block_before(block, self.current_operation)
self.inline_block_before(block, self.current_operation)

def inline_block_before(self, block: Block, op: Operation):
"""
Move the block operations before the given operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The operation should also be a child of the matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
if op is self.current_operation:
return self.inline_block_before_matched_op(block)
self._assert_can_modify_block(block)
self._assert_can_modify_op(op)
Rewriter.inline_block_before(block, op)

def inline_block_after_matched_op(self, block: Block):
"""
Move the block operations after the matched operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
self._assert_can_modify_block(block)
self.added_operations_after.extend(block.ops)
Rewriter.inline_block_after(block, self.current_operation)
self.inline_block_after(block, self.current_operation)

def inline_block_after(self, block: Block, op: Operation):
"""
Move the block operations after the given operation.
The block should not be a parent of the operation, and should be a child of the
matched operation.
The operation should also be a child of the matched operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
if op is self.current_operation:
return self.inline_block_after_matched_op(block)
self._assert_can_modify_block(block)
if op.parent is not None:
self._assert_can_modify_block(op.parent)
Rewriter.inline_block_after(block, op)

def move_region_contents_to_new_regions(self, region: Region) -> Region:
"""
Move the region blocks to a new region.
The region should be a child of the matched operation.
"""
"""Move the region blocks to a new region."""
self.has_done_action = True
self._assert_can_modify_region(region)
return Rewriter.move_region_contents_to_new_regions(region)

def inline_region_before(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_block(target)
Rewriter.inline_region_before(region, target)

def inline_region_after(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_block(target)
Rewriter.inline_region_after(region, target)

def inline_region_at_start(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_region(target)
Rewriter.inline_region_at_start(region, target)

def inline_region_at_end(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
self._assert_can_modify_region(region)
self._assert_can_modify_region(target)
Rewriter.inline_region_at_end(region, target)

def iter_affected_ops(self) -> Iterable[Operation]:
"""
Iterate newly added operations, in the order that they are in the module.
"""
yield from self.added_operations_before
if not self.has_erased_matched_operation:
yield self.current_operation
yield from self.added_operations_after

def iter_affected_ops_reversed(self) -> Iterable[Operation]:
"""
Iterate newly added operations, in reverse order from that in the module.
"""
yield from reversed(self.added_operations_after)
if not self.has_erased_matched_operation:
yield self.current_operation
yield from reversed(self.added_operations_before)


class RewritePattern(ABC):
"""
Expand Down

0 comments on commit 21a2af8

Please sign in to comment.