Skip to content

Commit

Permalink
feat[next]: Refactor CSE pass to support ITIR.Program (GTIR branch) (#…
Browse files Browse the repository at this point in the history
…1579)

Extends the common subexpression elimination to support the new
`itir.Program` node and pushes the intermediate `Fencil` -> `Program`
conversion upwards the pass manager. The CSE pass now uses the type
inference such that only field expressions or composites thereof are
collected in field-view context (i.e. outside of `as_fieldop`).
  • Loading branch information
tehrengruber authored Jul 24, 2024
1 parent 0416829 commit e48d8c3
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 109 deletions.
25 changes: 24 additions & 1 deletion src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import typing
from typing import Callable, Iterable, Union
from typing import Callable, Iterable, Optional, Union

from gt4py._core import definitions as core_defs
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -401,3 +401,26 @@ def _impl(*its: itir.Expr) -> itir.FunCall:
def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))


def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call:
"""
Create an `as_fieldop` call.
Examples
--------
>>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2"))
'(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)'
"""
return call(
call("as_fieldop")(
*(
(
expr,
domain,
)
if domain
else (expr,)
)
)
)
153 changes: 107 additions & 46 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

import dataclasses
import functools
import math
import operator
import typing
from typing import Callable, Iterable, TypeVar, Union, cast

from gt4py.eve import (
NodeTranslator,
Expand All @@ -25,32 +28,36 @@
VisitorWithSymbolTableTrait,
)
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.type_system import type_info, type_specifications as ts


@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)

expr_map: dict[int, ir.SymRef]
expr_map: dict[int, itir.SymRef]

def visit_Expr(self, node: ir.Node) -> ir.Node:
def visit_Expr(self, node: itir.Node) -> itir.Node:
if id(node) in self.expr_map:
return self.expr_map[id(node)]
return self.generic_visit(node)

def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = typing.cast(ir.FunCall, self.visit_Expr(node))
def visit_FunCall(self, node: itir.FunCall) -> itir.Node:
node = cast(itir.FunCall, self.visit_Expr(node))
# If we encounter an expression like:
# (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr)
# (non-recursively) inline the lambda to obtain:
# (λ(_cs_1) → _cs_1+_cs_1)(outer_expr)
# This allows identifying more common subexpressions later on
if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda):
if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda):
eligible_params = []
for arg in node.args:
eligible_params.append(isinstance(arg, ir.SymRef) and arg.id.startswith("_cs"))
eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs"))
if any(eligible_params):
# note: the inline is opcount preserving anyway so avoid the additional
# effort in the inliner by disabling opcount preservation.
Expand All @@ -60,18 +67,18 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
return node


def _is_collectable_expr(node: ir.Node) -> bool:
if isinstance(node, ir.FunCall):
def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce nodes if they are left in the it at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach)
if isinstance(node.fun, ir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
return False
return True
elif isinstance(node, ir.Lambda):
elif isinstance(node, itir.Lambda):
return True

return False
Expand All @@ -92,7 +99,7 @@ class SubexpressionData:
class State:
#: A dictionary mapping a node to a list of node ids with equal hash and some additional
#: information. See `SubexpressionData` for more information.
subexprs: dict[ir.Node, "CollectSubexpressions.SubexpressionData"] = dataclasses.field(
subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = dataclasses.field(
default_factory=dict
)
# TODO(tehrengruber): Revisit if this makes sense or if we can just recompute the collected
Expand All @@ -102,7 +109,7 @@ class State:
#: The ids of all nodes declaring a symbol which are referenced (using a `SymRef`)
used_symbol_ids: set[int] = dataclasses.field(default_factory=set)

def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
def remove_subexprs(self, nodes: Iterable[itir.Node]) -> None:
node_ids_to_remove: set[int] = set()
for node in nodes:
subexpr_data = self.subexprs.pop(node, None)
Expand All @@ -114,22 +121,22 @@ def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
collected_child_node_ids -= node_ids_to_remove

@classmethod
def apply(cls, node: ir.Node) -> dict[ir.Node, list[tuple[int, set[int]]]]:
def apply(cls, node: itir.Node) -> dict[itir.Node, list[tuple[int, set[int]]]]:
state = cls.State()
obj = cls()
obj.visit(node, state=state, depth=-1)
# Return subexpression such that the nodes closer to the root come first and skip the root
# node itself.
subexprs_sorted: list[tuple[ir.Node, "CollectSubexpressions.SubexpressionData"]] = sorted(
subexprs_sorted: list[tuple[itir.Node, CollectSubexpressions.SubexpressionData]] = sorted(
state.subexprs.items(), key=lambda el: el[1].max_depth
)
return {k: v.subexprs for k, v in subexprs_sorted if k is not node}

def generic_visit(self, *args, **kwargs):
def generic_visit(self, node, **kwargs):
depth = kwargs.pop("depth")
return super().generic_visit(*args, depth=depth + 1, **kwargs)
return super().generic_visit(node, depth=depth + 1, **kwargs)

def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
def visit(self, node: itir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node):
return super().visit(node, **kwargs)

Expand All @@ -141,7 +148,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
# Special handling of `if_(condition, true_branch, false_branch)` like expressions that
# avoids extracting subexpressions unless they are used in either the condition or both
# branches.
if isinstance(node, ir.FunCall) and node.fun == ir.SymRef(id="if_"):
if isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="if_"):
assert len(node.args) == 3
# collect subexpressions for all arguments to the `if_`
arg_states = [self.State() for _ in node.args]
Expand All @@ -157,7 +164,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
arg_state.remove_subexprs(arg_state.subexprs.keys() - eligible_subexprs)

# merge the states of the three arguments
subexprs: dict[ir.Node, CollectSubexpressions.SubexpressionData] = {}
subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = {}
for state in arg_states:
for subexpr, data in state.subexprs.items():
merged_data = subexprs.setdefault(subexpr, self.SubexpressionData())
Expand Down Expand Up @@ -203,19 +210,19 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
parent_state.collected_child_node_ids.update(collected_child_node_ids)

def visit_SymRef(
self, node: ir.SymRef, *, symtable: dict[str, ir.Node], state: State, **kwargs
self, node: itir.SymRef, *, symtable: dict[str, itir.Node], state: State, **kwargs
) -> None:
if node.id in symtable: # root symbol otherwise
state.used_symbol_ids.add(id(symtable[node.id]))


def extract_subexpression(
node: ir.Expr,
predicate: typing.Callable[[ir.Expr, int], bool],
node: itir.Expr,
predicate: Callable[[itir.Expr, int], bool],
uid_generator: UIDGenerator,
once_only: bool = False,
deepest_expr_first: bool = False,
) -> tuple[ir.Expr, typing.Union[dict[ir.Sym, ir.Expr], None], bool]:
) -> tuple[itir.Expr, Union[dict[itir.Sym, itir.Expr], None], bool]:
"""
Given an expression extract all subexprs and return a new expr with the subexprs replaced.
Expand Down Expand Up @@ -312,20 +319,20 @@ def extract_subexpression(
)

ignored_children = False
extracted = dict[ir.Sym, ir.Expr]()
extracted = dict[itir.Sym, itir.Expr]()

# collect expressions
subexprs = CollectSubexpressions.apply(node)

# collect multiple occurrences and map them to fresh symbols
expr_map = dict[int, ir.SymRef]()
expr_map = dict[int, itir.SymRef]()
ignored_ids = set()
for expr, subexpr_entry in (
subexprs.items() if not deepest_expr_first else reversed(subexprs.items())
):
# just to make mypy happy when calling the predicate. Every subnode and hence subexpression
# is an expr anyway.
assert isinstance(expr, ir.Expr)
assert isinstance(expr, itir.Expr)

if not predicate(expr, len(subexpr_entry)):
continue
Expand All @@ -345,8 +352,8 @@ def extract_subexpression(
continue

expr_id = uid_generator.sequential_id()
extracted[ir.Sym(id=expr_id)] = expr
expr_ref = ir.SymRef(id=expr_id)
extracted[itir.Sym(id=expr_id)] = expr
expr_ref = itir.SymRef(id=expr_id)
for id_ in eligible_ids:
expr_map[id_] = expr_ref

Expand All @@ -359,17 +366,26 @@ def extract_subexpression(
return _NodeReplacer(expr_map).visit(node), extracted, ignored_children


ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr)


@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
"""
Perform common subexpression elimination.
Examples:
>>> x = ir.SymRef(id="x")
>>> plus = lambda a, b: ir.FunCall(fun=ir.SymRef(id=("plus")), args=[a, b])
>>> x = itir.SymRef(id="x")
>>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b])
>>> expr = plus(plus(x, x), plus(x, x))
>>> print(CommonSubexpressionElimination().visit(expr))
>>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True))
(λ(_cs_1) → _cs_1 + _cs_1)(x + x)
The pass visits the tree top-down starting from the root node, e.g. an itir.Program.
For each node we extract (eligible) subexpressions occuring more than once using
:ref:`extract_subexpression`. In field-view context we only extract expression when they are
fields (or composites thereof), in local view everything is eligible. Since the visit is
top-down, extracted expressions always land up in the outermost scope they can appear in.
"""

# we use one UID generator per instance such that the generated ids are
Expand All @@ -380,23 +396,68 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):

collect_all: bool = dataclasses.field(default=False)

def visit_FunCall(self, node: ir.FunCall):
if isinstance(node.fun, ir.SymRef) and node.fun.id in [
"cartesian_domain",
"unstructured_domain",
]:
return node
@classmethod
def apply(
cls,
node: ProgramOrExpr,
is_local_view: bool | None = None,
offset_provider: common.OffsetProvider | None = None,
) -> ProgramOrExpr:
is_program = isinstance(node, (itir.Program, itir.FencilDefinition))
if is_program:
assert is_local_view is None
is_local_view = False
else:
assert (
is_local_view is not None
), "The expression's context must be specified using `is_local_view`."

new_expr, extracted, ignored_children = extract_subexpression(
node, lambda subexpr, num_occurences: num_occurences > 1, self.uids
offset_provider = offset_provider or {}
node = itir_type_inference.infer(
node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program
)
return cls().visit(node, is_local_view=is_local_view)

def generic_visit(self, node, **kwargs):
if cpm.is_call_to("as_fieldop", node):
assert not kwargs.get("is_local_view")
is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view")

return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view}))

def visit_FunCall(self, node: itir.FunCall, **kwargs):
is_local_view = kwargs["is_local_view"]

if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")):
return node

def predicate(subexpr: itir.Expr, num_occurences: int):
# note: be careful here with the syntatic context: the expression might be in local
# view, even though the syntactic context `node` is in field view.
# note: what is extracted is sketched in the docstring above. keep it updated.
if num_occurences > 1:
if is_local_view:
return True
else:
# only extract fields outside of `as_fieldop`
# `as_fieldop(...)(field_expr, field_expr)`
# -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)`
assert isinstance(subexpr.type, ts.TypeSpec)
if all(
isinstance(stype, ts.FieldType)
for stype in type_info.primitive_constituents(subexpr.type)
):
return True
return False

new_expr, extracted, ignored_children = extract_subexpression(node, predicate, self.uids)

if not extracted:
return self.generic_visit(node)
return self.generic_visit(node, **kwargs)

# apply remapping
result = ir.FunCall(
fun=ir.Lambda(params=list(extracted.keys()), expr=new_expr),
result = itir.FunCall(
fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr),
args=list(extracted.values()),
)

Expand All @@ -406,6 +467,6 @@ def visit_FunCall(self, node: ir.FunCall):
# inside of subexpressions directly. This would require a different order of replacement
# (from lower to higher level).
if ignored_children:
return self.visit(result)
return self.visit(result, **kwargs)

return self.generic_visit(result)
return self.generic_visit(result, **kwargs)
Loading

0 comments on commit e48d8c3

Please sign in to comment.