From 6f43e5a8ded48874bc1f33c3460cf3ce22ce3ff7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 10 Jan 2025 02:25:44 +0100 Subject: [PATCH] Only inline scalars outside of stencils --- .../next/iterator/transforms/inline_scalar.py | 6 +++ .../transforms_tests/test_inline_scalar.py | 47 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index 87b576d14d..d8a6e14d8a 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -21,6 +21,12 @@ def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProvide program = itir_inference.infer(program, offset_provider_type=offset_provider_type) return cls().visit(program) + def generic_visit(self, node, **kwargs): + if cpm.is_call_to(node, "as_fieldop"): + return node + + return super().generic_visit(node, **kwargs) + def visit_Expr(self, node: itir.Expr): node = self.generic_visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py new file mode 100644 index 0000000000..3e655b71f4 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import pytest + +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.transforms import inline_scalar +from gt4py.next.iterator.ir_utils import ir_makers as im + +TDim = common.Dimension(value="TDim") +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) + + +def program_factory(expr: itir.Expr) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[im.sym("out", ts.FieldType(dims=[TDim], dtype=int_type))], + declarations=[], + body=[ + itir.SetAt( + expr=expr, + target=im.ref("out"), + domain=im.domain(common.GridType.CARTESIAN, {TDim: (0, 1)}), + ) + ], + ) + + +def test_simple(): + testee = program_factory(im.let("a", 1)(im.op_as_fieldop("plus")("a", "a"))) + expected = program_factory(im.op_as_fieldop("plus")(1, 1)) + actual = inline_scalar.InlineScalar.apply(testee, offset_provider_type={}) + assert actual == expected + + +def test_fo_inline_only(): + scalar_expr = im.let("a", 1)(im.plus("a", "a")) + testee = program_factory(im.as_fieldop(im.lambda_()(scalar_expr))()) + actual = inline_scalar.InlineScalar.apply(testee, offset_provider_type={}) + assert actual == testee