Skip to content

Commit

Permalink
Add additional make_tuple test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Jan 6, 2025
1 parent ce2261d commit d04c4dc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
node = self.generic_visit(node, **kwargs)

if cpm.is_call_to(node, "make_tuple"):
# TODO(tehrengruber): x, y = alpha * y, x is not fused
as_fieldop_args = [arg for arg in node.args if cpm.is_applied_as_fieldop(arg)]
distinct_domains = set(arg.fun.args[1] for arg in as_fieldop_args) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop
if len(distinct_domains) != len(as_fieldop_args):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,45 @@ def test_make_tuple_fusion_trivial():
)
assert actual_simplified == expected

def test_make_tuple_fusion_symref():
d = im.domain("cartesian_domain", {IDim: (0, 1)})
testee = im.make_tuple(
im.as_fieldop("deref", d)(im.ref("a", field_type)),
im.ref("b", field_type),
)
expected = im.as_fieldop(
im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))),
d,
)(im.ref("a", field_type), im.ref("b", field_type))
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True
)
# simplify to remove unnecessary make_tuple call
actual_simplified = collapse_tuple.CollapseTuple.apply(
actual, within_stencil=False, allow_undeclared_symbols=True
)
assert actual_simplified == expected


def test_make_tuple_fusion_symref2():
d = im.domain("cartesian_domain", {IDim: (0, 1)})
testee = im.make_tuple(
im.as_fieldop("deref", d)(im.ref("a", field_type)),
im.ref("a", field_type),
)
expected = im.as_fieldop(
im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))),
d,
)(im.ref("a", field_type))
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True
)
# simplify to remove unnecessary make_tuple call
actual_simplified = collapse_tuple.CollapseTuple.apply(
actual, within_stencil=False, allow_undeclared_symbols=True
)
assert actual_simplified == expected


def test_make_tuple_fusion_different_domains():
d1 = im.domain("cartesian_domain", {IDim: (0, 1)})
Expand Down

0 comments on commit d04c4dc

Please sign in to comment.