diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py
index 40bfc0ab75..96b585b1f4 100644
--- a/src/gt4py/next/iterator/ir_utils/ir_makers.py
+++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py
@@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import typing
-from typing import Callable, Iterable, Union
+from typing import Callable, Iterable, Optional, Union
from gt4py._core import definitions as core_defs
from gt4py.next.iterator import ir as itir
@@ -401,3 +401,26 @@ def _impl(*its: itir.Expr) -> itir.FunCall:
def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
+
+
+def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call:
+ """
+ Create an `as_fieldop` call.
+
+ Examples
+ --------
+ >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2"))
+ '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)'
+ """
+ return call(
+ call("as_fieldop")(
+ *(
+ (
+ expr,
+ domain,
+ )
+ if domain
+ else (expr,)
+ )
+ )
+ )
diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py
index 32714232a6..acd896d753 100644
--- a/src/gt4py/next/iterator/transforms/cse.py
+++ b/src/gt4py/next/iterator/transforms/cse.py
@@ -11,11 +11,14 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
+
+from __future__ import annotations
+
import dataclasses
import functools
import math
import operator
-import typing
+from typing import Callable, Iterable, TypeVar, Union, cast
from gt4py.eve import (
NodeTranslator,
@@ -25,32 +28,36 @@
VisitorWithSymbolTableTrait,
)
from gt4py.eve.utils import UIDGenerator
-from gt4py.next.iterator import ir
+from gt4py.next import common
+from gt4py.next.iterator import ir as itir
+from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
+from gt4py.next.iterator.type_system import inference as itir_type_inference
+from gt4py.next.type_system import type_info, type_specifications as ts
@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)
- expr_map: dict[int, ir.SymRef]
+ expr_map: dict[int, itir.SymRef]
- def visit_Expr(self, node: ir.Node) -> ir.Node:
+ def visit_Expr(self, node: itir.Node) -> itir.Node:
if id(node) in self.expr_map:
return self.expr_map[id(node)]
return self.generic_visit(node)
- def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
- node = typing.cast(ir.FunCall, self.visit_Expr(node))
+ def visit_FunCall(self, node: itir.FunCall) -> itir.Node:
+ node = cast(itir.FunCall, self.visit_Expr(node))
# If we encounter an expression like:
# (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr)
# (non-recursively) inline the lambda to obtain:
# (λ(_cs_1) → _cs_1+_cs_1)(outer_expr)
# This allows identifying more common subexpressions later on
- if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda):
+ if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda):
eligible_params = []
for arg in node.args:
- eligible_params.append(isinstance(arg, ir.SymRef) and arg.id.startswith("_cs"))
+ eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs"))
if any(eligible_params):
# note: the inline is opcount preserving anyway so avoid the additional
# effort in the inliner by disabling opcount preservation.
@@ -60,18 +67,18 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
return node
-def _is_collectable_expr(node: ir.Node) -> bool:
- if isinstance(node, ir.FunCall):
+def _is_collectable_expr(node: itir.Node) -> bool:
+ if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce nodes if they are left in the it at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach)
- if isinstance(node.fun, ir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
+ if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
return False
return True
- elif isinstance(node, ir.Lambda):
+ elif isinstance(node, itir.Lambda):
return True
return False
@@ -92,7 +99,7 @@ class SubexpressionData:
class State:
#: A dictionary mapping a node to a list of node ids with equal hash and some additional
#: information. See `SubexpressionData` for more information.
- subexprs: dict[ir.Node, "CollectSubexpressions.SubexpressionData"] = dataclasses.field(
+ subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = dataclasses.field(
default_factory=dict
)
# TODO(tehrengruber): Revisit if this makes sense or if we can just recompute the collected
@@ -102,7 +109,7 @@ class State:
#: The ids of all nodes declaring a symbol which are referenced (using a `SymRef`)
used_symbol_ids: set[int] = dataclasses.field(default_factory=set)
- def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
+ def remove_subexprs(self, nodes: Iterable[itir.Node]) -> None:
node_ids_to_remove: set[int] = set()
for node in nodes:
subexpr_data = self.subexprs.pop(node, None)
@@ -114,22 +121,22 @@ def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
collected_child_node_ids -= node_ids_to_remove
@classmethod
- def apply(cls, node: ir.Node) -> dict[ir.Node, list[tuple[int, set[int]]]]:
+ def apply(cls, node: itir.Node) -> dict[itir.Node, list[tuple[int, set[int]]]]:
state = cls.State()
obj = cls()
obj.visit(node, state=state, depth=-1)
# Return subexpression such that the nodes closer to the root come first and skip the root
# node itself.
- subexprs_sorted: list[tuple[ir.Node, "CollectSubexpressions.SubexpressionData"]] = sorted(
+ subexprs_sorted: list[tuple[itir.Node, CollectSubexpressions.SubexpressionData]] = sorted(
state.subexprs.items(), key=lambda el: el[1].max_depth
)
return {k: v.subexprs for k, v in subexprs_sorted if k is not node}
- def generic_visit(self, *args, **kwargs):
+ def generic_visit(self, node, **kwargs):
depth = kwargs.pop("depth")
- return super().generic_visit(*args, depth=depth + 1, **kwargs)
+ return super().generic_visit(node, depth=depth + 1, **kwargs)
- def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
+ def visit(self, node: itir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node):
return super().visit(node, **kwargs)
@@ -141,7 +148,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
# Special handling of `if_(condition, true_branch, false_branch)` like expressions that
# avoids extracting subexpressions unless they are used in either the condition or both
# branches.
- if isinstance(node, ir.FunCall) and node.fun == ir.SymRef(id="if_"):
+ if isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="if_"):
assert len(node.args) == 3
# collect subexpressions for all arguments to the `if_`
arg_states = [self.State() for _ in node.args]
@@ -157,7 +164,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
arg_state.remove_subexprs(arg_state.subexprs.keys() - eligible_subexprs)
# merge the states of the three arguments
- subexprs: dict[ir.Node, CollectSubexpressions.SubexpressionData] = {}
+ subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = {}
for state in arg_states:
for subexpr, data in state.subexprs.items():
merged_data = subexprs.setdefault(subexpr, self.SubexpressionData())
@@ -203,19 +210,19 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
parent_state.collected_child_node_ids.update(collected_child_node_ids)
def visit_SymRef(
- self, node: ir.SymRef, *, symtable: dict[str, ir.Node], state: State, **kwargs
+ self, node: itir.SymRef, *, symtable: dict[str, itir.Node], state: State, **kwargs
) -> None:
if node.id in symtable: # root symbol otherwise
state.used_symbol_ids.add(id(symtable[node.id]))
def extract_subexpression(
- node: ir.Expr,
- predicate: typing.Callable[[ir.Expr, int], bool],
+ node: itir.Expr,
+ predicate: Callable[[itir.Expr, int], bool],
uid_generator: UIDGenerator,
once_only: bool = False,
deepest_expr_first: bool = False,
-) -> tuple[ir.Expr, typing.Union[dict[ir.Sym, ir.Expr], None], bool]:
+) -> tuple[itir.Expr, Union[dict[itir.Sym, itir.Expr], None], bool]:
"""
Given an expression extract all subexprs and return a new expr with the subexprs replaced.
@@ -312,20 +319,20 @@ def extract_subexpression(
)
ignored_children = False
- extracted = dict[ir.Sym, ir.Expr]()
+ extracted = dict[itir.Sym, itir.Expr]()
# collect expressions
subexprs = CollectSubexpressions.apply(node)
# collect multiple occurrences and map them to fresh symbols
- expr_map = dict[int, ir.SymRef]()
+ expr_map = dict[int, itir.SymRef]()
ignored_ids = set()
for expr, subexpr_entry in (
subexprs.items() if not deepest_expr_first else reversed(subexprs.items())
):
# just to make mypy happy when calling the predicate. Every subnode and hence subexpression
# is an expr anyway.
- assert isinstance(expr, ir.Expr)
+ assert isinstance(expr, itir.Expr)
if not predicate(expr, len(subexpr_entry)):
continue
@@ -345,8 +352,8 @@ def extract_subexpression(
continue
expr_id = uid_generator.sequential_id()
- extracted[ir.Sym(id=expr_id)] = expr
- expr_ref = ir.SymRef(id=expr_id)
+ extracted[itir.Sym(id=expr_id)] = expr
+ expr_ref = itir.SymRef(id=expr_id)
for id_ in eligible_ids:
expr_map[id_] = expr_ref
@@ -359,17 +366,26 @@ def extract_subexpression(
return _NodeReplacer(expr_map).visit(node), extracted, ignored_children
+ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr)
+
+
@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
"""
Perform common subexpression elimination.
Examples:
- >>> x = ir.SymRef(id="x")
- >>> plus = lambda a, b: ir.FunCall(fun=ir.SymRef(id=("plus")), args=[a, b])
+ >>> x = itir.SymRef(id="x")
+ >>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b])
>>> expr = plus(plus(x, x), plus(x, x))
- >>> print(CommonSubexpressionElimination().visit(expr))
+ >>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True))
(λ(_cs_1) → _cs_1 + _cs_1)(x + x)
+
+ The pass visits the tree top-down starting from the root node, e.g. an itir.Program.
+ For each node we extract (eligible) subexpressions occuring more than once using
+ :ref:`extract_subexpression`. In field-view context we only extract expression when they are
+ fields (or composites thereof), in local view everything is eligible. Since the visit is
+ top-down, extracted expressions always land up in the outermost scope they can appear in.
"""
# we use one UID generator per instance such that the generated ids are
@@ -380,23 +396,68 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
collect_all: bool = dataclasses.field(default=False)
- def visit_FunCall(self, node: ir.FunCall):
- if isinstance(node.fun, ir.SymRef) and node.fun.id in [
- "cartesian_domain",
- "unstructured_domain",
- ]:
- return node
+ @classmethod
+ def apply(
+ cls,
+ node: ProgramOrExpr,
+ is_local_view: bool | None = None,
+ offset_provider: common.OffsetProvider | None = None,
+ ) -> ProgramOrExpr:
+ is_program = isinstance(node, (itir.Program, itir.FencilDefinition))
+ if is_program:
+ assert is_local_view is None
+ is_local_view = False
+ else:
+ assert (
+ is_local_view is not None
+ ), "The expression's context must be specified using `is_local_view`."
- new_expr, extracted, ignored_children = extract_subexpression(
- node, lambda subexpr, num_occurences: num_occurences > 1, self.uids
+ offset_provider = offset_provider or {}
+ node = itir_type_inference.infer(
+ node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program
)
+ return cls().visit(node, is_local_view=is_local_view)
+
+ def generic_visit(self, node, **kwargs):
+ if cpm.is_call_to("as_fieldop", node):
+ assert not kwargs.get("is_local_view")
+ is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view")
+
+ return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view}))
+
+ def visit_FunCall(self, node: itir.FunCall, **kwargs):
+ is_local_view = kwargs["is_local_view"]
+
+ if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")):
+ return node
+
+ def predicate(subexpr: itir.Expr, num_occurences: int):
+ # note: be careful here with the syntatic context: the expression might be in local
+ # view, even though the syntactic context `node` is in field view.
+ # note: what is extracted is sketched in the docstring above. keep it updated.
+ if num_occurences > 1:
+ if is_local_view:
+ return True
+ else:
+ # only extract fields outside of `as_fieldop`
+ # `as_fieldop(...)(field_expr, field_expr)`
+ # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)`
+ assert isinstance(subexpr.type, ts.TypeSpec)
+ if all(
+ isinstance(stype, ts.FieldType)
+ for stype in type_info.primitive_constituents(subexpr.type)
+ ):
+ return True
+ return False
+
+ new_expr, extracted, ignored_children = extract_subexpression(node, predicate, self.uids)
if not extracted:
- return self.generic_visit(node)
+ return self.generic_visit(node, **kwargs)
# apply remapping
- result = ir.FunCall(
- fun=ir.Lambda(params=list(extracted.keys()), expr=new_expr),
+ result = itir.FunCall(
+ fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr),
args=list(extracted.values()),
)
@@ -406,6 +467,6 @@ def visit_FunCall(self, node: ir.FunCall):
# inside of subexpressions directly. This would require a different order of replacement
# (from lower to higher level).
if ignored_children:
- return self.visit(result)
+ return self.visit(result, **kwargs)
- return self.generic_visit(result)
+ return self.generic_visit(result, **kwargs)
diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py
index 655fa2f4f5..acff274b92 100644
--- a/src/gt4py/next/iterator/transforms/pass_manager.py
+++ b/src/gt4py/next/iterator/transforms/pass_manager.py
@@ -17,6 +17,7 @@
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
+from gt4py.next.iterator.transforms import fencil_to_program
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
from gt4py.next.iterator.transforms.constant_folding import ConstantFolding
@@ -84,7 +85,7 @@ def apply_common_transforms(
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
] = None,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
-) -> itir.FencilDefinition | FencilWithTemporaries | itir.Program:
+) -> itir.Program:
if isinstance(ir, itir.Program):
# TODO(havogt): during refactoring to GTIR, we bypass transformations in case we already translated to itir.Program
# (currently the case when using the roundtrip backend)
@@ -180,6 +181,11 @@ def apply_common_transforms(
ir = FuseMaps().visit(ir)
ir = CollapseListGet().visit(ir)
+ assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries))
+ ir = fencil_to_program.FencilToProgram().apply(
+ ir
+ ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR
+
if unroll_reduce:
for _ in range(10):
unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider)
@@ -197,12 +203,12 @@ def apply_common_transforms(
ir = ScanEtaReduction().visit(ir)
if common_subexpression_elimination:
- ir = CommonSubexpressionElimination().visit(ir)
+ ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program
ir = MergeLet().visit(ir)
ir = InlineLambdas.apply(
ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args
)
- assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries))
+ assert isinstance(ir, itir.Program)
return ir
diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py
index e6bee51c95..a58169858a 100644
--- a/src/gt4py/next/iterator/type_system/inference.py
+++ b/src/gt4py/next/iterator/type_system/inference.py
@@ -246,7 +246,9 @@ def __call__(
return return_type_or_synthesizer
-def _get_dimensions_from_offset_provider(offset_provider) -> dict[str, common.Dimension]:
+def _get_dimensions_from_offset_provider(
+ offset_provider: common.OffsetProvider,
+) -> dict[str, common.Dimension]:
dimensions: dict[str, common.Dimension] = {}
for offset_name, provider in offset_provider.items():
dimensions[offset_name] = common.Dimension(
@@ -285,10 +287,15 @@ def type_synthesizer(*args, **kwargs):
return type_synthesizer
-class RemoveTypes(eve.NodeTranslator):
- def visit_Node(self, node: itir.Node):
+class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait):
+ def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir.Node:
node = self.generic_visit(node)
- if not isinstance(node, (itir.Literal, itir.Sym)):
+ # We only want to sanitize types that have been inferred previously such that we don't run
+ # into errors because a node has been reused in a pass, but has changed type. Undeclared
+ # symbols however only occur when visiting a subtree (e.g. in testing). Their types
+ # can be injected by populating their type attribute, which we want to preserve here.
+ is_undeclared_symbol = isinstance(node, itir.SymRef) and node.id not in symtable
+ if not is_undeclared_symbol and not isinstance(node, (itir.Literal, itir.Sym)):
node.type = None
return node
@@ -380,8 +387,7 @@ def apply(
# becomes invalid (e.g. the shift part of ``shift(...)(it)`` has a different type when used
# on a different iterator). For now we just delete all types in case we are working an
# parts of a program.
- if not allow_undeclared_symbols:
- node = RemoveTypes().visit(node)
+ node = SanitizeTypes().visit(node)
if isinstance(node, (itir.FencilDefinition, itir.Program)):
assert all(isinstance(param.type, ts.DataType) for param in node.params), (
diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py
index eff6b2f42a..2aa064fa15 100644
--- a/src/gt4py/next/iterator/type_system/type_synthesizer.py
+++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py
@@ -116,15 +116,17 @@ def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType:
@_register_builtin_type_synthesizer
-def deref(it: it_ts.IteratorType) -> ts.DataType:
+def deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.DataType | ts.DeferredType:
+ if isinstance(it, ts.DeferredType):
+ return ts.DeferredType(constraint=None)
assert isinstance(it, it_ts.IteratorType)
assert _is_derefable_iterator_type(it)
return it.element_type
@_register_builtin_type_synthesizer
-def can_deref(it: it_ts.IteratorType) -> ts.ScalarType:
- assert isinstance(it, it_ts.IteratorType)
+def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType:
+ assert isinstance(it, ts.DeferredType) or isinstance(it, it_ts.IteratorType)
# note: We don't check if the iterator is derefable here as the iterator only needs to
# to have a valid position. Consider a nested reduction, e.g.
# `reduce(plus, 0)(neighbors(V2Eₒ, (↑(λ(a) → reduce(plus, 0)(neighbors(E2Vₒ, a))))(it))`
@@ -324,7 +326,11 @@ def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider
@_register_builtin_type_synthesizer
def shift(*offset_literals, offset_provider) -> TypeSynthesizer:
@TypeSynthesizer
- def apply_shift(it: it_ts.IteratorType) -> it_ts.IteratorType:
+ def apply_shift(
+ it: it_ts.IteratorType | ts.DeferredType,
+ ) -> it_ts.IteratorType | ts.DeferredType:
+ if isinstance(it, ts.DeferredType):
+ return ts.DeferredType(constraint=None)
assert isinstance(it, it_ts.IteratorType)
if it.position_dims == "unknown": # nothing to do here
return it
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py
index 30ef08a04a..edcc180f83 100644
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py
+++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py
@@ -264,25 +264,23 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.
self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=node))
return gtfn_ir_common.SymRef(id=f"{lam_idx}")
if isinstance(node.fun, gtfn_ir.Lambda):
+ localized_symbols = {**kwargs["localized_symbols"]} # create a new scope
+
lam_idx = self.uids.sequential_id(prefix="lam")
params = [self.visit(param, **kwargs) for param in node.fun.params]
args = [self.visit(arg, **kwargs) for arg in node.args]
for param, arg in zip(params, args):
- if param.id in self.sym_table:
- kwargs["localized_symbols"][param.id] = (
- f"{param.id}_{self.uids.sequential_id()}_local"
- )
- self.imp_list_ir.append(
- InitStmt(
- lhs=gtfn_ir_common.Sym(id=kwargs["localized_symbols"][param.id]),
- rhs=arg,
- )
- )
- else:
- self.imp_list_ir.append(
- InitStmt(lhs=gtfn_ir_common.Sym(id=f"{param.id}"), rhs=arg)
+ localized_symbols[param.id] = new_symbol = (
+ f"{param.id}_{self.uids.sequential_id()}_local"
+ )
+ self.imp_list_ir.append(
+ InitStmt(
+ lhs=gtfn_ir_common.Sym(id=new_symbol),
+ rhs=arg,
)
- expr = self.visit(node.fun.expr, **kwargs)
+ )
+ new_kwargs = {**kwargs, "localized_symbols": localized_symbols}
+ expr = self.visit(node.fun.expr, **new_kwargs)
self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr))
return gtfn_ir_common.SymRef(id=f"{lam_idx}")
if _is_reduce(node):
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
index 384d74a6c2..b3bec1e1f7 100644
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
+++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
@@ -27,7 +27,7 @@
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
-from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, global_tmps, pass_manager
+from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager
from gt4py.next.otf import languages, stages, step_types, workflow
from gt4py.next.otf.binding import cpp_interface, interface
from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
@@ -181,9 +181,11 @@ def _preprocess_program(
self,
program: itir.FencilDefinition,
offset_provider: dict[str, Connectivity | Dimension],
- ) -> itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program:
+ ) -> itir.Program:
if not self.enable_itir_transforms:
- return program
+ return fencil_to_program.FencilToProgram().apply(
+ program
+ ) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR
apply_common_transforms = functools.partial(
pass_manager.apply_common_transforms,
@@ -216,11 +218,8 @@ def generate_stencil_source(
column_axis: Optional[common.Dimension],
) -> str:
new_program = self._preprocess_program(program, offset_provider)
- program_itir = fencil_to_program.FencilToProgram().apply(
- new_program
- ) # TODO(havogt): should be removed after refactoring to combined IR
gtfn_ir = GTFN_lowering.apply(
- program_itir, offset_provider=offset_provider, column_axis=column_axis
+ new_program, offset_provider=offset_provider, column_axis=column_axis
)
if self.use_imperative_backend:
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
index 7147182fe8..ba479c1e63 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
@@ -87,18 +87,12 @@ def preprocess_program(
node = itir_type_inference.infer(node, offset_provider=offset_provider)
- if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries):
- fencil_definition = node.fencil
- tmps = node.tmps
-
- elif isinstance(node, itir.FencilDefinition):
+ if isinstance(node, itir.Program):
fencil_definition = node
- tmps = []
-
+ tmps = node.declarations
+ assert all(isinstance(tmp, itir.Temporary) for tmp in tmps)
else:
- raise TypeError(
- f"Expected 'FencilDefinition' or 'FencilWithTemporaries', got '{type(program).__name__}'."
- )
+ raise TypeError(f"Expected 'Program', got '{type(node).__name__}'.")
return fencil_definition, tmps
diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py
index 50ff92e4c1..13d008c617 100644
--- a/src/gt4py/next/program_processors/runners/roundtrip.py
+++ b/src/gt4py/next/program_processors/runners/roundtrip.py
@@ -28,7 +28,6 @@
from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako
from gt4py.next import allocators as next_allocators, backend as next_backend, common, config
from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms
-from gt4py.next.iterator.transforms import fencil_to_program
from gt4py.next.otf import stages, workflow
from gt4py.next.program_processors import modular_executor, processor_interface as ppi
from gt4py.next.type_system import type_specifications as ts
@@ -124,8 +123,6 @@ def fencil_generator(
ir, lift_mode=lift_mode, offset_provider=offset_provider
)
- ir = fencil_to_program.FencilToProgram.apply(ir)
-
program = EmbeddedDSL.apply(ir)
# format output in debug mode for better debuggability
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
index 9807982797..39875a3169 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
@@ -113,4 +113,4 @@ def test_temporary_symbols(testee, mesh_descriptor):
params = ["num_vertices", "num_edges", "num_cells"]
for param in params:
- assert any([param == str(p) for p in itir_with_tmp.fencil.params])
+ assert any([param == str(p) for p in itir_with_tmp.params])
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py
index a2d0a170a0..a124ff086f 100644
--- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py
@@ -11,17 +11,25 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
+import pytest
import textwrap
from gt4py.eve.utils import UIDGenerator
+from gt4py.next import common
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
+from gt4py.next.type_system import type_specifications as ts
from gt4py.next.iterator.transforms.cse import (
CommonSubexpressionElimination as CSE,
extract_subexpression,
)
+@pytest.fixture
+def offset_provider(request):
+ return {"I": common.Dimension("I", kind=common.DimensionKind.HORIZONTAL)}
+
+
def test_trivial():
common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")])
testee = ir.FunCall(fun=ir.SymRef(id="plus"), args=[common, common])
@@ -34,7 +42,7 @@ def test_trivial():
),
args=[common],
)
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
@@ -42,7 +50,7 @@ def test_lambda_capture():
common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")])
testee = ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id="x")], expr=common), args=[common])
expected = testee
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
@@ -50,7 +58,7 @@ def test_lambda_no_capture():
common = im.plus("x", "y")
testee = im.call(im.lambda_("z")(im.plus("x", "y")))(im.plus("x", "y"))
expected = im.let("_cs_1", common)("_cs_1")
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
@@ -62,7 +70,7 @@ def common_expr():
testee = im.call(im.lambda_("x", "y")(common_expr()))(common_expr(), common_expr())
# (λ(_cs_1) → _cs_1 + _cs_1)(x + y)
expected = im.let("_cs_1", common_expr())(im.plus("_cs_1", "_cs_1"))
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
@@ -82,7 +90,7 @@ def common_expr():
)
)(common_expr())
)
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
@@ -96,7 +104,7 @@ def common_expr():
)
# (λ(_cs_1) → _cs_1(2) + _cs_1(3))(λ(a) → a + 1)
expected = im.let("_cs_1", common_expr())(im.plus(im.call("_cs_1")(2), im.call("_cs_1")(3)))
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
@@ -112,7 +120,7 @@ def common_expr():
expected = im.let("_cs_1", common_expr())(
im.let("_cs_2", im.call("_cs_1")(2))(im.plus("_cs_2", "_cs_2"))
)
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
@@ -136,11 +144,11 @@ def common_expr():
)
)
)
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, is_local_view=True)
assert actual == expected
-def test_if_can_deref_no_extraction():
+def test_if_can_deref_no_extraction(offset_provider):
# Test that a subexpression only occurring in one branch of an `if_` is not moved outside the
# if statement. A case using `can_deref` is used here as it is common.
@@ -160,11 +168,11 @@ def test_if_can_deref_no_extraction():
)
)
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True)
assert actual == expected
-def test_if_can_deref_eligible_extraction():
+def test_if_can_deref_eligible_extraction(offset_provider):
# Test that a subexpression only occurring in both branches of an `if_` is moved outside the
# if statement. A case using `can_deref` is used here as it is common.
@@ -181,11 +189,11 @@ def test_if_can_deref_eligible_extraction():
)
)
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True)
assert actual == expected
-def test_if_eligible_extraction():
+def test_if_eligible_extraction(offset_provider):
# Test that a subexpression only occurring in the condition of an `if_` is moved outside the
# if statement.
@@ -194,7 +202,7 @@ def test_if_eligible_extraction():
# (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b)
expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d"))
- actual = CSE().visit(testee)
+ actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True)
assert actual == expected
@@ -255,3 +263,42 @@ def render_stmt_form(assignments: list[tuple[str, ir.Expr]], return_expr: ir.Exp
actual = render_stmt_form(*convert_to_assignment_stmt_form(testee))
assert actual == expected
+
+
+def test_no_extraction_outside_asfieldop():
+ plus_fieldop = im.as_fieldop(
+ im.lambda_("x", "y")(im.plus(im.deref("x"), im.deref("y"))), im.call("cartesian_domain")()
+ )
+ identity_fieldop = im.as_fieldop(im.lambda_("x")(im.deref("x")), im.call("cartesian_domain")())
+
+ field_type = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))
+ # as_fieldop(λ(x, y) → ·x + ·y, cartesian_domain())(
+ # as_fieldop(λ(x) → ·x, cartesian_domain())(a), as_fieldop(λ(x) → ·x, cartesian_domain())(b)
+ # )
+ testee = plus_fieldop(
+ identity_fieldop(im.ref("a", field_type)), identity_fieldop(im.ref("b", field_type))
+ )
+
+ actual = CSE.apply(testee, is_local_view=False)
+ assert actual == testee
+
+
+def test_field_extraction_outside_asfieldop():
+ plus_fieldop = im.as_fieldop(
+ im.lambda_("x", "y")(im.plus(im.deref("x"), im.deref("y"))), im.call("cartesian_domain")()
+ )
+ identity_fieldop = im.as_fieldop(im.lambda_("x")(im.deref("x")), im.call("cartesian_domain")())
+
+ # as_fieldop(λ(x, y) → ·x + ·y, cartesian_domain())(
+ # as_fieldop(λ(x) → ·x, cartesian_domain())(a), as_fieldop(λ(x) → ·x, cartesian_domain())(a)
+ # )
+ field = im.ref("a", ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)))
+ testee = plus_fieldop(identity_fieldop(field), identity_fieldop(field))
+
+ # (λ(_cs_1) → as_fieldop(λ(x, y) → ·x + ·y, cartesian_domain())(_cs_1, _cs_1))(
+ # as_fieldop(λ(x) → ·x, cartesian_domain())(a)
+ # )
+ expected = im.let("_cs_1", identity_fieldop(field))(plus_fieldop("_cs_1", "_cs_1"))
+
+ actual = CSE.apply(testee, is_local_view=False)
+ assert actual == expected