Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Refactor CSE pass to support ITIR.Program (GTIR branch) #1579

Merged
merged 15 commits into from
Jul 24, 2024
25 changes: 24 additions & 1 deletion src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import typing
from typing import Callable, Iterable, Union
from typing import Callable, Iterable, Optional, Union

from gt4py._core import definitions as core_defs
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -401,3 +401,26 @@ def _impl(*its: itir.Expr) -> itir.FunCall:
def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))


def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call:
"""
Create an `as_fieldop` call.

Examples
--------
>>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2"))
'(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)'
"""
return call(
call("as_fieldop")(
*(
(
expr,
domain,
)
if domain
else (expr,)
)
)
)
153 changes: 107 additions & 46 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

import dataclasses
import functools
import math
import operator
import typing
from typing import Callable, Iterable, TypeVar, Union, cast

from gt4py.eve import (
NodeTranslator,
Expand All @@ -25,32 +28,36 @@
VisitorWithSymbolTableTrait,
)
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.type_system import type_info, type_specifications as ts


@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)

expr_map: dict[int, ir.SymRef]
expr_map: dict[int, itir.SymRef]

def visit_Expr(self, node: ir.Node) -> ir.Node:
def visit_Expr(self, node: itir.Node) -> itir.Node:
if id(node) in self.expr_map:
return self.expr_map[id(node)]
return self.generic_visit(node)

def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = typing.cast(ir.FunCall, self.visit_Expr(node))
def visit_FunCall(self, node: itir.FunCall) -> itir.Node:
node = cast(itir.FunCall, self.visit_Expr(node))
# If we encounter an expression like:
# (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr)
# (non-recursively) inline the lambda to obtain:
# (λ(_cs_1) → _cs_1+_cs_1)(outer_expr)
# This allows identifying more common subexpressions later on
if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda):
if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda):
eligible_params = []
for arg in node.args:
eligible_params.append(isinstance(arg, ir.SymRef) and arg.id.startswith("_cs"))
eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs"))
if any(eligible_params):
# note: the inline is opcount preserving anyway so avoid the additional
# effort in the inliner by disabling opcount preservation.
Expand All @@ -60,18 +67,18 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
return node


def _is_collectable_expr(node: ir.Node) -> bool:
if isinstance(node, ir.FunCall):
def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce nodes if they are left in the it at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach)
if isinstance(node.fun, ir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
return False
return True
elif isinstance(node, ir.Lambda):
elif isinstance(node, itir.Lambda):
return True

return False
Expand All @@ -92,7 +99,7 @@ class SubexpressionData:
class State:
#: A dictionary mapping a node to a list of node ids with equal hash and some additional
#: information. See `SubexpressionData` for more information.
subexprs: dict[ir.Node, "CollectSubexpressions.SubexpressionData"] = dataclasses.field(
subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = dataclasses.field(
default_factory=dict
)
# TODO(tehrengruber): Revisit if this makes sense or if we can just recompute the collected
Expand All @@ -102,7 +109,7 @@ class State:
#: The ids of all nodes declaring a symbol which are referenced (using a `SymRef`)
used_symbol_ids: set[int] = dataclasses.field(default_factory=set)

def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
def remove_subexprs(self, nodes: Iterable[itir.Node]) -> None:
node_ids_to_remove: set[int] = set()
for node in nodes:
subexpr_data = self.subexprs.pop(node, None)
Expand All @@ -114,22 +121,22 @@ def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
collected_child_node_ids -= node_ids_to_remove

@classmethod
def apply(cls, node: ir.Node) -> dict[ir.Node, list[tuple[int, set[int]]]]:
def apply(cls, node: itir.Node) -> dict[itir.Node, list[tuple[int, set[int]]]]:
state = cls.State()
obj = cls()
obj.visit(node, state=state, depth=-1)
# Return subexpression such that the nodes closer to the root come first and skip the root
# node itself.
subexprs_sorted: list[tuple[ir.Node, "CollectSubexpressions.SubexpressionData"]] = sorted(
subexprs_sorted: list[tuple[itir.Node, CollectSubexpressions.SubexpressionData]] = sorted(
state.subexprs.items(), key=lambda el: el[1].max_depth
)
return {k: v.subexprs for k, v in subexprs_sorted if k is not node}

def generic_visit(self, *args, **kwargs):
def generic_visit(self, node, **kwargs):
depth = kwargs.pop("depth")
return super().generic_visit(*args, depth=depth + 1, **kwargs)
return super().generic_visit(node, depth=depth + 1, **kwargs)

def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
def visit(self, node: itir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node):
return super().visit(node, **kwargs)

Expand All @@ -141,7 +148,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
# Special handling of `if_(condition, true_branch, false_branch)` like expressions that
# avoids extracting subexpressions unless they are used in either the condition or both
# branches.
if isinstance(node, ir.FunCall) and node.fun == ir.SymRef(id="if_"):
if isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="if_"):
assert len(node.args) == 3
# collect subexpressions for all arguments to the `if_`
arg_states = [self.State() for _ in node.args]
Expand All @@ -157,7 +164,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
arg_state.remove_subexprs(arg_state.subexprs.keys() - eligible_subexprs)

# merge the states of the three arguments
subexprs: dict[ir.Node, CollectSubexpressions.SubexpressionData] = {}
subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = {}
for state in arg_states:
for subexpr, data in state.subexprs.items():
merged_data = subexprs.setdefault(subexpr, self.SubexpressionData())
Expand Down Expand Up @@ -203,19 +210,19 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
parent_state.collected_child_node_ids.update(collected_child_node_ids)

def visit_SymRef(
self, node: ir.SymRef, *, symtable: dict[str, ir.Node], state: State, **kwargs
self, node: itir.SymRef, *, symtable: dict[str, itir.Node], state: State, **kwargs
) -> None:
if node.id in symtable: # root symbol otherwise
state.used_symbol_ids.add(id(symtable[node.id]))


def extract_subexpression(
node: ir.Expr,
predicate: typing.Callable[[ir.Expr, int], bool],
node: itir.Expr,
predicate: Callable[[itir.Expr, int], bool],
uid_generator: UIDGenerator,
once_only: bool = False,
deepest_expr_first: bool = False,
) -> tuple[ir.Expr, typing.Union[dict[ir.Sym, ir.Expr], None], bool]:
) -> tuple[itir.Expr, Union[dict[itir.Sym, itir.Expr], None], bool]:
"""
Given an expression extract all subexprs and return a new expr with the subexprs replaced.

Expand Down Expand Up @@ -312,20 +319,20 @@ def extract_subexpression(
)

ignored_children = False
extracted = dict[ir.Sym, ir.Expr]()
extracted = dict[itir.Sym, itir.Expr]()

# collect expressions
subexprs = CollectSubexpressions.apply(node)

# collect multiple occurrences and map them to fresh symbols
expr_map = dict[int, ir.SymRef]()
expr_map = dict[int, itir.SymRef]()
ignored_ids = set()
for expr, subexpr_entry in (
subexprs.items() if not deepest_expr_first else reversed(subexprs.items())
):
# just to make mypy happy when calling the predicate. Every subnode and hence subexpression
# is an expr anyway.
assert isinstance(expr, ir.Expr)
assert isinstance(expr, itir.Expr)

if not predicate(expr, len(subexpr_entry)):
continue
Expand All @@ -345,8 +352,8 @@ def extract_subexpression(
continue

expr_id = uid_generator.sequential_id()
extracted[ir.Sym(id=expr_id)] = expr
expr_ref = ir.SymRef(id=expr_id)
extracted[itir.Sym(id=expr_id)] = expr
expr_ref = itir.SymRef(id=expr_id)
for id_ in eligible_ids:
expr_map[id_] = expr_ref

Expand All @@ -359,17 +366,26 @@ def extract_subexpression(
return _NodeReplacer(expr_map).visit(node), extracted, ignored_children


ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr)


@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
"""
Perform common subexpression elimination.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you briefly outline here the algorithm and the used tools (e.g. other visitors, functions, ...)

Examples:
>>> x = ir.SymRef(id="x")
>>> plus = lambda a, b: ir.FunCall(fun=ir.SymRef(id=("plus")), args=[a, b])
>>> x = itir.SymRef(id="x")
>>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b])
>>> expr = plus(plus(x, x), plus(x, x))
>>> print(CommonSubexpressionElimination().visit(expr))
>>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True))
(λ(_cs_1) → _cs_1 + _cs_1)(x + x)

The pass visits the tree top-down starting from the root node, e.g. an itir.Program.
For each node we extract (eligible) subexpressions occuring more than once using
:ref:`extract_subexpression`. In field-view context we only extract expression when they are
fields (or composites thereof), in local view everything is eligible. Since the visit is
top-down, extracted expressions always land up in the outermost scope they can appear in.
"""

# we use one UID generator per instance such that the generated ids are
Expand All @@ -380,23 +396,68 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):

collect_all: bool = dataclasses.field(default=False)

def visit_FunCall(self, node: ir.FunCall):
if isinstance(node.fun, ir.SymRef) and node.fun.id in [
"cartesian_domain",
"unstructured_domain",
]:
return node
@classmethod
def apply(
cls,
node: ProgramOrExpr,
is_local_view: bool | None = None,
offset_provider: common.OffsetProvider | None = None,
) -> ProgramOrExpr:
is_program = isinstance(node, (itir.Program, itir.FencilDefinition))
if is_program:
assert is_local_view is None
is_local_view = False
else:
assert (
is_local_view is not None
), "The expression's context must be specified using `is_local_view`."

new_expr, extracted, ignored_children = extract_subexpression(
node, lambda subexpr, num_occurences: num_occurences > 1, self.uids
offset_provider = offset_provider or {}
node = itir_type_inference.infer(
node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program
)
return cls().visit(node, is_local_view=is_local_view)

def generic_visit(self, node, **kwargs):
if cpm.is_call_to("as_fieldop", node):
assert not kwargs.get("is_local_view")
is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view")

return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view}))

def visit_FunCall(self, node: itir.FunCall, **kwargs):
is_local_view = kwargs["is_local_view"]

if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")):
return node

def predicate(subexpr: itir.Expr, num_occurences: int):
# note: be careful here with the syntatic context: the expression might be in local
# view, even though the syntactic context `node` is in field view.
# note: what is extracted is sketched in the docstring above. keep it updated.
if num_occurences > 1:
if is_local_view:
return True
else:
# only extract fields outside of `as_fieldop`
# `as_fieldop(...)(field_expr, field_expr)`
# -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)`
assert isinstance(subexpr.type, ts.TypeSpec)
if all(
isinstance(stype, ts.FieldType)
for stype in type_info.primitive_constituents(subexpr.type)
):
return True
return False

new_expr, extracted, ignored_children = extract_subexpression(node, predicate, self.uids)

if not extracted:
return self.generic_visit(node)
return self.generic_visit(node, **kwargs)

# apply remapping
result = ir.FunCall(
fun=ir.Lambda(params=list(extracted.keys()), expr=new_expr),
result = itir.FunCall(
fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr),
args=list(extracted.values()),
)

Expand All @@ -406,6 +467,6 @@ def visit_FunCall(self, node: ir.FunCall):
# inside of subexpressions directly. This would require a different order of replacement
# (from lower to higher level).
if ignored_children:
return self.visit(result)
return self.visit(result, **kwargs)

return self.generic_visit(result)
return self.generic_visit(result, **kwargs)
Loading
Loading