Skip to content

Commit

Permalink
bug[next]: fix lowering of tuples of neighbors in conditionals (#1710)
Browse files Browse the repository at this point in the history
Use the `_map` function in all cases where mapping of with
as_fieldop/lifted stencil *and* mapping of lists is required.
  • Loading branch information
havogt authored Nov 1, 2024
1 parent f1c9d83 commit e8f11fe
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 54 deletions.
78 changes: 51 additions & 27 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def adapted_foast_to_gtir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i
return toolchain.StripArgsAdapter(foast_to_gtir_factory(**kwargs))


def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node.type):
def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node_type):
return lambda x: im.op_as_fieldop("make_const_list")(x)
return lambda x: x

Expand Down Expand Up @@ -215,16 +215,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
return self._map("not_", node.operand)
return self._lower_and_map("not_", node.operand)

return self._map(
return self._lower_and_map(
node.op.value,
foast.Constant(value="0", type=dtype, location=node.location),
node.operand,
)

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall:
assert (
Expand All @@ -236,7 +236,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
)

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
current_expr = self.visit(node.func, **kwargs)
Expand Down Expand Up @@ -338,34 +338,43 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id

def create_cast(expr: itir.Expr, t: ts.TypeSpec) -> itir.FunCall:
if isinstance(t, ts.FieldType):
def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall:
if isinstance(t[0], ts.FieldType):
return im.as_fieldop(
im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type)))
)(expr)
else:
assert isinstance(t, ts.ScalarType)
assert isinstance(t[0], ts.ScalarType)
return im.call("cast_")(expr, str(new_type))

if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return create_cast(obj, node.type)
return create_cast(obj, (node.args[0].type,))

return lowering_utils.process_elements(create_cast, obj, node.type, with_type=True)
return lowering_utils.process_elements(
create_cast, obj, node.type, arg_types=(node.args[0].type,)
)

def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return im.op_as_fieldop("if_")(*self.visit(node.args))
return self._lower_and_map("if_", *node.args)

cond_ = self.visit(node.args[0])
cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}"

def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall:
return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_)
def create_if(
true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec]
) -> itir.FunCall:
return _map(
"if_",
(im.ref(cond_symref_name), true_, false_),
(node.args[0].type, *arg_types),
)

result = lowering_utils.process_elements(
create_if,
(self.visit(node.args[1]), self.visit(node.args[2])),
node.type,
arg_types=(node.args[1].type, node.args[2].type),
)

return im.let(cond_symref_name, cond_)(result)
Expand All @@ -377,7 +386,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return im.as_fieldop(im.ref("deref"))(expr)

def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._map(self.visit(node.func, **kwargs), *node.args)
return self._lower_and_map(self.visit(node.func, **kwargs), *node.args)

def _make_reduction_expr(
self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any
Expand Down Expand Up @@ -436,19 +445,34 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr:
def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr:
return self._make_literal(node.value, node.type)

def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
lowered_args = [self.visit(arg, **kwargs) for arg in args]
if all(
isinstance(t, ts.ScalarType)
for arg in args
for t in type_info.primitive_constituents(arg.type)
):
return im.call(op)(*lowered_args) # scalar operation
if any(type_info.contains_local_field(arg.type) for arg in args):
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
op = im.call("map_")(op)
def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
return _map(
op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args)
)


def _map(
op: itir.Expr | str,
lowered_args: tuple,
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
if all(
isinstance(t, ts.ScalarType)
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
return im.call(op)(*lowered_args) # scalar operation
if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types):
lowered_args = tuple(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.call("map_")(op)

return im.op_as_fieldop(im.call(op))(*lowered_args)
return im.op_as_fieldop(im.call(op))(*lowered_args)


class FieldOperatorLoweringError(Exception): ...
50 changes: 35 additions & 15 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i
return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs))


def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node.type):
def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node_type):
return lambda x: im.promote_to_lifted_stencil("make_const_list")(x)
return lambda x: x

Expand Down Expand Up @@ -267,16 +267,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
return self._map("not_", node.operand)
return self._lower_and_map("not_", node.operand)

return self._map(
return self._lower_and_map(
node.op.value,
foast.Constant(value="0", type=dtype, location=node.location),
node.operand,
)

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall:
op = "if_"
Expand All @@ -286,15 +286,17 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
for arg in args
]
if any(type_info.contains_local_field(arg.type) for arg in args):
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
lowered_args = [
promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args)
]
op = im.call("map_")(op)

return lowering_utils.to_tuples_of_iterator(
im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type
)

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
current_expr = self.visit(node.func, **kwargs)
Expand Down Expand Up @@ -408,9 +410,12 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr:

lowered_condition = self.visit(condition, **kwargs)
return lowering_utils.process_elements(
lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv),
lambda tv, fv, types: _map(
"if_", (lowered_condition, tv, fv), (condition.type, *types)
),
[self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)],
node.type,
(node.args[1].type, node.args[2].type),
)

_visit_concat_where = _visit_where
Expand All @@ -419,7 +424,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self.visit(node.args[0], **kwargs)

def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._map(self.visit(node.func, **kwargs), *node.args)
return self._lower_and_map(self.visit(node.func, **kwargs), *node.args)

def _make_reduction_expr(
self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any
Expand Down Expand Up @@ -480,13 +485,28 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr:
def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr:
return self._make_literal(node.value, node.type)

def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
lowered_args = [self.visit(arg, **kwargs) for arg in args]
if any(type_info.contains_local_field(arg.type) for arg in args):
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
op = im.call("map_")(op)
def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
return _map(
op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args)
)


def _map(
op: itir.Expr | str,
lowered_args: tuple,
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
"""
Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists.
"""
if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types):
lowered_args = tuple(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.call("map_")(op)

return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)
return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)


class FieldOperatorLoweringError(Exception): ...
24 changes: 12 additions & 12 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable
from typing import Any, Callable, TypeVar
from typing import Any, Callable, Optional, TypeVar

from gt4py.eve import utils as eve_utils
from gt4py.next.ffront import type_info as ti_ffront
Expand Down Expand Up @@ -102,7 +102,7 @@ def process_elements(
process_func: Callable[..., itir.Expr],
objs: itir.Expr | Iterable[itir.Expr],
current_el_type: ts.TypeSpec,
with_type: bool = False,
arg_types: Optional[Iterable[ts.TypeSpec]] = None,
) -> itir.FunCall:
"""
Recursively applies a processing function to all primitive constituents of a tuple.
Expand All @@ -113,9 +113,9 @@ def process_elements(
objs: The object whose elements are to be transformed.
current_el_type: A type with the same structure as the elements of `objs`. The leaf-types
are not used and thus not relevant.
current_el_type: A type with the same structure as the elements of `objs`. Unless `with_type=True`
the leaf-types are not used and thus not relevant.
with_type: If True, the last argument passed to `process_func` will be its type.
arg_types: If provided, a tuple of the type of each argument is passed to `process_func` as last argument.
Note, that `arg_types` might coincide with `(current_el_type,)*len(objs)`, but not necessarily,
in case of implicit broadcasts.
"""
if isinstance(objs, itir.Expr):
objs = (objs,)
Expand All @@ -125,7 +125,7 @@ def process_elements(
process_func,
tuple(im.ref(let_id) for let_id in let_ids),
current_el_type,
with_type=with_type,
arg_types=arg_types,
)

return im.let(*(zip(let_ids, objs, strict=True)))(body)
Expand All @@ -138,7 +138,7 @@ def _process_elements_impl(
process_func: Callable[..., itir.Expr],
_current_el_exprs: Iterable[T],
current_el_type: ts.TypeSpec,
with_type: bool,
arg_types: Optional[Iterable[ts.TypeSpec]],
) -> itir.Expr:
if isinstance(current_el_type, ts.TupleType):
result = im.make_tuple(
Expand All @@ -149,16 +149,16 @@ def _process_elements_impl(
im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs
),
current_el_type.types[i],
with_type=with_type,
arg_types=tuple(arg_t.types[i] for arg_t in arg_types) # type: ignore[attr-defined] # guaranteed by the requirement that `current_el_type` and each element of `arg_types` have the same tuple structure
if arg_types is not None
else None,
)
for i in range(len(current_el_type.types))
)
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
else:
if with_type:
result = process_func(*_current_el_exprs, current_el_type)
if arg_types is not None:
result = process_func(*_current_el_exprs, arg_types)
else:
result = process_func(*_current_el_exprs)

Expand Down
1 change: 1 addition & 0 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
IJKField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.int32] # type: ignore [valid-type]
IJKFloatField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.float64] # type: ignore [valid-type]
VField: TypeAlias = gtx.Field[[Vertex], np.int32] # type: ignore [valid-type]
VBoolField: TypeAlias = gtx.Field[[Vertex], bool] # type: ignore [valid-type]
EField: TypeAlias = gtx.Field[[Edge], np.int32] # type: ignore [valid-type]
CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type]
EmptyField: TypeAlias = gtx.Field[[], np.int32] # type: ignore [valid-type]
Expand Down
Loading

0 comments on commit e8f11fe

Please sign in to comment.