Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interactive: add rewrite functionality to gui tool 3/3 #2155

Merged
merged 18 commits into from
Feb 19, 2024
128 changes: 127 additions & 1 deletion tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
from xdsl.interactive.get_condensed_passes import AvailablePass
from xdsl.ir import Block, Region
from xdsl.transforms import (
canonicalize,
individual_rewrite,
mlir_opt,
printf_to_llvm,
scf_parallel_loop_tiling,
stencil_unroll,
test_lower_linalg_to_snitch,
)
from xdsl.transforms.experimental import (
hls_convert_stencil_to_ll_mlir,
)
from xdsl.transforms.experimental.dmp import stencil_global_to_local
from xdsl.utils.exceptions import ParseError
from xdsl.utils.parse_pipeline import PipelinePassSpec
from xdsl.utils.parse_pipeline import PipelinePassSpec, parse_pipeline


@pytest.mark.asyncio()
Expand Down Expand Up @@ -329,6 +331,130 @@ async def test_buttons():
assert app.condense_mode is False


@pytest.mark.asyncio()
async def test_rewrites():
"""Test rewrite application has the desired result."""
async with InputApp().run_test() as pilot:
app = cast(InputApp, pilot.app)
# clear preloaded code and unselect preselected pass
app.input_text_area.clear()

await pilot.pause()
# Testing a pass
app.input_text_area.insert(
"""
func.func @hello(%n : i32) -> i32 {
%two = arith.constant 0 : i32
%res = arith.addi %two, %n : i32
func.return %res : i32
}
"""
)

# press "Condense" button
await pilot.click("#condense_button")

condensed_list = tuple(
(
AvailablePass(
display_name="apply-individual-rewrite",
module_pass=individual_rewrite.IndividualRewrite,
pass_spec=None,
),
AvailablePass(
display_name="canonicalize",
module_pass=canonicalize.CanonicalizePass,
pass_spec=None,
),
AvailablePass(
display_name="convert-arith-to-riscv",
module_pass=convert_arith_to_riscv.ConvertArithToRiscvPass,
pass_spec=None,
),
AvailablePass(
display_name="convert-func-to-riscv-func",
module_pass=convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
pass_spec=None,
),
AvailablePass(
display_name="distribute-stencil",
module_pass=stencil_global_to_local.DistributeStencilPass,
pass_spec=None,
),
AvailablePass(
display_name="hls-convert-stencil-to-ll-mlir",
module_pass=hls_convert_stencil_to_ll_mlir.HLSConvertStencilToLLMLIRPass,
pass_spec=None,
),
AvailablePass(
display_name="mlir-opt",
module_pass=mlir_opt.MLIROptPass,
pass_spec=None,
),
AvailablePass(
display_name="printf-to-llvm",
module_pass=printf_to_llvm.PrintfToLLVM,
pass_spec=None,
),
AvailablePass(
display_name="scf-parallel-loop-tiling",
module_pass=scf_parallel_loop_tiling.ScfParallelLoopTilingPass,
pass_spec=None,
),
AvailablePass(
display_name="stencil-unroll",
module_pass=stencil_unroll.StencilUnrollPass,
pass_spec=None,
),
AvailablePass(
display_name="test-lower-linalg-to-snitch",
module_pass=test_lower_linalg_to_snitch.TestLowerLinalgToSnitchPass,
pass_spec=None,
),
AvailablePass(
display_name="Addi(%res = arith.addi %two, %n : i32):arith.addi:AddImmediateZero",
module_pass=individual_rewrite.IndividualRewrite,
pass_spec=list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}'
)
)[0],
),
)
)

await pilot.pause()
# assert after "Condense Button" is clicked that the state and get_condensed_pass list change accordingly
assert app.condense_mode is True
assert app.available_pass_list == condensed_list

# Select a rewrite
app.pass_pipeline = (
*app.pass_pipeline,
(
individual_rewrite.IndividualRewrite,
list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}'
)
)[0],
),
)

# assert that pass selection affected Output Text Area
await pilot.pause()
assert (
app.output_text_area.text
== """builtin.module {
func.func @hello(%n : i32) -> i32 {
%two = arith.constant 0 : i32
func.return %n : i32
}
}
"""
)


@pytest.mark.asyncio()
async def test_passes():
"""Test pass application has the desired result."""
Expand Down
51 changes: 45 additions & 6 deletions xdsl/interactive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from xdsl.dialects.builtin import ModuleOp
from xdsl.interactive.add_arguments_screen import AddArguments
from xdsl.interactive.get_all_possible_rewrites import get_all_possible_rewrites
from xdsl.interactive.get_condensed_passes import (
ALL_PASSES,
AvailablePass,
Expand All @@ -46,6 +47,7 @@
from xdsl.passes import ModulePass, PipelinePass, get_pass_argument_names_and_types
from xdsl.printer import Printer
from xdsl.tools.command_line_tool import get_all_dialects, get_all_passes
from xdsl.transforms import individual_rewrite
from xdsl.utils.exceptions import PassPipelineParseError
from xdsl.utils.parse_pipeline import PipelinePassSpec, parse_pipeline

Expand Down Expand Up @@ -249,10 +251,41 @@ def compute_available_pass_list(self) -> tuple[AvailablePass, ...]:
case Exception():
return ()
case ModuleOp():
# transform rewrites into passes
rewrites = get_all_possible_rewrites(
self.current_module,
individual_rewrite.REWRITE_BY_NAMES,
)
rewrites_as_pass_list: tuple[AvailablePass, ...] = ()
for op_idx, (op_name, pat_name) in rewrites:
rewrite_pass = individual_rewrite.IndividualRewrite
rewrite_spec = PipelinePassSpec(
name=rewrite_pass.name,
args={
"matched_operation_index": [op_idx],
"operation_name": [op_name],
"pattern_name": [pat_name],
},
)
op = list(self.current_module.walk())[op_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like we only use the op to get its description, maybe we'd be better off getting that as part of the available rewrites? It also feels wasteful to get the nth op for every op if we already had this information in the past

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add the op to what "get_all_possible_rewrites" returns. what do u think? does this make sense?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered removing the operation index, but I think we should leave it for now as it could be useful for future plans we have for functionalities (line numbers etc.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@superlopuh I realize now that maybe it is better to leave it the way that it is.

returning the str(Operation) from get_all_possible_rewrites():

When the rewrite is displayed in the gui, we now get

%0 = arith.addi %1, %2
instead of
Addi(%res = arith.addi %two, %n : i32)

returning the Operation from get_all_possible_rewrites() - (This wont work either way due to issues with testing and equality checking of an Operation):

When the rewrite is displayed in the gui, we now get

Addi(%0 = arith.addi %1, %2)
instead of
Addi(%res = arith.addi %two, %n : i32)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2162
@superlopuh told me to file an issue. we will address this later. will merge as is.

rewrites_as_pass_list = (
*rewrites_as_pass_list,
(
AvailablePass(
f"{op}:{op_name}:{pat_name}", rewrite_pass, rewrite_spec
)
),
)

# merge rewrite passes with "other" pass list
if self.condense_mode:
return get_condensed_pass_list(self.current_module)
pass_list = get_condensed_pass_list(self.current_module)
return pass_list + rewrites_as_pass_list
else:
return tuple(AvailablePass(p.name, p, None) for _, p in ALL_PASSES)
pass_list = tuple(
AvailablePass(p.name, p, None) for _, p in ALL_PASSES
)
return pass_list + rewrites_as_pass_list

def watch_available_pass_list(
self,
Expand All @@ -275,7 +308,11 @@ def watch_available_pass_list(
)
)

def get_pass_arguments(self, selected_pass_value: type[ModulePass]) -> None:
def get_pass_arguments(
self,
selected_pass_value: type[ModulePass],
selected_pass_spec: PipelinePassSpec | None,
) -> None:
"""
This function facilitates user input of pass concatenated_arg_val by navigating
to the AddArguments screen, and subsequently parses the returned string upon
Expand Down Expand Up @@ -306,7 +343,7 @@ def add_pass_with_arguments_to_pass_pipeline(concatenated_arg_val: str) -> None:
self.push_screen(screen, add_pass_with_arguments_to_pass_pipeline)

# if selected_pass_value has arguments, push screen
if fields(selected_pass_value):
if fields(selected_pass_value) and selected_pass_spec is None:
# generates a string containing the concatenated_arg_val and types of the selected pass and initializes the AddArguments Screen to contain the string
self.push_screen(
AddArguments(
Expand All @@ -319,9 +356,11 @@ def add_pass_with_arguments_to_pass_pipeline(concatenated_arg_val: str) -> None:
)
else:
# add the selected pass to pass_pipeline
if selected_pass_spec is None:
selected_pass_spec = selected_pass_value().pipeline_pass_spec()
self.pass_pipeline = (
*self.pass_pipeline,
(selected_pass_value, selected_pass_value().pipeline_pass_spec()),
(selected_pass_value, selected_pass_spec),
)

@on(ListView.Selected)
Expand All @@ -332,7 +371,7 @@ def update_pass_pipeline(self, event: ListView.Selected) -> None:
"""
list_item = event.item
assert isinstance(list_item, PassListItem)
self.get_pass_arguments(list_item.module_pass)
self.get_pass_arguments(list_item.module_pass, list_item.pass_spec)

def watch_pass_pipeline(self) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions xdsl/interactive/get_all_possible_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_all_possible_rewrites(
"""
Function that takes a sequence of IndividualRewrite Patterns and a ModuleOp, and
returns the possible rewrites.
Issue filed: https://github.com/xdslproject/xdsl/issues/2162
"""
old_module = op.clone()
num_ops = len(list(old_module.walk()))
Expand Down
Loading