Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: fix lowering of tuples of neighbors in conditionals #1710

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def adapted_foast_to_gtir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i
return toolchain.StripArgsAdapter(foast_to_gtir_factory(**kwargs))


def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node.type):
def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node_type):
return lambda x: im.op_as_fieldop("make_const_list")(x)
return lambda x: x

Expand Down Expand Up @@ -215,16 +215,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
return self._map("not_", node.operand)
return self._lower_and_map("not_", node.operand)

return self._map(
return self._lower_and_map(
node.op.value,
foast.Constant(value="0", type=dtype, location=node.location),
node.operand,
)

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall:
assert (
Expand All @@ -236,7 +236,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
)

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
current_expr = self.visit(node.func, **kwargs)
Expand Down Expand Up @@ -338,34 +338,43 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id

def create_cast(expr: itir.Expr, t: ts.TypeSpec) -> itir.FunCall:
if isinstance(t, ts.FieldType):
def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall:
if isinstance(t[0], ts.FieldType):
return im.as_fieldop(
im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type)))
)(expr)
else:
assert isinstance(t, ts.ScalarType)
assert isinstance(t[0], ts.ScalarType)
return im.call("cast_")(expr, str(new_type))

if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return create_cast(obj, node.type)
return create_cast(obj, (node.args[0].type,))

return lowering_utils.process_elements(create_cast, obj, node.type, with_type=True)
return lowering_utils.process_elements(
create_cast, obj, node.type, arg_types=(node.args[0].type,)
)

def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return im.op_as_fieldop("if_")(*self.visit(node.args))
return self._lower_and_map("if_", *node.args)

cond_ = self.visit(node.args[0])
cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}"

def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall:
return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_)
def create_if(
true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec]
) -> itir.FunCall:
return _map(
"if_",
(im.ref(cond_symref_name), true_, false_),
(node.args[0].type, *arg_types),
)

result = lowering_utils.process_elements(
create_if,
(self.visit(node.args[1]), self.visit(node.args[2])),
node.type,
arg_types=(node.args[1].type, node.args[2].type),
)

return im.let(cond_symref_name, cond_)(result)
Expand All @@ -377,7 +386,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return im.as_fieldop(im.ref("deref"))(expr)

def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._map(self.visit(node.func, **kwargs), *node.args)
return self._lower_and_map(self.visit(node.func, **kwargs), *node.args)

def _make_reduction_expr(
self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any
Expand Down Expand Up @@ -436,19 +445,34 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr:
def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr:
return self._make_literal(node.value, node.type)

def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
lowered_args = [self.visit(arg, **kwargs) for arg in args]
if all(
isinstance(t, ts.ScalarType)
for arg in args
for t in type_info.primitive_constituents(arg.type)
):
return im.call(op)(*lowered_args) # scalar operation
if any(type_info.contains_local_field(arg.type) for arg in args):
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
op = im.call("map_")(op)
def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
return _map(
op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args)
)


def _map(
op: itir.Expr | str,
lowered_args: tuple,
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
if all(
isinstance(t, ts.ScalarType)
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
return im.call(op)(*lowered_args) # scalar operation
if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types):
lowered_args = tuple(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.call("map_")(op)

return im.op_as_fieldop(im.call(op))(*lowered_args)
return im.op_as_fieldop(im.call(op))(*lowered_args)


class FieldOperatorLoweringError(Exception): ...
50 changes: 35 additions & 15 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i
return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs))


def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node.type):
def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node_type):
return lambda x: im.promote_to_lifted_stencil("make_const_list")(x)
return lambda x: x

Expand Down Expand Up @@ -267,16 +267,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
return self._map("not_", node.operand)
return self._lower_and_map("not_", node.operand)

return self._map(
return self._lower_and_map(
node.op.value,
foast.Constant(value="0", type=dtype, location=node.location),
node.operand,
)

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall:
op = "if_"
Expand All @@ -286,15 +286,17 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
for arg in args
]
if any(type_info.contains_local_field(arg.type) for arg in args):
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
lowered_args = [
promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args)
]
op = im.call("map_")(op)

return lowering_utils.to_tuples_of_iterator(
im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type
)

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
current_expr = self.visit(node.func, **kwargs)
Expand Down Expand Up @@ -408,9 +410,12 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr:

lowered_condition = self.visit(condition, **kwargs)
return lowering_utils.process_elements(
lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv),
lambda tv, fv, types: _map(
"if_", (lowered_condition, tv, fv), (condition.type, *types)
),
[self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)],
node.type,
(node.args[1].type, node.args[2].type),
)

_visit_concat_where = _visit_where
Expand All @@ -419,7 +424,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self.visit(node.args[0], **kwargs)

def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._map(self.visit(node.func, **kwargs), *node.args)
return self._lower_and_map(self.visit(node.func, **kwargs), *node.args)

def _make_reduction_expr(
self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any
Expand Down Expand Up @@ -480,13 +485,28 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr:
def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr:
return self._make_literal(node.value, node.type)

def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
lowered_args = [self.visit(arg, **kwargs) for arg in args]
if any(type_info.contains_local_field(arg.type) for arg in args):
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
op = im.call("map_")(op)
def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
return _map(
op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args)
)


def _map(
op: itir.Expr | str,
lowered_args: tuple,
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
"""
Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists.
"""
if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types):
lowered_args = tuple(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.call("map_")(op)

return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)
return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)


class FieldOperatorLoweringError(Exception): ...
24 changes: 12 additions & 12 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable
from typing import Any, Callable, TypeVar
from typing import Any, Callable, Optional, TypeVar

from gt4py.eve import utils as eve_utils
from gt4py.next.ffront import type_info as ti_ffront
Expand Down Expand Up @@ -102,7 +102,7 @@ def process_elements(
process_func: Callable[..., itir.Expr],
objs: itir.Expr | Iterable[itir.Expr],
current_el_type: ts.TypeSpec,
with_type: bool = False,
arg_types: Optional[Iterable[ts.TypeSpec]] = None,
) -> itir.FunCall:
"""
Recursively applies a processing function to all primitive constituents of a tuple.
Expand All @@ -113,9 +113,9 @@ def process_elements(
objs: The object whose elements are to be transformed.
current_el_type: A type with the same structure as the elements of `objs`. The leaf-types
are not used and thus not relevant.
current_el_type: A type with the same structure as the elements of `objs`. Unless `with_type=True`
the leaf-types are not used and thus not relevant.
with_type: If True, the last argument passed to `process_func` will be its type.
arg_types: If provided, a tuple of the type of each argument is passed to `process_func` as last argument.
Note, that `arg_types` might coincide with `(current_el_type,)*len(objs)`, but not necessarily,
in case of implicit broadcasts.
"""
if isinstance(objs, itir.Expr):
objs = (objs,)
Expand All @@ -125,7 +125,7 @@ def process_elements(
process_func,
tuple(im.ref(let_id) for let_id in let_ids),
current_el_type,
with_type=with_type,
arg_types=arg_types,
)

return im.let(*(zip(let_ids, objs, strict=True)))(body)
Expand All @@ -138,7 +138,7 @@ def _process_elements_impl(
process_func: Callable[..., itir.Expr],
_current_el_exprs: Iterable[T],
current_el_type: ts.TypeSpec,
with_type: bool,
arg_types: Optional[Iterable[ts.TypeSpec]],
) -> itir.Expr:
if isinstance(current_el_type, ts.TupleType):
result = im.make_tuple(
Expand All @@ -149,16 +149,16 @@ def _process_elements_impl(
im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs
),
current_el_type.types[i],
with_type=with_type,
arg_types=tuple(arg_t.types[i] for arg_t in arg_types) # type: ignore[attr-defined] # guaranteed by the requirement that `current_el_type` and each element of `arg_types` have the same tuple structure
if arg_types is not None
else None,
)
for i in range(len(current_el_type.types))
)
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
else:
if with_type:
result = process_func(*_current_el_exprs, current_el_type)
if arg_types is not None:
result = process_func(*_current_el_exprs, arg_types)
else:
result = process_func(*_current_el_exprs)

Expand Down
1 change: 1 addition & 0 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
IJKField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.int32] # type: ignore [valid-type]
IJKFloatField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.float64] # type: ignore [valid-type]
VField: TypeAlias = gtx.Field[[Vertex], np.int32] # type: ignore [valid-type]
VBoolField: TypeAlias = gtx.Field[[Vertex], bool] # type: ignore [valid-type]
EField: TypeAlias = gtx.Field[[Edge], np.int32] # type: ignore [valid-type]
CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type]
EmptyField: TypeAlias = gtx.Field[[], np.int32] # type: ignore [valid-type]
Expand Down
Loading
Loading