diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index bd22aebe57..c9bea908a8 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi @WhereBuiltinFunction def concat_where( - mask: common.Field, + mask: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, false_field: common.Field | core_defs.ScalarT | Tuple, /, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index ee14006b22..b50aaadc1e 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -66,6 +66,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType + elif t is common.Domain: + return ts.DomainType elif t is type: return ( ts.FunctionType @@ -135,7 +137,7 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskT = TypeVar("MaskT", bound=common.Field) +MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain]) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 26bcadaef1..dc8d36af5e 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -11,7 +11,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits from gt4py.next import errors -from gt4py.next.common import DimensionKind +from gt4py.next.common import DimensionKind, promote_dims from gt4py.next.ffront import ( # noqa dialect_ast_enums, experimental, @@ -20,6 +20,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices +from gt4py.next.iterator import builtins from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -570,6 +571,36 @@ def _deduce_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: # check both types compatible + left_t, right_t = left.type, right.type + integer_kind = getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + if ( + isinstance(left_t, ts.DimensionType) + and isinstance(right_t, ts.ScalarType) + and right_t.kind == integer_kind + ): + return ts.DomainType(dims=[left_t.dim]) + if ( + isinstance(right_t, ts.DimensionType) + and isinstance(left_t, ts.ScalarType) + and left_t.kind == integer_kind + ): + return ts.DomainType(dims=[right_t.dim]) + if ( + isinstance(left_t, ts.OffsetType) + and left.op == dialect_ast_enums.BinaryOperator.MOD + and isinstance(right_t, ts.ScalarType) + and right_t.kind == integer_kind + ) or ( + isinstance(right_t, ts.OffsetType) + and right.op == dialect_ast_enums.BinaryOperator.MOD + and isinstance(left_t, ts.ScalarType) + and left_t.kind == integer_kind + ): + raise errors.DSLError( + left.location, "Type 'ts.OffsetType' can not be used in operator 'mod'." + ) + + # TODO for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( @@ -582,13 +613,13 @@ def _deduce_compare_type( # transform operands to have bool dtype and use regular promotion # mechanism to handle dimension promotion return type_info.promote( - with_altered_scalar_kind(left.type, ts.ScalarKind.BOOL), - with_altered_scalar_kind(right.type, ts.ScalarKind.BOOL), + with_altered_scalar_kind(left_t, ts.ScalarKind.BOOL), + with_altered_scalar_kind(right_t, ts.ScalarKind.BOOL), ) except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote '{left.type}' and '{right.type}' to common type" + f"Could not promote '{left_t}' and '{right_t}' to common type" f" in call to '{node.op}'.", ) from ex @@ -612,7 +643,11 @@ def _deduce_binop_type( dialect_ast_enums.BinaryOperator.BIT_OR, dialect_ast_enums.BinaryOperator.BIT_XOR, } - is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic + + def is_logical_or_domain(arg: ts.TypeSpec) -> bool: + return type_info.is_logical(arg) or isinstance(arg, ts.DomainType) + + is_compatible = is_logical_or_domain if node.op in logical_ops else type_info.is_arithmetic # check both types compatible for arg in (left, right): @@ -620,29 +655,32 @@ def _deduce_binop_type( raise errors.DSLError( arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) - - left_type = cast(ts.FieldType | ts.ScalarType, left.type) - right_type = cast(ts.FieldType | ts.ScalarType, right.type) - - if node.op == dialect_ast_enums.BinaryOperator.POW: - return left_type - - if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( - right_type + if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance( + right.type, (ts.ScalarType, ts.FieldType) ): - raise errors.DSLError( - arg.location, - f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", - ) + if node.op == dialect_ast_enums.BinaryOperator.POW: + return left.type - try: - return type_info.promote(left_type, right_type) - except ValueError as ex: - raise errors.DSLError( - node.location, - f"Could not promote '{left_type}' and '{right_type}' to common type" - f" in call to '{node.op}'.", - ) from ex + if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( + right.type + ): + raise errors.DSLError( + arg.location, + f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.", + ) + + try: + return type_info.promote(left.type, right.type) + except ValueError as ex: + raise errors.DSLError( + node.location, + f"Could not promote '{left.type}' and '{right.type}' to common type" + f" in call to '{node.op}'.", + ) from ex + elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType): + return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) + else: + raise ValueError("TODO") def _check_operand_dtypes_match( self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr @@ -908,6 +946,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) try: + # TODO(tehrengruber): the construct_tuple_type function doesn't look correct if isinstance(true_branch_type, ts.TupleType) and isinstance( false_branch_type, ts.TupleType ): @@ -943,7 +982,21 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: location=node.location, ) - _visit_concat_where = _visit_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: + true_branch_type = node.args[1].type + false_branch_type = node.args[2].type + true_branch_fieldtype = cast(ts.FieldType, true_branch_type) + false_branch_fieldtype = cast(ts.FieldType, false_branch_type) + promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype) + return_type = promoted_type + + return foast.Call( + func=node.func, + args=node.args, + kwargs=node.kwargs, + type=return_type, + location=node.location, + ) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index f884ec555d..dd936d7995 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -224,7 +224,9 @@ def visit_Assign( def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: return im.sym(node.id) - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.AxisLiteral: + if isinstance(node.type, ts.DimensionType): + return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: @@ -249,7 +251,28 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: raise NotImplementedError(f"Unary operator '{node.op}' is not supported.") def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) + if ( + node.op == dialect_ast_enums.BinaryOperator.BIT_AND + and isinstance(node.left.type, ts.DomainType) + and isinstance(node.right.type, ts.DomainType) + ): + return im.and_(self.visit(node.left), self.visit(node.right)) + if ( + node.op == dialect_ast_enums.BinaryOperator.BIT_OR + and isinstance(node.left.type, ts.DomainType) + and isinstance(node.right.type, ts.DomainType) + ): + return im.or_(self.visit(node.left), self.visit(node.right)) + if ( + node.op == dialect_ast_enums.BinaryOperator.BIT_XOR + and isinstance(node.left.type, ts.DomainType) + and isinstance(node.right.type, ts.DomainType) + ): + raise NotImplementedError( + f"Binary operator '{node.op}' is not supported for '{node.right.type}' and '{node.right.type}'." + ) + else: + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: assert ( @@ -261,6 +284,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: + # TODO: double-check if we need the changes in the original PR return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -394,7 +418,13 @@ def create_if( return im.let(cond_symref_name, cond_)(result) - _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + if not isinstance(node.type, ts.TupleType): # to keep the IR simpler + return im.call("concat_where")(*self.visit(node.args)) + else: + raise NotImplementedError() + + # TODO: tuple case def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) @@ -476,8 +506,9 @@ def _map( """ Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. """ + # TODO double-check that this code is consistent with the changes in the original PR if all( - isinstance(t, ts.ScalarType) + isinstance(t, (ts.ScalarType, ts.DimensionType)) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 8e5f7addca..4ebc9a388c 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -488,6 +488,8 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "scan", "tuple_get", "unstructured_domain", + "concat_where", + "in", *ARITHMETIC_BUILTINS, *TYPE_BUILTINS, } diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index ea5cf84d86..7ccd86faab 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -63,6 +63,14 @@ class NoneLiteral(Expr): _none_literal: int = 0 +class InfinityLiteral(Expr): + pass + + +class NegInfinityLiteral(Expr): + pass + + class OffsetLiteral(Expr): value: Union[int, str] @@ -142,3 +150,5 @@ class Program(Node, ValidatedSymbolTableTrait): Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] +InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] +NegInfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 27900b6db6..e3ab788033 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -174,3 +174,49 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: + """Return the (set) intersection of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + new_domain_ranges[dim] = SymbolicRange(start, stop) + + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: + """Return the (set) complement of a domain.""" + dims_dict = {} + for dim in domain.ranges.keys(): + lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop + if isinstance(lb, itir.NegInfinityLiteral): + dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral()) + elif isinstance(ub, itir.InfinityLiteral): + dims_dict[dim] = SymbolicRange(start=itir.NegInfinityLiteral(), stop=lb) + else: + raise ValueError("Invalid domain ranges") + return SymbolicDomain(domain.grid_type, dims_dict) + + +def promote_to_same_dimensions( + domain_small: SymbolicDomain, domain_large: SymbolicDomain +) -> SymbolicDomain: + """Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions.""" + dims_dict = {} + for dim in domain_large.ranges.keys(): + if dim in domain_small.ranges.keys(): + lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop + dims_dict[dim] = SymbolicRange(lb, ub) + else: + dims_dict[dim] = SymbolicRange(itir.NegInfinityLiteral(), itir.InfinityLiteral()) + return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 24842ad3be..85ee416aee 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -234,6 +234,11 @@ def if_(cond, true_val, false_val): return call("if_")(cond, true_val, false_val) +def concat_where(cond, true_field, false_field): + """Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``.""" + return call("concat_where")(cond, true_field, false_field) + + def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 7acbf5d23d..1d97878257 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -133,6 +133,12 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]: def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]: return [str(node.value)] + def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: + return ["INF"] + + def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: + return ["-INF"] + def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 7215d0787a..b3980e70ed 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -8,7 +8,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import builtins, embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @@ -16,17 +16,59 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator): def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) - def visit_FunCall(self, node: ir.FunCall): + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded new_node = self.generic_visit(node) if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id in ["minimum", "maximum"] + cpm.is_call_to(new_node, ("minimum", "maximum")) and new_node.args[0] == new_node.args[1] ): # `minimum(a, a)` -> `a` return new_node.args[0] + if cpm.is_call_to(new_node, "minimum"): + # `minimum(neg_inf, neg_inf)` -> `neg_inf` + if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( + new_node.args[1], ir.NegInfinityLiteral + ): + return ir.NegInfinityLiteral() + # `minimum(inf, a)` -> `a` + elif isinstance(new_node.args[0], ir.InfinityLiteral): + return new_node.args[1] + # `minimum(a, inf)` -> `a` + elif isinstance(new_node.args[1], ir.InfinityLiteral): + return new_node.args[0] + + if cpm.is_call_to(new_node, "maximum"): + # `minimum(inf, inf)` -> `inf` + if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( + new_node.args[1], ir.InfinityLiteral + ): + return ir.InfinityLiteral() + # `minimum(neg_inf, a)` -> `a` + elif isinstance(new_node.args[0], ir.NegInfinityLiteral): + return new_node.args[1] + # `minimum(a, neg_inf)` -> `a` + elif isinstance(new_node.args[1], ir.NegInfinityLiteral): + return new_node.args[0] + if cpm.is_call_to(new_node, ("less", "less_equal")): + if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( + new_node.args[1], ir.InfinityLiteral + ): + return im.literal_from_value(True) + if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( + new_node.args[1], ir.NegInfinityLiteral + ): + return im.literal_from_value(False) + if cpm.is_call_to(new_node, ("greater", "greater_equal")): + if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( + new_node.args[1], ir.InfinityLiteral + ): + return im.literal_from_value(False) + if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( + new_node.args[1], ir.NegInfinityLiteral + ): + return im.literal_from_value(True) if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" @@ -52,6 +94,6 @@ def visit_FunCall(self, node: ir.FunCall): ] new_node = im.literal_from_value(fun(*arg_values)) except ValueError: - pass # happens for inf and neginf + pass # happens for SymRefs which are not inf or neg_inf return new_node diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccaaf563f5..955f428fc4 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -86,7 +86,21 @@ def _is_collectable_expr(node: itir.Node) -> bool: # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach) - if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce", "map_"]: + # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if isinstance(node.fun, itir.SymRef) and node.fun.id in [ + "lift", + "shift", + "reduce", + "map_", + "index", + ]: + return False + # do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if cpm.is_call_to(node, "make_tuple") and all( + cpm.is_call_to(arg, "index") for arg in node.args + ): return False return True elif isinstance(node, itir.Lambda): diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py new file mode 100644 index 0000000000..9fab9e053f --- /dev/null +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -0,0 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import reduce + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, "in"): + ret = [] + pos, domain = node.args + for i, (_, v) in enumerate( + domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items() + ): + ret.append( + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), + ) + ) # TODO: avoid pos duplication + return reduce(im.and_, ret) + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f3c3185225..618fa15699 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -363,6 +363,36 @@ def _infer_if( return result_expr, actual_domains +def _infer_concat_where( + expr: itir.Expr, + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: + assert cpm.is_call_to(expr, "concat_where") + assert isinstance(domain, domain_utils.SymbolicDomain) + infered_args_expr = [] + actual_domains: AccessedDomains = {} + cond, true_field, false_field = expr.args + symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) + for arg in [true_field, false_field]: + if arg == true_field: + extended_cond = domain_utils.promote_to_same_dimensions(symbolic_cond, domain) + domain_ = domain_utils.domain_intersection(domain, extended_cond) + elif arg == false_field: + cond_complement = domain_utils.domain_complement(symbolic_cond) + extended_cond_complement = domain_utils.promote_to_same_dimensions( + cond_complement, domain + ) + domain_ = domain_utils.domain_intersection(domain, extended_cond_complement) + + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains + + def _infer_expr( expr: itir.Expr, domain: DomainAccess, @@ -382,6 +412,8 @@ def _infer_expr( return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) + elif cpm.is_call_to(expr, "concat_where"): + return _infer_concat_where(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, builtins.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, builtins.TYPE_BUILTINS) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py new file mode 100644 index 0000000000..a5da214ae3 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -0,0 +1,111 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import builtins, ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding + + +class InferDomainOps(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + if ( + cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) + and any(isinstance(arg, ir.AxisLiteral) for arg in node.args) + and any(isinstance(arg, ir.Literal) for arg in node.args) + ): # TODO: add tests + arg1, arg2 = node.args + fun = node.fun + if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): + dim = common.Dimension(value=arg1.value, kind=arg1.kind) + value = int(arg2.value) + reverse = False + elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): + dim = common.Dimension(value=arg2.value, kind=arg2.kind) + value = int(arg1.value) + reverse = True + else: + raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") + assert isinstance(fun, ir.SymRef) + min_: int | ir.NegInfinityLiteral + max_: int | ir.InfinityLiteral + match fun.id: + case ir.SymbolRef("less"): + if reverse: + min_ = value + 1 + max_ = ir.InfinityLiteral() + else: + min_ = ir.NegInfinityLiteral() + max_ = value - 1 + case ir.SymbolRef("less_equal"): + if reverse: + min_ = value + max_ = ir.InfinityLiteral() + else: + min_ = ir.NegInfinityLiteral() + max_ = value + case ir.SymbolRef("greater"): + if reverse: + min_ = ir.NegInfinityLiteral() + max_ = value - 1 + else: + min_ = value + 1 + max_ = ir.InfinityLiteral() + case ir.SymbolRef("greater_equal"): + if reverse: + min_ = ir.NegInfinityLiteral() + max_ = value + else: + min_ = value + max_ = ir.InfinityLiteral() + case ir.SymbolRef("eq"): + min_ = max_ = value + case ir.SymbolRef("not_eq"): + min1 = ir.NegInfinityLiteral() + max1 = value - 1 + min2 = value + 1 + max2 = ir.InfinityLiteral() + return im.call("and_")( + im.domain(common.GridType.CARTESIAN, {dim: (min1, max1)}), + im.domain(common.GridType.CARTESIAN, {dim: (min2, max2)}), + ) + case _: + raise NotImplementedError + + return im.domain( + common.GridType.CARTESIAN, + {dim: (min_, max_ + 1)} + if not isinstance(max_, ir.InfinityLiteral) + else {dim: (min_, max_)}, + ) + if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( + isinstance(arg, (ir.Literal, ir.FunCall)) for arg in node.args + ): + if cpm.is_call_to(node, "and_"): + # TODO: domain promotion + return ConstantFolding.apply( + domain_utils.domain_intersection( + *[domain_utils.SymbolicDomain.from_expr(arg) for arg in node.args] + ).as_expr() + ) + + else: + raise NotImplementedError + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py new file mode 100644 index 0000000000..258494e0c4 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -0,0 +1,38 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im + + +class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + if cpm.is_call_to(cond_expr, ("and_")): + conds = cond_expr.args + return im.concat_where( + conds[0], im.concat_where(conds[1], field_a, field_b), field_b + ) + if cpm.is_call_to(cond_expr, ("or_")): + conds = cond_expr.args + return im.concat_where( + conds[0], field_a, im.concat_where(conds[1], field_a, field_b) + ) + if cpm.is_call_to(cond_expr, ("eq")): + cond1 = im.less(cond_expr.args[0], cond_expr.args[1]) + cond2 = im.greater(cond_expr.args[0], cond_expr.args[1]) + return im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a)) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 0a79848443..43a3e98f47 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,12 +12,16 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + expand_library_functions, fuse_as_fieldop, global_tmps, infer_domain, + infer_domain_ops, inline_dynamic_shifts, inline_fundefs, inline_lifts, + nest_concat_wheres, + transform_concat_where, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -83,11 +87,16 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = nest_concat_wheres.NestConcatWheres.apply(ir) + ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = infer_domain.infer_program( ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) for _ in range(10): inlined = ir diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py new file mode 100644 index 0000000000..a33cfcab5a --- /dev/null +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -0,0 +1,38 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() + dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] + return im.as_fieldop( + im.lambda_("pos", "a", "b")( + im.if_(im.call("in")(im.deref("pos"), cond_expr), im.deref("a"), im.deref("b")) + ), + node.annex.domain.as_expr(), + )(im.make_tuple(*dims), field_a, field_b) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 901cb103da..c2f25e2e89 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,7 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) - assert isinstance(domain, it_ts.DomainType) + assert isinstance(domain, ts.DomainType) assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( @@ -579,6 +579,12 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type + def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + + def visit_NegInfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + def visit_SymRef( self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec] ) -> ts.TypeSpec | type_synthesizer.TypeSynthesizer: diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 7825bf1c98..30c79c7c94 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -16,10 +16,6 @@ class NamedRangeType(ts.TypeSpec): dim: common.Dimension -class DomainType(ts.DataType): - dims: list[common.Dimension] | Literal["unknown"] - - class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index f5aeac7943..c6f31d0a51 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -109,13 +109,30 @@ def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) -@_register_builtin_type_synthesizer( - fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS | builtins.BINARY_LOGICAL_BUILTINS -) -def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: +def synthesize_binary_math_comparison_builtins( + lhs, rhs +) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.DimensionType): + return ts.DomainType(dims=[rhs.dim]) + if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): + return ts.DomainType(dims=[lhs.dim]) + assert isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.ScalarType) return ts.ScalarType(kind=ts.ScalarKind.BOOL) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_LOGICAL_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.DomainType) and isinstance(rhs, ts.DomainType): + return ts.DomainType(dims=common.promote_dims(lhs.dims, rhs.dims)) + else: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + @_register_builtin_type_synthesizer def deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.DataType | ts.DeferredType: if isinstance(it, ts.DeferredType): @@ -183,9 +200,9 @@ def named_range( @_register_builtin_type_synthesizer(fun_names=["cartesian_domain", "unstructured_domain"]) -def _(*args: it_ts.NamedRangeType) -> it_ts.DomainType: +def _(*args: it_ts.NamedRangeType) -> ts.DomainType: assert all(isinstance(arg, it_ts.NamedRangeType) for arg in args) - return it_ts.DomainType(dims=[arg.dim for arg in args]) + return ts.DomainType(dims=[arg.dim for arg in args]) @_register_builtin_type_synthesizer @@ -201,6 +218,15 @@ def index(arg: ts.DimensionType) -> ts.FieldType: ) +@_register_builtin_type_synthesizer +def concat_where( + domain: ts.DomainType, + true_field: ts.FieldType | ts.TupleType, + false_field: ts.FieldType | ts.TupleType, +) -> ts.FieldType | ts.TupleType: + return type_info.promote(true_field, false_field) + + @_register_builtin_type_synthesizer def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( @@ -244,7 +270,7 @@ def apply_lift( def _convert_as_fieldop_input_to_iterator( - domain: it_ts.DomainType, input_: ts.TypeSpec + domain: ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: # get the dimensions of all non-zero-dimensional field inputs and check they agree all_input_dims = ( @@ -284,7 +310,7 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( stencil: TypeSynthesizer, - domain: Optional[it_ts.DomainType] = None, + domain: Optional[ts.DomainType] = None, *, offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: @@ -299,7 +325,7 @@ def as_fieldop( # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` # it is unclear if the result has dimension I, J or J, I. if domain is None: - domain = it_ts.DomainType(dims="unknown") + domain = ts.DomainType(dims="unknown") @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 831694791a..6ca9bde77f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -124,7 +124,11 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) + or _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, Literal)) + or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")), + el, + ) for el in value ): raise ValueError( diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 26373c647f..fa73748df6 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -304,7 +304,8 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: return ( - isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) + isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) + and isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind is ts.ScalarKind.BOOL ) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 2fbd039d16..abc62885ae 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Optional, Sequence, Union +from typing import Iterator, Literal, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common @@ -133,3 +133,7 @@ def __str__(self) -> str: kwarg_strs = [f"{key}: {value}" for key, value in self.pos_or_kw_args.items()] args_str = ", ".join((*arg_strs, *kwarg_strs)) return f"({args_str}) -> {self.returns}" + + +class DomainType(DataType): + dims: list[common.Dimension] | Literal["unknown"] diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 89ad556476..66330016ef 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -62,6 +62,7 @@ IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] IKField: TypeAlias = gtx.Field[[IDim, KDim], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 364434029f..7db29bc088 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -9,8 +9,9 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import KDim, cartesian_case +from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case from gt4py import next as gtx +from gt4py.next import errors from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -23,7 +24,7 @@ def test_boundary_same_size_fields(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -37,12 +38,109 @@ def testee( cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(KDim >= 2, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] >= 2, boundary.asnumpy(), interior.asnumpy() + ) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_dimension_different_dims(cartesian_case): + @gtx.field_operator + def testee(j: cases.JField, interior: cases.IJField, boundary: cases.JField) -> cases.IJField: + return concat_where(IDim >= 2, boundary, interior) + + j = cases.allocate(cartesian_case, testee, "j", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + j.asnumpy()[:, np.newaxis] >= 2, boundary.asnumpy()[np.newaxis, :], interior.asnumpy() + ) + cases.verify(cartesian_case, testee, j, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_nested_conditions(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where((KDim < 2), boundary, concat_where((KDim >= 5), boundary, interior)) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + (k.asnumpy()[np.newaxis, np.newaxis, :] < 2) + | (k.asnumpy()[np.newaxis, np.newaxis, :] >= 5), + boundary.asnumpy(), + interior.asnumpy(), + ) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_conditions_and(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where(((KDim > 2) & (KDim <= 5)), interior, boundary) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where((k.asnumpy() > 2) & (k.asnumpy() <= 5), interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_conditions_eq(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where((KDim == 2), interior, boundary) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where(k.asnumpy() == 2, interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_conditions_or(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where(((KDim < 2) | (KDim >= 5)), boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where((k.asnumpy() < 2) | (k.asnumpy() >= 5), boundary.asnumpy(), interior.asnumpy()) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + def test_boundary_horizontal_slice(cartesian_case): @gtx.field_operator def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -63,7 +161,7 @@ def test_boundary_single_layer(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -80,18 +178,22 @@ def testee( def test_alternating_mask(cartesian_case): - @gtx.field_operator - def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: - return concat_where(k % 2 == 0, f1, f0) + with pytest.raises( + errors.DSLError, match=("Type 'ts.OffsetType' can not be used in operator 'mod'.") + ): - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - f0 = cases.allocate(cartesian_case, testee, "f0")() - f1 = cases.allocate(cartesian_case, testee, "f1")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() + @gtx.field_operator + def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: + return concat_where(KDim % 2 == 0, f1, f0) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + f0 = cases.allocate(cartesian_case, testee, "f0")() + f1 = cases.allocate(cartesian_case, testee, "f1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) + ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) - cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) @pytest.mark.uses_tuple_returns @@ -104,7 +206,7 @@ def testee( interior1: cases.IJKField, boundary1: cases.IJField, ) -> Tuple[cases.IJKField, cases.IJKField]: - return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1)) + return concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior0 = cases.allocate(cartesian_case, testee, "interior0")() diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a39fe3c6d8..0f9b1acb8a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -72,6 +72,7 @@ def expression_test_cases(): return ( # itir expr, type + # TODO: write test for IDim < 10, concat_where (im.call("abs")(1), int_type), (im.call("power")(2.0, 2), float64_type), (im.plus(1, 2), int_type), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index cf325c2daa..794a93090b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -8,6 +8,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next.iterator import ir def test_constant_folding_boolean(): @@ -60,3 +61,89 @@ def test_constant_folding_literal_maximum(): expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected + + +def test_constant_folding_inf_maximum(): + testee = im.call("maximum")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = ir.InfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = ir.InfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_folding_inf_minimum(): + testee = im.call("minimum")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = ir.NegInfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = ir.NegInfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_greater_less(): + testee = im.call("greater")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = im.literal_from_value(True) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = im.literal_from_value(True) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(True) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(True) + actual = ConstantFolding.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 4a2a441510..2e014ffdb8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1093,3 +1093,117 @@ def test_never_accessed_domain_tuple(offset_provider): "in_field2": infer_domain.DomainAccessDescriptor.NEVER, } run_test_expr(testee, testee, domain, expected_domains, offset_provider) + + +def test_concat_where(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 4)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 4)}) +# Todo: nested concat wheres + + +def test_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 10)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_two_dimensions_J(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, "inf")}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (20, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, "inf")}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 20)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + testee = im.concat_where( + domain_cond1, + im.concat_where( + domain_cond2, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ), + im.as_fieldop("deref")("in_field3"), + ) + + expected = im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ), + im.as_fieldop("deref", domain3)("in_field3"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2, "in_field3": domain3} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains)