diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 40bfc0ab75..96b585b1f4 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import typing -from typing import Callable, Iterable, Union +from typing import Callable, Iterable, Optional, Union from gt4py._core import definitions as core_defs from gt4py.next.iterator import ir as itir @@ -401,3 +401,26 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) + + +def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: + """ + Create an `as_fieldop` call. + + Examples + -------- + >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) + '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' + """ + return call( + call("as_fieldop")( + *( + ( + expr, + domain, + ) + if domain + else (expr,) + ) + ) + ) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 32714232a6..acd896d753 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -11,11 +11,14 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + import dataclasses import functools import math import operator -import typing +from typing import Callable, Iterable, TypeVar, Union, cast from gt4py.eve import ( NodeTranslator, @@ -25,32 +28,36 @@ VisitorWithSymbolTableTrait, ) from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda +from gt4py.next.iterator.type_system import inference as itir_type_inference +from gt4py.next.type_system import type_info, type_specifications as ts @dataclasses.dataclass class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) - expr_map: dict[int, ir.SymRef] + expr_map: dict[int, itir.SymRef] - def visit_Expr(self, node: ir.Node) -> ir.Node: + def visit_Expr(self, node: itir.Node) -> itir.Node: if id(node) in self.expr_map: return self.expr_map[id(node)] return self.generic_visit(node) - def visit_FunCall(self, node: ir.FunCall) -> ir.Node: - node = typing.cast(ir.FunCall, self.visit_Expr(node)) + def visit_FunCall(self, node: itir.FunCall) -> itir.Node: + node = cast(itir.FunCall, self.visit_Expr(node)) # If we encounter an expression like: # (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr) # (non-recursively) inline the lambda to obtain: # (λ(_cs_1) → _cs_1+_cs_1)(outer_expr) # This allows identifying more common subexpressions later on - if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda): + if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda): eligible_params = [] for arg in node.args: - eligible_params.append(isinstance(arg, ir.SymRef) and arg.id.startswith("_cs")) + eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs")) if any(eligible_params): # note: the inline is opcount preserving anyway so avoid the additional # effort in the inliner by disabling opcount preservation. @@ -60,18 +67,18 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: return node -def _is_collectable_expr(node: ir.Node) -> bool: - if isinstance(node, ir.FunCall): +def _is_collectable_expr(node: itir.Node) -> bool: + if isinstance(node, itir.FunCall): # do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be # visited, to ensure symbol dependencies are recognized correctly. # do also not collect reduce nodes if they are left in the it at this point, this may lead to # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach) - if isinstance(node.fun, ir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]: + if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]: return False return True - elif isinstance(node, ir.Lambda): + elif isinstance(node, itir.Lambda): return True return False @@ -92,7 +99,7 @@ class SubexpressionData: class State: #: A dictionary mapping a node to a list of node ids with equal hash and some additional #: information. See `SubexpressionData` for more information. - subexprs: dict[ir.Node, "CollectSubexpressions.SubexpressionData"] = dataclasses.field( + subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = dataclasses.field( default_factory=dict ) # TODO(tehrengruber): Revisit if this makes sense or if we can just recompute the collected @@ -102,7 +109,7 @@ class State: #: The ids of all nodes declaring a symbol which are referenced (using a `SymRef`) used_symbol_ids: set[int] = dataclasses.field(default_factory=set) - def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None: + def remove_subexprs(self, nodes: Iterable[itir.Node]) -> None: node_ids_to_remove: set[int] = set() for node in nodes: subexpr_data = self.subexprs.pop(node, None) @@ -114,22 +121,22 @@ def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None: collected_child_node_ids -= node_ids_to_remove @classmethod - def apply(cls, node: ir.Node) -> dict[ir.Node, list[tuple[int, set[int]]]]: + def apply(cls, node: itir.Node) -> dict[itir.Node, list[tuple[int, set[int]]]]: state = cls.State() obj = cls() obj.visit(node, state=state, depth=-1) # Return subexpression such that the nodes closer to the root come first and skip the root # node itself. - subexprs_sorted: list[tuple[ir.Node, "CollectSubexpressions.SubexpressionData"]] = sorted( + subexprs_sorted: list[tuple[itir.Node, CollectSubexpressions.SubexpressionData]] = sorted( state.subexprs.items(), key=lambda el: el[1].max_depth ) return {k: v.subexprs for k, v in subexprs_sorted if k is not node} - def generic_visit(self, *args, **kwargs): + def generic_visit(self, node, **kwargs): depth = kwargs.pop("depth") - return super().generic_visit(*args, depth=depth + 1, **kwargs) + return super().generic_visit(node, depth=depth + 1, **kwargs) - def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. + def visit(self, node: itir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node): return super().visit(node, **kwargs) @@ -141,7 +148,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su # Special handling of `if_(condition, true_branch, false_branch)` like expressions that # avoids extracting subexpressions unless they are used in either the condition or both # branches. - if isinstance(node, ir.FunCall) and node.fun == ir.SymRef(id="if_"): + if isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="if_"): assert len(node.args) == 3 # collect subexpressions for all arguments to the `if_` arg_states = [self.State() for _ in node.args] @@ -157,7 +164,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su arg_state.remove_subexprs(arg_state.subexprs.keys() - eligible_subexprs) # merge the states of the three arguments - subexprs: dict[ir.Node, CollectSubexpressions.SubexpressionData] = {} + subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = {} for state in arg_states: for subexpr, data in state.subexprs.items(): merged_data = subexprs.setdefault(subexpr, self.SubexpressionData()) @@ -203,19 +210,19 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su parent_state.collected_child_node_ids.update(collected_child_node_ids) def visit_SymRef( - self, node: ir.SymRef, *, symtable: dict[str, ir.Node], state: State, **kwargs + self, node: itir.SymRef, *, symtable: dict[str, itir.Node], state: State, **kwargs ) -> None: if node.id in symtable: # root symbol otherwise state.used_symbol_ids.add(id(symtable[node.id])) def extract_subexpression( - node: ir.Expr, - predicate: typing.Callable[[ir.Expr, int], bool], + node: itir.Expr, + predicate: Callable[[itir.Expr, int], bool], uid_generator: UIDGenerator, once_only: bool = False, deepest_expr_first: bool = False, -) -> tuple[ir.Expr, typing.Union[dict[ir.Sym, ir.Expr], None], bool]: +) -> tuple[itir.Expr, Union[dict[itir.Sym, itir.Expr], None], bool]: """ Given an expression extract all subexprs and return a new expr with the subexprs replaced. @@ -312,20 +319,20 @@ def extract_subexpression( ) ignored_children = False - extracted = dict[ir.Sym, ir.Expr]() + extracted = dict[itir.Sym, itir.Expr]() # collect expressions subexprs = CollectSubexpressions.apply(node) # collect multiple occurrences and map them to fresh symbols - expr_map = dict[int, ir.SymRef]() + expr_map = dict[int, itir.SymRef]() ignored_ids = set() for expr, subexpr_entry in ( subexprs.items() if not deepest_expr_first else reversed(subexprs.items()) ): # just to make mypy happy when calling the predicate. Every subnode and hence subexpression # is an expr anyway. - assert isinstance(expr, ir.Expr) + assert isinstance(expr, itir.Expr) if not predicate(expr, len(subexpr_entry)): continue @@ -345,8 +352,8 @@ def extract_subexpression( continue expr_id = uid_generator.sequential_id() - extracted[ir.Sym(id=expr_id)] = expr - expr_ref = ir.SymRef(id=expr_id) + extracted[itir.Sym(id=expr_id)] = expr + expr_ref = itir.SymRef(id=expr_id) for id_ in eligible_ids: expr_map[id_] = expr_ref @@ -359,17 +366,26 @@ def extract_subexpression( return _NodeReplacer(expr_map).visit(node), extracted, ignored_children +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr) + + @dataclasses.dataclass(frozen=True) class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): """ Perform common subexpression elimination. Examples: - >>> x = ir.SymRef(id="x") - >>> plus = lambda a, b: ir.FunCall(fun=ir.SymRef(id=("plus")), args=[a, b]) + >>> x = itir.SymRef(id="x") + >>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b]) >>> expr = plus(plus(x, x), plus(x, x)) - >>> print(CommonSubexpressionElimination().visit(expr)) + >>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True)) (λ(_cs_1) → _cs_1 + _cs_1)(x + x) + + The pass visits the tree top-down starting from the root node, e.g. an itir.Program. + For each node we extract (eligible) subexpressions occuring more than once using + :ref:`extract_subexpression`. In field-view context we only extract expression when they are + fields (or composites thereof), in local view everything is eligible. Since the visit is + top-down, extracted expressions always land up in the outermost scope they can appear in. """ # we use one UID generator per instance such that the generated ids are @@ -380,23 +396,68 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): collect_all: bool = dataclasses.field(default=False) - def visit_FunCall(self, node: ir.FunCall): - if isinstance(node.fun, ir.SymRef) and node.fun.id in [ - "cartesian_domain", - "unstructured_domain", - ]: - return node + @classmethod + def apply( + cls, + node: ProgramOrExpr, + is_local_view: bool | None = None, + offset_provider: common.OffsetProvider | None = None, + ) -> ProgramOrExpr: + is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) + if is_program: + assert is_local_view is None + is_local_view = False + else: + assert ( + is_local_view is not None + ), "The expression's context must be specified using `is_local_view`." - new_expr, extracted, ignored_children = extract_subexpression( - node, lambda subexpr, num_occurences: num_occurences > 1, self.uids + offset_provider = offset_provider or {} + node = itir_type_inference.infer( + node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program ) + return cls().visit(node, is_local_view=is_local_view) + + def generic_visit(self, node, **kwargs): + if cpm.is_call_to("as_fieldop", node): + assert not kwargs.get("is_local_view") + is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view") + + return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view})) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + is_local_view = kwargs["is_local_view"] + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + return node + + def predicate(subexpr: itir.Expr, num_occurences: int): + # note: be careful here with the syntatic context: the expression might be in local + # view, even though the syntactic context `node` is in field view. + # note: what is extracted is sketched in the docstring above. keep it updated. + if num_occurences > 1: + if is_local_view: + return True + else: + # only extract fields outside of `as_fieldop` + # `as_fieldop(...)(field_expr, field_expr)` + # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` + assert isinstance(subexpr.type, ts.TypeSpec) + if all( + isinstance(stype, ts.FieldType) + for stype in type_info.primitive_constituents(subexpr.type) + ): + return True + return False + + new_expr, extracted, ignored_children = extract_subexpression(node, predicate, self.uids) if not extracted: - return self.generic_visit(node) + return self.generic_visit(node, **kwargs) # apply remapping - result = ir.FunCall( - fun=ir.Lambda(params=list(extracted.keys()), expr=new_expr), + result = itir.FunCall( + fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr), args=list(extracted.values()), ) @@ -406,6 +467,6 @@ def visit_FunCall(self, node: ir.FunCall): # inside of subexpressions directly. This would require a different order of replacement # (from lower to higher level). if ignored_children: - return self.visit(result) + return self.visit(result, **kwargs) - return self.generic_visit(result) + return self.generic_visit(result, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 655fa2f4f5..acff274b92 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -17,6 +17,7 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fencil_to_program from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -84,7 +85,7 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, -) -> itir.FencilDefinition | FencilWithTemporaries | itir.Program: +) -> itir.Program: if isinstance(ir, itir.Program): # TODO(havogt): during refactoring to GTIR, we bypass transformations in case we already translated to itir.Program # (currently the case when using the roundtrip backend) @@ -180,6 +181,11 @@ def apply_common_transforms( ir = FuseMaps().visit(ir) ir = CollapseListGet().visit(ir) + assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)) + ir = fencil_to_program.FencilToProgram().apply( + ir + ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR + if unroll_reduce: for _ in range(10): unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) @@ -197,12 +203,12 @@ def apply_common_transforms( ir = ScanEtaReduction().visit(ir) if common_subexpression_elimination: - ir = CommonSubexpressionElimination().visit(ir) + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program ir = MergeLet().visit(ir) ir = InlineLambdas.apply( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) - assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)) + assert isinstance(ir, itir.Program) return ir diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index e6bee51c95..a58169858a 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -246,7 +246,9 @@ def __call__( return return_type_or_synthesizer -def _get_dimensions_from_offset_provider(offset_provider) -> dict[str, common.Dimension]: +def _get_dimensions_from_offset_provider( + offset_provider: common.OffsetProvider, +) -> dict[str, common.Dimension]: dimensions: dict[str, common.Dimension] = {} for offset_name, provider in offset_provider.items(): dimensions[offset_name] = common.Dimension( @@ -285,10 +287,15 @@ def type_synthesizer(*args, **kwargs): return type_synthesizer -class RemoveTypes(eve.NodeTranslator): - def visit_Node(self, node: itir.Node): +class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir.Node: node = self.generic_visit(node) - if not isinstance(node, (itir.Literal, itir.Sym)): + # We only want to sanitize types that have been inferred previously such that we don't run + # into errors because a node has been reused in a pass, but has changed type. Undeclared + # symbols however only occur when visiting a subtree (e.g. in testing). Their types + # can be injected by populating their type attribute, which we want to preserve here. + is_undeclared_symbol = isinstance(node, itir.SymRef) and node.id not in symtable + if not is_undeclared_symbol and not isinstance(node, (itir.Literal, itir.Sym)): node.type = None return node @@ -380,8 +387,7 @@ def apply( # becomes invalid (e.g. the shift part of ``shift(...)(it)`` has a different type when used # on a different iterator). For now we just delete all types in case we are working an # parts of a program. - if not allow_undeclared_symbols: - node = RemoveTypes().visit(node) + node = SanitizeTypes().visit(node) if isinstance(node, (itir.FencilDefinition, itir.Program)): assert all(isinstance(param.type, ts.DataType) for param in node.params), ( diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index eff6b2f42a..2aa064fa15 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -116,15 +116,17 @@ def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: @_register_builtin_type_synthesizer -def deref(it: it_ts.IteratorType) -> ts.DataType: +def deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.DataType | ts.DeferredType: + if isinstance(it, ts.DeferredType): + return ts.DeferredType(constraint=None) assert isinstance(it, it_ts.IteratorType) assert _is_derefable_iterator_type(it) return it.element_type @_register_builtin_type_synthesizer -def can_deref(it: it_ts.IteratorType) -> ts.ScalarType: - assert isinstance(it, it_ts.IteratorType) +def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType: + assert isinstance(it, ts.DeferredType) or isinstance(it, it_ts.IteratorType) # note: We don't check if the iterator is derefable here as the iterator only needs to # to have a valid position. Consider a nested reduction, e.g. # `reduce(plus, 0)(neighbors(V2Eₒ, (↑(λ(a) → reduce(plus, 0)(neighbors(E2Vₒ, a))))(it))` @@ -324,7 +326,11 @@ def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider @_register_builtin_type_synthesizer def shift(*offset_literals, offset_provider) -> TypeSynthesizer: @TypeSynthesizer - def apply_shift(it: it_ts.IteratorType) -> it_ts.IteratorType: + def apply_shift( + it: it_ts.IteratorType | ts.DeferredType, + ) -> it_ts.IteratorType | ts.DeferredType: + if isinstance(it, ts.DeferredType): + return ts.DeferredType(constraint=None) assert isinstance(it, it_ts.IteratorType) if it.position_dims == "unknown": # nothing to do here return it diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index 30ef08a04a..edcc180f83 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -264,25 +264,23 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common. self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=node)) return gtfn_ir_common.SymRef(id=f"{lam_idx}") if isinstance(node.fun, gtfn_ir.Lambda): + localized_symbols = {**kwargs["localized_symbols"]} # create a new scope + lam_idx = self.uids.sequential_id(prefix="lam") params = [self.visit(param, **kwargs) for param in node.fun.params] args = [self.visit(arg, **kwargs) for arg in node.args] for param, arg in zip(params, args): - if param.id in self.sym_table: - kwargs["localized_symbols"][param.id] = ( - f"{param.id}_{self.uids.sequential_id()}_local" - ) - self.imp_list_ir.append( - InitStmt( - lhs=gtfn_ir_common.Sym(id=kwargs["localized_symbols"][param.id]), - rhs=arg, - ) - ) - else: - self.imp_list_ir.append( - InitStmt(lhs=gtfn_ir_common.Sym(id=f"{param.id}"), rhs=arg) + localized_symbols[param.id] = new_symbol = ( + f"{param.id}_{self.uids.sequential_id()}_local" + ) + self.imp_list_ir.append( + InitStmt( + lhs=gtfn_ir_common.Sym(id=new_symbol), + rhs=arg, ) - expr = self.visit(node.fun.expr, **kwargs) + ) + new_kwargs = {**kwargs, "localized_symbols": localized_symbols} + expr = self.visit(node.fun.expr, **new_kwargs) self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr)) return gtfn_ir_common.SymRef(id=f"{lam_idx}") if _is_reduce(node): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 384d74a6c2..b3bec1e1f7 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -27,7 +27,7 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, global_tmps, pass_manager +from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen @@ -181,9 +181,11 @@ def _preprocess_program( self, program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], - ) -> itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program: + ) -> itir.Program: if not self.enable_itir_transforms: - return program + return fencil_to_program.FencilToProgram().apply( + program + ) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, @@ -216,11 +218,8 @@ def generate_stencil_source( column_axis: Optional[common.Dimension], ) -> str: new_program = self._preprocess_program(program, offset_provider) - program_itir = fencil_to_program.FencilToProgram().apply( - new_program - ) # TODO(havogt): should be removed after refactoring to combined IR gtfn_ir = GTFN_lowering.apply( - program_itir, offset_provider=offset_provider, column_axis=column_axis + new_program, offset_provider=offset_provider, column_axis=column_axis ) if self.use_imperative_backend: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 7147182fe8..ba479c1e63 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -87,18 +87,12 @@ def preprocess_program( node = itir_type_inference.infer(node, offset_provider=offset_provider) - if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries): - fencil_definition = node.fencil - tmps = node.tmps - - elif isinstance(node, itir.FencilDefinition): + if isinstance(node, itir.Program): fencil_definition = node - tmps = [] - + tmps = node.declarations + assert all(isinstance(tmp, itir.Temporary) for tmp in tmps) else: - raise TypeError( - f"Expected 'FencilDefinition' or 'FencilWithTemporaries', got '{type(program).__name__}'." - ) + raise TypeError(f"Expected 'Program', got '{type(node).__name__}'.") return fencil_definition, tmps diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 50ff92e4c1..13d008c617 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -28,7 +28,6 @@ from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms -from gt4py.next.iterator.transforms import fencil_to_program from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import modular_executor, processor_interface as ppi from gt4py.next.type_system import type_specifications as ts @@ -124,8 +123,6 @@ def fencil_generator( ir, lift_mode=lift_mode, offset_provider=offset_provider ) - ir = fencil_to_program.FencilToProgram.apply(ir) - program = EmbeddedDSL.apply(ir) # format output in debug mode for better debuggability diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 9807982797..39875a3169 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -113,4 +113,4 @@ def test_temporary_symbols(testee, mesh_descriptor): params = ["num_vertices", "num_edges", "num_cells"] for param in params: - assert any([param == str(p) for p in itir_with_tmp.fencil.params]) + assert any([param == str(p) for p in itir_with_tmp.params]) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index a2d0a170a0..a124ff086f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -11,17 +11,25 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import pytest import textwrap from gt4py.eve.utils import UIDGenerator +from gt4py.next import common from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.cse import ( CommonSubexpressionElimination as CSE, extract_subexpression, ) +@pytest.fixture +def offset_provider(request): + return {"I": common.Dimension("I", kind=common.DimensionKind.HORIZONTAL)} + + def test_trivial(): common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) testee = ir.FunCall(fun=ir.SymRef(id="plus"), args=[common, common]) @@ -34,7 +42,7 @@ def test_trivial(): ), args=[common], ) - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected @@ -42,7 +50,7 @@ def test_lambda_capture(): common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) testee = ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id="x")], expr=common), args=[common]) expected = testee - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected @@ -50,7 +58,7 @@ def test_lambda_no_capture(): common = im.plus("x", "y") testee = im.call(im.lambda_("z")(im.plus("x", "y")))(im.plus("x", "y")) expected = im.let("_cs_1", common)("_cs_1") - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected @@ -62,7 +70,7 @@ def common_expr(): testee = im.call(im.lambda_("x", "y")(common_expr()))(common_expr(), common_expr()) # (λ(_cs_1) → _cs_1 + _cs_1)(x + y) expected = im.let("_cs_1", common_expr())(im.plus("_cs_1", "_cs_1")) - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected @@ -82,7 +90,7 @@ def common_expr(): ) )(common_expr()) ) - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected @@ -96,7 +104,7 @@ def common_expr(): ) # (λ(_cs_1) → _cs_1(2) + _cs_1(3))(λ(a) → a + 1) expected = im.let("_cs_1", common_expr())(im.plus(im.call("_cs_1")(2), im.call("_cs_1")(3))) - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected @@ -112,7 +120,7 @@ def common_expr(): expected = im.let("_cs_1", common_expr())( im.let("_cs_2", im.call("_cs_1")(2))(im.plus("_cs_2", "_cs_2")) ) - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected @@ -136,11 +144,11 @@ def common_expr(): ) ) ) - actual = CSE().visit(testee) + actual = CSE.apply(testee, is_local_view=True) assert actual == expected -def test_if_can_deref_no_extraction(): +def test_if_can_deref_no_extraction(offset_provider): # Test that a subexpression only occurring in one branch of an `if_` is not moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -160,11 +168,11 @@ def test_if_can_deref_no_extraction(): ) ) - actual = CSE().visit(testee) + actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) assert actual == expected -def test_if_can_deref_eligible_extraction(): +def test_if_can_deref_eligible_extraction(offset_provider): # Test that a subexpression only occurring in both branches of an `if_` is moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -181,11 +189,11 @@ def test_if_can_deref_eligible_extraction(): ) ) - actual = CSE().visit(testee) + actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) assert actual == expected -def test_if_eligible_extraction(): +def test_if_eligible_extraction(offset_provider): # Test that a subexpression only occurring in the condition of an `if_` is moved outside the # if statement. @@ -194,7 +202,7 @@ def test_if_eligible_extraction(): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE().visit(testee) + actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) assert actual == expected @@ -255,3 +263,42 @@ def render_stmt_form(assignments: list[tuple[str, ir.Expr]], return_expr: ir.Exp actual = render_stmt_form(*convert_to_assignment_stmt_form(testee)) assert actual == expected + + +def test_no_extraction_outside_asfieldop(): + plus_fieldop = im.as_fieldop( + im.lambda_("x", "y")(im.plus(im.deref("x"), im.deref("y"))), im.call("cartesian_domain")() + ) + identity_fieldop = im.as_fieldop(im.lambda_("x")(im.deref("x")), im.call("cartesian_domain")()) + + field_type = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + # as_fieldop(λ(x, y) → ·x + ·y, cartesian_domain())( + # as_fieldop(λ(x) → ·x, cartesian_domain())(a), as_fieldop(λ(x) → ·x, cartesian_domain())(b) + # ) + testee = plus_fieldop( + identity_fieldop(im.ref("a", field_type)), identity_fieldop(im.ref("b", field_type)) + ) + + actual = CSE.apply(testee, is_local_view=False) + assert actual == testee + + +def test_field_extraction_outside_asfieldop(): + plus_fieldop = im.as_fieldop( + im.lambda_("x", "y")(im.plus(im.deref("x"), im.deref("y"))), im.call("cartesian_domain")() + ) + identity_fieldop = im.as_fieldop(im.lambda_("x")(im.deref("x")), im.call("cartesian_domain")()) + + # as_fieldop(λ(x, y) → ·x + ·y, cartesian_domain())( + # as_fieldop(λ(x) → ·x, cartesian_domain())(a), as_fieldop(λ(x) → ·x, cartesian_domain())(a) + # ) + field = im.ref("a", ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) + testee = plus_fieldop(identity_fieldop(field), identity_fieldop(field)) + + # (λ(_cs_1) → as_fieldop(λ(x, y) → ·x + ·y, cartesian_domain())(_cs_1, _cs_1))( + # as_fieldop(λ(x) → ·x, cartesian_domain())(a) + # ) + expected = im.let("_cs_1", identity_fieldop(field))(plus_fieldop("_cs_1", "_cs_1")) + + actual = CSE.apply(testee, is_local_view=False) + assert actual == expected