Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interactive: add functions that make xdsl-gui add rewrite functionality to tool (4/4) #2131

Closed
wants to merge 45 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
8061a3e
zcvx
dshaaban01 Feb 7, 2024
d08c732
asdfasdfadsf
dshaaban01 Feb 7, 2024
e56aa32
asdf
dshaaban01 Feb 7, 2024
c27dcf4
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 7, 2024
278d6dd
asdfasdfasdf
dshaaban01 Feb 7, 2024
a1aa5cf
hii
dshaaban01 Feb 7, 2024
03a488c
asdfasdf
dshaaban01 Feb 7, 2024
3512695
sasha
dshaaban01 Feb 7, 2024
d192850
asdlkfja;dsklfjadslkfjjaks
dshaaban01 Feb 7, 2024
9eec6da
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 7, 2024
e126345
adsfff
dshaaban01 Feb 7, 2024
9137c78
asdfasdfasdf
dshaaban01 Feb 7, 2024
e069c3b
fix pytest
dshaaban01 Feb 8, 2024
7acf915
add test for rewrite
dshaaban01 Feb 8, 2024
e6545b7
sasha
dshaaban01 Feb 10, 2024
3fe3786
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 10, 2024
5963489
asdf
dshaaban01 Feb 10, 2024
f05d18b
asdf
dshaaban01 Feb 10, 2024
369cc25
adf
dshaaban01 Feb 10, 2024
d164b55
hii
dshaaban01 Feb 10, 2024
943aa27
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 10, 2024
1a0feb5
hii
dshaaban01 Feb 10, 2024
ca45748
asdf
dshaaban01 Feb 10, 2024
37c05b3
asdf
dshaaban01 Feb 10, 2024
6e950b8
dd
dshaaban01 Feb 10, 2024
809addc
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 10, 2024
9a7e22f
helloooo
dshaaban01 Feb 10, 2024
cb5e3cb
hello?
dshaaban01 Feb 10, 2024
3beb741
asdf
dshaaban01 Feb 10, 2024
6fd410e
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 10, 2024
b42ac71
hello
dshaaban01 Feb 11, 2024
a50f41f
hi
dshaaban01 Feb 11, 2024
ca63427
bla
dshaaban01 Feb 11, 2024
607b8c5
hello
dshaaban01 Feb 11, 2024
e9c639b
hi
dshaaban01 Feb 11, 2024
210ade6
asdf
dshaaban01 Feb 11, 2024
004d04b
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 11, 2024
dfd5817
Merge branch 'main' into dalia/interactive/oopsie
dshaaban01 Feb 11, 2024
81b45ea
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 11, 2024
4e8d948
df
dshaaban01 Feb 11, 2024
0e9cdb6
Merge branch 'main' into dalia/interactive/oopsie
dshaaban01 Feb 12, 2024
2cf45d6
erd
dshaaban01 Feb 12, 2024
037ff6a
Merge branch 'main' into dalia/interactive/oopsie
dshaaban01 Feb 12, 2024
7972c6b
hiiii back in buizznezz
dshaaban01 Feb 12, 2024
a65424c
Merge branch 'dalia/interactive/oopsie' into dalia/interactive/oopsie2
dshaaban01 Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 173 additions & 11 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@
ModuleOp,
UnrealizedConversionCastOp,
)
from xdsl.interactive.app import InputApp
from xdsl.interactive.app import AvailablePass, InputApp
from xdsl.ir import Block, Region
from xdsl.transforms import (
canonicalize,
individual_rewrite,
mlir_opt,
printf_to_llvm,
scf_parallel_loop_tiling,
stencil_unroll,
test_lower_linalg_to_snitch,
)
from xdsl.transforms.experimental import (
hls_convert_stencil_to_ll_mlir,
)
from xdsl.transforms.experimental.dmp import stencil_global_to_local
from xdsl.utils.exceptions import ParseError
from xdsl.utils.parse_pipeline import PipelinePassSpec
from xdsl.utils.parse_pipeline import PipelinePassSpec, parse_pipeline


@pytest.mark.asyncio()
Expand Down Expand Up @@ -267,15 +269,51 @@ async def test_buttons():

condensed_list = tuple(
(
individual_rewrite.IndividualRewrite,
convert_arith_to_riscv.ConvertArithToRiscvPass,
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
stencil_global_to_local.DistributeStencilPass,
hls_convert_stencil_to_ll_mlir.HLSConvertStencilToLLMLIRPass,
mlir_opt.MLIROptPass,
printf_to_llvm.PrintfToLLVM,
scf_parallel_loop_tiling.ScfParallelLoopTilingPass,
stencil_unroll.StencilUnrollPass,
AvailablePass(
display_name="apply-individual-rewrite",
module_pass=individual_rewrite.IndividualRewrite,
pass_spec=None,
),
AvailablePass(
display_name="convert-arith-to-riscv",
module_pass=convert_arith_to_riscv.ConvertArithToRiscvPass,
pass_spec=None,
),
AvailablePass(
display_name="convert-func-to-riscv-func",
module_pass=convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
pass_spec=None,
),
AvailablePass(
display_name="distribute-stencil",
module_pass=stencil_global_to_local.DistributeStencilPass,
pass_spec=None,
),
AvailablePass(
display_name="hls-convert-stencil-to-ll-mlir",
module_pass=hls_convert_stencil_to_ll_mlir.HLSConvertStencilToLLMLIRPass,
pass_spec=None,
),
AvailablePass(
display_name="mlir-opt",
module_pass=mlir_opt.MLIROptPass,
pass_spec=None,
),
AvailablePass(
display_name="printf-to-llvm",
module_pass=printf_to_llvm.PrintfToLLVM,
pass_spec=None,
),
AvailablePass(
display_name="scf-parallel-loop-tiling",
module_pass=scf_parallel_loop_tiling.ScfParallelLoopTilingPass,
pass_spec=None,
),
AvailablePass(
display_name="stencil-unroll",
module_pass=stencil_unroll.StencilUnrollPass,
pass_spec=None,
),
)
)

Expand All @@ -292,6 +330,130 @@ async def test_buttons():
assert app.condense_mode is False


@pytest.mark.asyncio()
async def test_rewrites():
"""Test rewrite application has the desired result."""
async with InputApp().run_test() as pilot:
app = cast(InputApp, pilot.app)
# clear preloaded code and unselect preselected pass
app.input_text_area.clear()

await pilot.pause()
# Testing a pass
app.input_text_area.insert(
"""
func.func @hello(%n : i32) -> i32 {
%two = arith.constant 0 : i32
%res = arith.addi %two, %n : i32
func.return %res : i32
}
"""
)

# press "Condense" button
await pilot.click("#condense_button")

condensed_list = tuple(
(
AvailablePass(
display_name="apply-individual-rewrite",
module_pass=individual_rewrite.IndividualRewrite,
pass_spec=None,
),
AvailablePass(
display_name="canonicalize",
module_pass=canonicalize.CanonicalizePass,
pass_spec=None,
),
AvailablePass(
display_name="convert-arith-to-riscv",
module_pass=convert_arith_to_riscv.ConvertArithToRiscvPass,
pass_spec=None,
),
AvailablePass(
display_name="convert-func-to-riscv-func",
module_pass=convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
pass_spec=None,
),
AvailablePass(
display_name="distribute-stencil",
module_pass=stencil_global_to_local.DistributeStencilPass,
pass_spec=None,
),
AvailablePass(
display_name="hls-convert-stencil-to-ll-mlir",
module_pass=hls_convert_stencil_to_ll_mlir.HLSConvertStencilToLLMLIRPass,
pass_spec=None,
),
AvailablePass(
display_name="mlir-opt",
module_pass=mlir_opt.MLIROptPass,
pass_spec=None,
),
AvailablePass(
display_name="printf-to-llvm",
module_pass=printf_to_llvm.PrintfToLLVM,
pass_spec=None,
),
AvailablePass(
display_name="scf-parallel-loop-tiling",
module_pass=scf_parallel_loop_tiling.ScfParallelLoopTilingPass,
pass_spec=None,
),
AvailablePass(
display_name="stencil-unroll",
module_pass=stencil_unroll.StencilUnrollPass,
pass_spec=None,
),
AvailablePass(
display_name="test-lower-linalg-to-snitch",
module_pass=test_lower_linalg_to_snitch.TestLowerLinalgToSnitchPass,
pass_spec=None,
),
AvailablePass(
display_name="Addi(%res = arith.addi %two, %n : i32):arith.addi:AddImmediateZero",
module_pass=individual_rewrite.IndividualRewrite,
pass_spec=list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}'
)
)[0],
),
)
)

await pilot.pause()
# assert after "Condense Button" is clicked that the state and get_condensed_pass list change accordingly
assert app.condense_mode is True
assert app.available_pass_list == condensed_list

# Select a rewrite
app.pass_pipeline = (
*app.pass_pipeline,
(
individual_rewrite.IndividualRewrite,
list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}'
)
)[0],
),
)

# assert that pass selection affected Output Text Area
await pilot.pause()
assert (
app.output_text_area.text
== """builtin.module {
func.func @hello(%n : i32) -> i32 {
%two = arith.constant 0 : i32
func.return %n : i32
}
}
"""
)


@pytest.mark.asyncio()
async def test_passes():
"""Test pass application has the desired result."""
Expand Down
58 changes: 58 additions & 0 deletions tests/interactive/test_get_all_possible_rewrites.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from xdsl.dialects.builtin import (
StringAttr,
)
from xdsl.dialects.test import TestOp
from xdsl.interactive.get_all_possible_rewrites import (
IndexedIndividualRewrite,
IndividualRewrite,
get_all_possible_rewrites,
)
from xdsl.ir import MLContext
from xdsl.parser import Parser
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.tools.command_line_tool import get_all_dialects


def test_get_all_possible_rewrite():
# build module
prog = """
builtin.module {
"test.op"() {"label" = "a"} : () -> ()
"test.op"() {"label" = "a"} : () -> ()
"test.op"() {"label" = "b"} : () -> ()
}
"""

ctx = MLContext(True)
for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)
parser = Parser(ctx, prog)
module = parser.parse_module()

class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TestOp, rewriter: PatternRewriter):
if op.attributes["label"] != StringAttr("a"):
return
rewriter.replace_matched_op(TestOp(attributes={"label": StringAttr("c")}))

expected_res = (
(
IndexedIndividualRewrite(
1, IndividualRewrite(operation="test.op", pattern="TestRewrite")
)
),
(
IndexedIndividualRewrite(
operation_index=2,
rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"),
)
),
)

res = get_all_possible_rewrites(module, {"test.op": {"TestRewrite": Rewrite()}})
assert res == expected_res
Loading
Loading