Skip to content

Commit

Permalink
feat[next]: GTIR embedded and GTFN temporaries with new lowering (#1648)
Browse files Browse the repository at this point in the history
Use new lowering for GTIR embedded, and GTFN. Only the dace iterator
backend continues to use the old lowering.

Changes:
- Use GTIR lowering for all backends except for dace
- Old lowering and transformations only used in dace backend
- workflows defined in
[`gt4py.next.backend.LEGACY_TRANSFORMS`](https://github.com/GridTools/gt4py/pull/1648/files#diff-cf4385d02cbeacc310d4326350903b4cb6f9a61c7cd36dda162a5077ab8b8e86).
Variable can be removed in a cleanup PR.
- old `apply_common_transforms` in
[pass_manager_legacy.py](https://github.com/GridTools/gt4py/pull/1648/files#diff-db17bff48ac16ee75ff974a1b9af98e3cf0c850971ce9898aa55b635bb046b72).
Just a straight copy of the old function. No need to review, this is
just to avoid deleting until gtir based dace backend is ready.
- Re-add `symbolic_sizes` param. Was in temporary extraction, is now
part of the domain inference. In preparation of icon-exclaim tests

---------

Co-authored-by: Hannes Vogt <hannes@havogt.de>
  • Loading branch information
tehrengruber and havogt authored Nov 15, 2024
1 parent c51bdd1 commit 998f279
Show file tree
Hide file tree
Showing 52 changed files with 945 additions and 586 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ repos:
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: debug-statements

- repo: https://github.com/astral-sh/ruff-pre-commit
##[[[cog
Expand Down
14 changes: 13 additions & 1 deletion src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gt4py._core import definitions as core_defs
from gt4py.next import allocators as next_allocators
from gt4py.next.ffront import (
foast_to_gtir,
foast_to_itir,
foast_to_past,
func_to_foast,
Expand Down Expand Up @@ -76,7 +77,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]):
)

foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] = dataclasses.field(
default_factory=foast_to_itir.adapted_foast_to_itir_factory
default_factory=foast_to_gtir.adapted_foast_to_gtir_factory
)

field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field(
Expand Down Expand Up @@ -134,6 +135,17 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]:

DEFAULT_TRANSFORMS: Transforms = Transforms()

# FIXME[#1582](havogt): remove after refactoring to GTIR
# note: this step is deliberately placed here, such that the cache is shared
_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True)
LEGACY_TRANSFORMS: Transforms = Transforms(
past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False),
foast_to_itir=_foast_to_itir_step,
field_view_op_to_prog=foast_to_past.operator_to_program_factory(
foast_to_itir_step=_foast_to_itir_step
),
)


# TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe:
# `Backend` -> `Toolchain`
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gt4py.eve import utils as eve_utils
from gt4py.next.ffront import (
dialect_ast_enums,
foast_to_itir,
foast_to_gtir,
program_ast as past,
stages as ffront_stages,
type_specifications as ts_ffront,
Expand Down Expand Up @@ -68,7 +68,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]):
... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]:
... return a
>>> op_to_prog = OperatorToProgram(foast_to_itir.adapted_foast_to_itir_factory())
>>> op_to_prog = OperatorToProgram(foast_to_gtir.adapted_foast_to_gtir_factory())
>>> compile_time_args = arguments.CompileTimeArgs(
... args=tuple(param.type for param in copy.foast_stage.foast_node.definition.params),
Expand Down Expand Up @@ -169,7 +169,7 @@ def operator_to_program_factory(
) -> workflow.Workflow[AOT_FOP, AOT_PRG]:
"""Optionally wrap `OperatorToProgram` in a `CachedStep`."""
wf: workflow.Workflow[AOT_FOP, AOT_PRG] = OperatorToProgram(
foast_to_itir_step or foast_to_itir.adapted_foast_to_itir_factory()
foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory()
)
if cached:
wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra

# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR
def past_to_itir_factory(
cached: bool = True, to_gtir: bool = False
cached: bool = True, to_gtir: bool = True
) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]:
wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir))
if cached:
Expand Down
23 changes: 17 additions & 6 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import dataclasses
import functools
from typing import Any, Literal, Mapping
from typing import Any, Literal, Mapping, Optional

import gt4py.next as gtx
from gt4py.next import common
Expand Down Expand Up @@ -93,6 +93,9 @@ def translate(
...,
],
offset_provider: common.OffsetProvider,
#: A dictionary mapping axes names to their length. See
#: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details.
symbolic_domain_sizes: Optional[dict[str, str]] = None,
) -> SymbolicDomain:
dims = list(self.ranges.keys())
new_ranges = {dim: self.ranges[dim] for dim in dims}
Expand All @@ -119,18 +122,24 @@ def translate(
trace_shifts.Sentinel.ALL_NEIGHBORS,
trace_shifts.Sentinel.VALUE,
]
# note: ugly but cheap re-computation, but should disappear
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider)
horizontal_sizes: dict[str, itir.Expr]
if symbolic_domain_sizes is not None:
horizontal_sizes = {k: im.ref(v) for k, v in symbolic_domain_sizes.items()}
else:
# note: ugly but cheap re-computation, but should disappear
horizontal_sizes = {
k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN)
for k, v in _max_domain_sizes_by_location_type(offset_provider).items()
}

old_dim = nbt_provider.origin_axis
new_dim = nbt_provider.neighbor_axis

assert new_dim not in new_ranges or old_dim == new_dim

# TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON?
new_range = SymbolicRange(
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
im.literal(str(horizontal_sizes[new_dim.value]), itir.INTEGER_INDEX_BUILTIN),
horizontal_sizes[new_dim.value],
)
new_ranges = dict(
(dim, range_) if dim != old_dim else (new_dim, new_range)
Expand All @@ -140,7 +149,9 @@ def translate(
raise AssertionError()
return SymbolicDomain(self.grid_type, new_ranges)
elif len(shift) > 2:
return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider)
return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate(
shift[2:], offset_provider, symbolic_domain_sizes
)
else:
raise AssertionError("Number of shifts must be a multiple of 2.")

Expand Down
5 changes: 2 additions & 3 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Callable, Optional, Union

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Dict, Tuple
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.type_system import type_specifications as ts, type_translation
Expand Down Expand Up @@ -412,7 +411,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall:

def domain(
grid_type: Union[common.GridType, str],
ranges: Dict[Union[common.Dimension, str], Tuple[itir.Expr, itir.Expr]],
ranges: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]],
) -> itir.FunCall:
"""
>>> str(
Expand Down Expand Up @@ -446,7 +445,7 @@ def domain(
)


def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call:
def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call:
"""
Create an `as_fieldop` call.
Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/iterator/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from gt4py.next.iterator.transforms.pass_manager import (
ITIRTransform,
LiftMode,
apply_common_transforms,
apply_fieldview_transforms,
)


__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "LiftMode", "ITIRTransform"]
__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "ITIRTransform"]
52 changes: 25 additions & 27 deletions src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# SPDX-License-Identifier: BSD-3-Clause

from gt4py import eve
from gt4py.next.iterator import ir
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im


class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator):
Expand All @@ -18,32 +19,29 @@ class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator):
- `list_get(i, make_const_list(e))` -> `e`
"""

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node:
node = self.generic_visit(node)
if node.fun == ir.SymRef(id="list_get"):
if isinstance(node.args[1], ir.FunCall):
if node.args[1].fun == ir.SymRef(id="neighbors"):
offset_tag = node.args[1].args[0]
offset_index = (
ir.OffsetLiteral(value=int(node.args[0].value))
if isinstance(node.args[0], ir.Literal)
else node.args[
0
] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn
)
it = node.args[1].args[1]
return ir.FunCall(
fun=ir.SymRef(id="deref"),
args=[
ir.FunCall(
fun=ir.FunCall(
fun=ir.SymRef(id="shift"), args=[offset_tag, offset_index]
),
args=[it],
)
],
)
if node.args[1].fun == ir.SymRef(id="make_const_list"):
return node.args[1].args[0]
if cpm.is_call_to(node, "list_get"):
if cpm.is_call_to(node.args[1], "if_"):
list_idx = node.args[0]
cond, true_val, false_val = node.args[1].args
return im.if_(
cond,
self.visit(im.call("list_get")(list_idx, true_val)),
self.visit(im.call("list_get")(list_idx, false_val)),
)
if cpm.is_call_to(node.args[1], "neighbors"):
offset_tag = node.args[1].args[0]
offset_index = (
itir.OffsetLiteral(value=int(node.args[0].value))
if isinstance(node.args[0], itir.Literal)
else node.args[
0
] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn
)
it = node.args[1].args[1]
return im.deref(im.shift(offset_tag, offset_index)(it))
if cpm.is_call_to(node.args[1], "make_const_list"):
return node.args[1].args[0]

return node
Loading

0 comments on commit 998f279

Please sign in to comment.