diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index f06b2aff2f..bb966234aa 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -1777,3 +1777,32 @@ def convert_type(self, typ: IntegerType) -> IndexType: op_replaced=1, op_modified=1, ) + + +def test_pattern_rewriter_erase_op_with_region(): + """Test that erasing an operation with a region works correctly.""" + prog = """ +"builtin.module"() ({ + "test.op"() ({ + "test.op"() {"error_if_matching"} : () -> () + }): () -> () +}) : () -> ()""" + expected = """ +"builtin.module"() ({ +^0: +}) : () -> ()""" + + class Rewrite(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): + if "error_if_matching" in op.attributes: + raise Exception("operation that is supposed to be deleted was matched") + assert not op.attributes + rewriter.erase_matched_op() + + rewrite_and_compare( + prog, + expected, + PatternRewriteWalker(Rewrite(), apply_recursively=False), + op_removed=1, + ) diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index a8282bc21a..06a6bd5239 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -726,7 +726,11 @@ def _handle_operation_removal(self, op: Operation) -> None: """Handle removal of an operation.""" if self.apply_recursively: self._add_operands_to_worklist(op.operands) - self._worklist.remove(op) + if op.regions: + for sub_op in op.walk(): + self._worklist.remove(sub_op) + else: + self._worklist.remove(op) def _handle_operation_modification(self, op: Operation) -> None: """Handle modification of an operation."""