From 938f6f581def197897b6039e05c0310b7d19e507 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Mon, 9 Dec 2024 10:40:19 +0000 Subject: [PATCH 1/2] interactive: small refactor --- tests/interactive/test_rewrites.py | 4 ++-- xdsl/interactive/rewrites.py | 33 +++++++++++------------------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/tests/interactive/test_rewrites.py b/tests/interactive/test_rewrites.py index b9010b093f..04a650f80a 100644 --- a/tests/interactive/test_rewrites.py +++ b/tests/interactive/test_rewrites.py @@ -45,7 +45,7 @@ def test_get_all_possible_rewrite(): parser = Parser(ctx, prog) module = parser.parse_module() - expected_res = ( + expected_res = [ ( IndexedIndividualRewrite( 1, IndividualRewrite(operation="test.op", pattern="TestRewrite") @@ -57,7 +57,7 @@ def test_get_all_possible_rewrite(): rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"), ) ), - ) + ] res = get_all_possible_rewrites(module, {"test.op": {"TestRewrite": Rewrite()}}) assert res == expected_res diff --git a/xdsl/interactive/rewrites.py b/xdsl/interactive/rewrites.py index 421df1eace..59ba48da14 100644 --- a/xdsl/interactive/rewrites.py +++ b/xdsl/interactive/rewrites.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import NamedTuple from xdsl.dialects.builtin import ModuleOp @@ -51,40 +52,30 @@ def convert_indexed_individual_rewrites_to_available_pass( def get_all_possible_rewrites( - op: ModuleOp, + module: ModuleOp, rewrite_by_name: dict[str, dict[str, RewritePattern]], -) -> tuple[IndexedIndividualRewrite, ...]: +) -> Sequence[IndexedIndividualRewrite]: """ Function that takes a sequence of IndividualRewrite Patterns and a ModuleOp, and returns the possible rewrites. Issue filed: https://github.com/xdslproject/xdsl/issues/2162 """ - old_module = op.clone() - num_ops = len(list(old_module.walk())) - current_module = old_module.clone() + res: list[IndexedIndividualRewrite] = [] - res: tuple[IndexedIndividualRewrite, ...] = () - - for op_idx in range(num_ops): - matched_op = list(current_module.walk())[op_idx] + for op_idx, matched_op in enumerate(module.walk()): if matched_op.name not in rewrite_by_name: continue pattern_by_name = rewrite_by_name[matched_op.name] - for pattern_name, pattern in pattern_by_name.items(): - rewriter = PatternRewriter(matched_op) - pattern.match_and_rewrite(matched_op, rewriter) + cloned_op = tuple(module.clone().walk())[op_idx] + rewriter = PatternRewriter(cloned_op) + pattern.match_and_rewrite(cloned_op, rewriter) if rewriter.has_done_action: - res = ( - *res, - ( - IndexedIndividualRewrite( - op_idx, IndividualRewrite(matched_op.name, pattern_name) - ) - ), + res.append( + IndexedIndividualRewrite( + op_idx, IndividualRewrite(cloned_op.name, pattern_name) + ) ) - current_module = old_module.clone() - matched_op = list(current_module.walk())[op_idx] return res From a9ae013ac7f0941ff645d8e868de68000c562272 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Mon, 9 Dec 2024 10:45:25 +0000 Subject: [PATCH 2/2] Fix pyright issue --- xdsl/interactive/rewrites.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/interactive/rewrites.py b/xdsl/interactive/rewrites.py index 59ba48da14..985cdbcb18 100644 --- a/xdsl/interactive/rewrites.py +++ b/xdsl/interactive/rewrites.py @@ -27,7 +27,7 @@ class IndexedIndividualRewrite(NamedTuple): def convert_indexed_individual_rewrites_to_available_pass( - rewrites: tuple[IndexedIndividualRewrite, ...], current_module: ModuleOp + rewrites: Sequence[IndexedIndividualRewrite], current_module: ModuleOp ) -> tuple[AvailablePass, ...]: """ Function that takes a tuple of rewrites, converts each rewrite into an IndividualRewrite pass and returns the tuple of AvailablePass.