diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 0d0c3868f8..10583b90ff 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -53,8 +53,8 @@ def adapted_foast_to_gtir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i return toolchain.StripArgsAdapter(foast_to_gtir_factory(**kwargs)) -def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node.type): +def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: + if not type_info.contains_local_field(node_type): return lambda x: im.op_as_fieldop("make_const_list")(x) return lambda x: x @@ -215,16 +215,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._map("not_", node.operand) + return self._lower_and_map("not_", node.operand) - return self._map( + return self._lower_and_map( node.op.value, foast.Constant(value="0", type=dtype, location=node.location), node.operand, ) def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: assert ( @@ -236,7 +236,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: current_expr = self.visit(node.func, **kwargs) @@ -338,34 +338,43 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id - def create_cast(expr: itir.Expr, t: ts.TypeSpec) -> itir.FunCall: - if isinstance(t, ts.FieldType): + def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: + if isinstance(t[0], ts.FieldType): return im.as_fieldop( im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type))) )(expr) else: - assert isinstance(t, ts.ScalarType) + assert isinstance(t[0], ts.ScalarType) return im.call("cast_")(expr, str(new_type)) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler - return create_cast(obj, node.type) + return create_cast(obj, (node.args[0].type,)) - return lowering_utils.process_elements(create_cast, obj, node.type, with_type=True) + return lowering_utils.process_elements( + create_cast, obj, node.type, arg_types=(node.args[0].type,) + ) def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: if not isinstance(node.type, ts.TupleType): # to keep the IR simpler - return im.op_as_fieldop("if_")(*self.visit(node.args)) + return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" - def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: - return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_) + def create_if( + true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] + ) -> itir.FunCall: + return _map( + "if_", + (im.ref(cond_symref_name), true_, false_), + (node.args[0].type, *arg_types), + ) result = lowering_utils.process_elements( create_if, (self.visit(node.args[1]), self.visit(node.args[2])), node.type, + arg_types=(node.args[1].type, node.args[2].type), ) return im.let(cond_symref_name, cond_)(result) @@ -377,7 +386,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return im.as_fieldop(im.ref("deref"))(expr) def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args) + return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any @@ -436,19 +445,34 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - lowered_args = [self.visit(arg, **kwargs) for arg in args] - if all( - isinstance(t, ts.ScalarType) - for arg in args - for t in type_info.primitive_constituents(arg.type) - ): - return im.call(op)(*lowered_args) # scalar operation - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - op = im.call("map_")(op) + def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: + return _map( + op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) + ) + + +def _map( + op: itir.Expr | str, + lowered_args: tuple, + original_arg_types: tuple[ts.TypeSpec, ...], +) -> itir.FunCall: + """ + Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. + """ + if all( + isinstance(t, ts.ScalarType) + for arg_type in original_arg_types + for t in type_info.primitive_constituents(arg_type) + ): + return im.call(op)(*lowered_args) # scalar operation + if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): + lowered_args = tuple( + promote_to_list(arg_type)(larg) + for arg_type, larg in zip(original_arg_types, lowered_args) + ) + op = im.call("map_")(op) - return im.op_as_fieldop(im.call(op))(*lowered_args) + return im.op_as_fieldop(im.call(op))(*lowered_args) class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 7936eda1cf..538b0f3ddb 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -55,8 +55,8 @@ def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) -def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node.type): +def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: + if not type_info.contains_local_field(node_type): return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) return lambda x: x @@ -267,16 +267,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._map("not_", node.operand) + return self._lower_and_map("not_", node.operand) - return self._map( + return self._lower_and_map( node.op.value, foast.Constant(value="0", type=dtype, location=node.location), node.operand, ) def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: op = "if_" @@ -286,7 +286,9 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC for arg in args ] if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] + lowered_args = [ + promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args) + ] op = im.call("map_")(op) return lowering_utils.to_tuples_of_iterator( @@ -294,7 +296,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: current_expr = self.visit(node.func, **kwargs) @@ -408,9 +410,12 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: lowered_condition = self.visit(condition, **kwargs) return lowering_utils.process_elements( - lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv), + lambda tv, fv, types: _map( + "if_", (lowered_condition, tv, fv), (condition.type, *types) + ), [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], node.type, + (node.args[1].type, node.args[2].type), ) _visit_concat_where = _visit_where @@ -419,7 +424,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self.visit(node.args[0], **kwargs) def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args) + return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any @@ -480,13 +485,28 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - lowered_args = [self.visit(arg, **kwargs) for arg in args] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - op = im.call("map_")(op) + def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: + return _map( + op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) + ) + + +def _map( + op: itir.Expr | str, + lowered_args: tuple, + original_arg_types: tuple[ts.TypeSpec, ...], +) -> itir.FunCall: + """ + Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists. + """ + if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): + lowered_args = tuple( + promote_to_list(arg_type)(larg) + for arg_type, larg in zip(original_arg_types, lowered_args) + ) + op = im.call("map_")(op) - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index a52581edb0..7049f70021 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import Any, Callable, TypeVar +from typing import Any, Callable, Optional, TypeVar from gt4py.eve import utils as eve_utils from gt4py.next.ffront import type_info as ti_ffront @@ -102,7 +102,7 @@ def process_elements( process_func: Callable[..., itir.Expr], objs: itir.Expr | Iterable[itir.Expr], current_el_type: ts.TypeSpec, - with_type: bool = False, + arg_types: Optional[Iterable[ts.TypeSpec]] = None, ) -> itir.FunCall: """ Recursively applies a processing function to all primitive constituents of a tuple. @@ -113,9 +113,9 @@ def process_elements( objs: The object whose elements are to be transformed. current_el_type: A type with the same structure as the elements of `objs`. The leaf-types are not used and thus not relevant. - current_el_type: A type with the same structure as the elements of `objs`. Unless `with_type=True` - the leaf-types are not used and thus not relevant. - with_type: If True, the last argument passed to `process_func` will be its type. + arg_types: If provided, a tuple of the type of each argument is passed to `process_func` as last argument. + Note, that `arg_types` might coincide with `(current_el_type,)*len(objs)`, but not necessarily, + in case of implicit broadcasts. """ if isinstance(objs, itir.Expr): objs = (objs,) @@ -125,7 +125,7 @@ def process_elements( process_func, tuple(im.ref(let_id) for let_id in let_ids), current_el_type, - with_type=with_type, + arg_types=arg_types, ) return im.let(*(zip(let_ids, objs, strict=True)))(body) @@ -138,7 +138,7 @@ def _process_elements_impl( process_func: Callable[..., itir.Expr], _current_el_exprs: Iterable[T], current_el_type: ts.TypeSpec, - with_type: bool, + arg_types: Optional[Iterable[ts.TypeSpec]], ) -> itir.Expr: if isinstance(current_el_type, ts.TupleType): result = im.make_tuple( @@ -149,16 +149,16 @@ def _process_elements_impl( im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs ), current_el_type.types[i], - with_type=with_type, + arg_types=tuple(arg_t.types[i] for arg_t in arg_types) # type: ignore[attr-defined] # guaranteed by the requirement that `current_el_type` and each element of `arg_types` have the same tuple structure + if arg_types is not None + else None, ) for i in range(len(current_el_type.types)) ) ) - elif type_info.contains_local_field(current_el_type): - raise NotImplementedError("Processing fields with local dimension is not implemented.") else: - if with_type: - result = process_func(*_current_el_exprs, current_el_type) + if arg_types is not None: + result = process_func(*_current_el_exprs, arg_types) else: result = process_func(*_current_el_exprs) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d85cd5b3df..9fb7850666 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -69,6 +69,7 @@ IJKField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.int32] # type: ignore [valid-type] IJKFloatField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.float64] # type: ignore [valid-type] VField: TypeAlias = gtx.Field[[Vertex], np.int32] # type: ignore [valid-type] +VBoolField: TypeAlias = gtx.Field[[Vertex], bool] # type: ignore [valid-type] EField: TypeAlias = gtx.Field[[Edge], np.int32] # type: ignore [valid-type] CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type] EmptyField: TypeAlias = gtx.Field[[], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 3777de7843..29966c30ad 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -220,6 +220,94 @@ def testee(flux: cases.EField) -> cases.VField: ) +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(where(mask, inp(V2E), inp(V2E)), axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].table + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table], + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where_and_tuples(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(where(mask, (inp(V2E), inp(V2E)), (inp(V2E), inp(V2E)))[1], axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].table + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table], + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where_and_scalar(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(inp(V2E) + where(mask, inp(V2E), 1), axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].table + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table] + + np.where(np.expand_dims(mask.asnumpy(), 1), inp.asnumpy()[v2e_table], 1), + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + @pytest.mark.uses_tuple_returns def test_conditional_nested_tuple(cartesian_case): @gtx.field_operator