diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 4b7cc45adc..83cc922dc6 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -310,6 +310,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): node = self.generic_visit(node, **kwargs) if cpm.is_call_to(node, "make_tuple"): + # TODO(tehrengruber): x, y = alpha * y, x is not fused as_fieldop_args = [arg for arg in node.args if cpm.is_applied_as_fieldop(arg)] distinct_domains = set(arg.fun.args[1] for arg in as_fieldop_args) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop if len(distinct_domains) != len(as_fieldop_args): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index e9cb016313..1b95183422 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -169,6 +169,45 @@ def test_make_tuple_fusion_trivial(): ) assert actual_simplified == expected +def test_make_tuple_fusion_symref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.ref("b", field_type), + ) + expected = im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref2(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.ref("a", field_type), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + def test_make_tuple_fusion_different_domains(): d1 = im.domain("cartesian_domain", {IDim: (0, 1)})