Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Gtir concat where #1713

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi

@WhereBuiltinFunction
def concat_where(
mask: common.Field,
mask: common.Domain,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
/,
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp
return ts.OffsetType
elif t is core_defs.ScalarT:
return ts.ScalarType
elif t is common.Domain:
return ts.DomainType
elif t is type:
return (
ts.FunctionType
Expand Down Expand Up @@ -135,7 +137,7 @@ def __gt_type__(self) -> ts.FunctionType:
)


MaskT = TypeVar("MaskT", bound=common.Field)
MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain])
FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple])


Expand Down
107 changes: 80 additions & 27 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, NodeVisitor, traits
from gt4py.next import errors
from gt4py.next.common import DimensionKind
from gt4py.next.common import DimensionKind, promote_dims
from gt4py.next.ffront import ( # noqa
dialect_ast_enums,
experimental,
Expand All @@ -20,6 +20,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.foast_passes.utils import compute_assign_indices
from gt4py.next.iterator import builtins
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation


Expand Down Expand Up @@ -570,6 +571,36 @@ def _deduce_compare_type(
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# check both types compatible
left_t, right_t = left.type, right.type
integer_kind = getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
if (
isinstance(left_t, ts.DimensionType)
and isinstance(right_t, ts.ScalarType)
and right_t.kind == integer_kind
):
return ts.DomainType(dims=[left_t.dim])
if (
isinstance(right_t, ts.DimensionType)
and isinstance(left_t, ts.ScalarType)
and left_t.kind == integer_kind
):
return ts.DomainType(dims=[right_t.dim])
if (
isinstance(left_t, ts.OffsetType)
and left.op == dialect_ast_enums.BinaryOperator.MOD
and isinstance(right_t, ts.ScalarType)
and right_t.kind == integer_kind
) or (
isinstance(right_t, ts.OffsetType)
and right.op == dialect_ast_enums.BinaryOperator.MOD
and isinstance(left_t, ts.ScalarType)
and left_t.kind == integer_kind
):
raise errors.DSLError(
left.location, "Type 'ts.OffsetType' can not be used in operator 'mod'."
)

# TODO
for arg in (left, right):
if not type_info.is_arithmetic(arg.type):
raise errors.DSLError(
Expand All @@ -582,13 +613,13 @@ def _deduce_compare_type(
# transform operands to have bool dtype and use regular promotion
# mechanism to handle dimension promotion
return type_info.promote(
with_altered_scalar_kind(left.type, ts.ScalarKind.BOOL),
with_altered_scalar_kind(right.type, ts.ScalarKind.BOOL),
with_altered_scalar_kind(left_t, ts.ScalarKind.BOOL),
with_altered_scalar_kind(right_t, ts.ScalarKind.BOOL),
)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left.type}' and '{right.type}' to common type"
f"Could not promote '{left_t}' and '{right_t}' to common type"
f" in call to '{node.op}'.",
) from ex

Expand All @@ -612,37 +643,44 @@ def _deduce_binop_type(
dialect_ast_enums.BinaryOperator.BIT_OR,
dialect_ast_enums.BinaryOperator.BIT_XOR,
}
is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic

def is_logical_or_domain(arg: ts.TypeSpec) -> bool:
return type_info.is_logical(arg) or isinstance(arg, ts.DomainType)

is_compatible = is_logical_or_domain if node.op in logical_ops else type_info.is_arithmetic

# check both types compatible
for arg in (left, right):
if not is_compatible(arg.type):
raise errors.DSLError(
arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'."
)

left_type = cast(ts.FieldType | ts.ScalarType, left.type)
right_type = cast(ts.FieldType | ts.ScalarType, right.type)

if node.op == dialect_ast_enums.BinaryOperator.POW:
return left_type

if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
right_type
if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance(
right.type, (ts.ScalarType, ts.FieldType)
):
raise errors.DSLError(
arg.location,
f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.",
)
if node.op == dialect_ast_enums.BinaryOperator.POW:
return left.type

try:
return type_info.promote(left_type, right_type)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left_type}' and '{right_type}' to common type"
f" in call to '{node.op}'.",
) from ex
if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
right.type
):
raise errors.DSLError(
arg.location,
f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.",
)

try:
return type_info.promote(left.type, right.type)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left.type}' and '{right.type}' to common type"
f" in call to '{node.op}'.",
) from ex
elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType):
return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims))
else:
raise ValueError("TODO")

def _check_operand_dtypes_match(
self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr
Expand Down Expand Up @@ -908,6 +946,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
)

try:
# TODO(tehrengruber): the construct_tuple_type function doesn't look correct
if isinstance(true_branch_type, ts.TupleType) and isinstance(
false_branch_type, ts.TupleType
):
Expand Down Expand Up @@ -943,7 +982,21 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
location=node.location,
)

_visit_concat_where = _visit_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
true_branch_type = node.args[1].type
false_branch_type = node.args[2].type
true_branch_fieldtype = cast(ts.FieldType, true_branch_type)
false_branch_fieldtype = cast(ts.FieldType, false_branch_type)
promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype)
return_type = promoted_type

return foast.Call(
func=node.func,
args=node.args,
kwargs=node.kwargs,
type=return_type,
location=node.location,
)

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call:
arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type)
Expand Down
39 changes: 35 additions & 4 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def visit_Assign(
def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym:
return im.sym(node.id)

def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef:
def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.AxisLiteral:
if isinstance(node.type, ts.DimensionType):
return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind)
return im.ref(node.id)

def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr:
Expand All @@ -249,7 +251,28 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
raise NotImplementedError(f"Unary operator '{node.op}' is not supported.")

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map(node.op.value, node.left, node.right)
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_AND
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
return im.and_(self.visit(node.left), self.visit(node.right))
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_OR
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
return im.or_(self.visit(node.left), self.visit(node.right))
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_XOR
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
raise NotImplementedError(
f"Binary operator '{node.op}' is not supported for '{node.right.type}' and '{node.right.type}'."
)
else:
return self._lower_and_map(node.op.value, node.left, node.right)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall:
assert (
Expand All @@ -261,6 +284,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
)

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
# TODO: double-check if we need the changes in the original PR
return self._lower_and_map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
Expand Down Expand Up @@ -394,7 +418,13 @@ def create_if(

return im.let(cond_symref_name, cond_)(result)

_visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return im.call("concat_where")(*self.visit(node.args))
else:
raise NotImplementedError()

# TODO: tuple case

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
expr = self.visit(node.args[0], **kwargs)
Expand Down Expand Up @@ -476,8 +506,9 @@ def _map(
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
# TODO double-check that this code is consistent with the changes in the original PR
if all(
isinstance(t, ts.ScalarType)
isinstance(t, (ts.ScalarType, ts.DimensionType))
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"scan",
"tuple_get",
"unstructured_domain",
"concat_where",
"in",
*ARITHMETIC_BUILTINS,
*TYPE_BUILTINS,
}
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ class NoneLiteral(Expr):
_none_literal: int = 0


class InfinityLiteral(Expr):
pass


class NegInfinityLiteral(Expr):
pass


class OffsetLiteral(Expr):
value: Union[int, str]

Expand Down Expand Up @@ -142,3 +150,5 @@ class Program(Node, ValidatedSymbolTableTrait):
Program.__hash__ = Node.__hash__ # type: ignore[method-assign]
SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign]
IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign]
InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
NegInfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
46 changes: 46 additions & 0 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,49 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain:
new_domain_ranges[dim] = SymbolicRange(start, stop)

return SymbolicDomain(domains[0].grid_type, new_domain_ranges)


def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain:
"""Return the (set) intersection of a list of domains."""
new_domain_ranges = {}
assert all(domain.grid_type == domains[0].grid_type for domain in domains)
for dim in domains[0].ranges.keys():
start = functools.reduce(
lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr),
[domain.ranges[dim].start for domain in domains],
)
stop = functools.reduce(
lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr),
[domain.ranges[dim].stop for domain in domains],
)
new_domain_ranges[dim] = SymbolicRange(start, stop)

return SymbolicDomain(domains[0].grid_type, new_domain_ranges)


def domain_complement(domain: SymbolicDomain) -> SymbolicDomain:
"""Return the (set) complement of a domain."""
dims_dict = {}
for dim in domain.ranges.keys():
lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop
if isinstance(lb, itir.NegInfinityLiteral):
dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral())
elif isinstance(ub, itir.InfinityLiteral):
dims_dict[dim] = SymbolicRange(start=itir.NegInfinityLiteral(), stop=lb)
else:
raise ValueError("Invalid domain ranges")
return SymbolicDomain(domain.grid_type, dims_dict)


def promote_to_same_dimensions(
domain_small: SymbolicDomain, domain_large: SymbolicDomain
) -> SymbolicDomain:
"""Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions."""
dims_dict = {}
for dim in domain_large.ranges.keys():
if dim in domain_small.ranges.keys():
lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop
dims_dict[dim] = SymbolicRange(lb, ub)
else:
dims_dict[dim] = SymbolicRange(itir.NegInfinityLiteral(), itir.InfinityLiteral())
return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def if_(cond, true_val, false_val):
return call("if_")(cond, true_val, false_val)


def concat_where(cond, true_field, false_field):
"""Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``."""
return call("concat_where")(cond, true_field, false_field)


def lift(expr):
"""Create a lift FunCall, shorthand for ``call(call("lift")(expr))``."""
return call(call("lift")(expr))
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]:
def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]:
return [str(node.value)]

def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["INF"]

def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["-INF"]

def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]:
return [str(node.value) + "ₒ"]

Expand Down
Loading
Loading