Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tree_map and replace apply_to_primitive_constituents #1570

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
8add029
refactor[next]: itir embedded: cleaner closure run
havogt Apr 4, 2024
853d3e1
cleanup
havogt Apr 4, 2024
f661cd3
fix test
havogt Apr 4, 2024
09e568d
without temporaries
havogt Apr 8, 2024
12b8696
temporaries
havogt Apr 8, 2024
540a2d8
cleanup
havogt Apr 9, 2024
23ddef1
move to SetAt
havogt Apr 10, 2024
e64b986
Merge branch 'refactor_itir_embedded' into itir_program_embedded2
havogt Apr 10, 2024
c99f44d
embedded
havogt Apr 10, 2024
1a6f885
roundtrip+double_roundtrip with shortcuts
havogt Apr 11, 2024
39d6d7c
changes
havogt Apr 11, 2024
ab44009
fencil2program only for gtfn
havogt Apr 11, 2024
12f1663
fix import
havogt Apr 11, 2024
aa80949
Merge remote-tracking branch 'upstream/main' into itir_program
havogt Apr 11, 2024
5037493
fix builtins list
havogt Apr 11, 2024
751581e
add comment
havogt Apr 11, 2024
3d2f33e
fix type checker
havogt Apr 11, 2024
53bad75
Merge branch 'itir_program' into itir_program_embedded2
havogt Apr 11, 2024
4cbce7e
Apply suggestions from code review
havogt Apr 12, 2024
c955645
format
havogt Apr 12, 2024
6effe10
pretty printing/parsing
havogt Apr 12, 2024
66de3ec
Apply suggestions from code review
havogt Apr 15, 2024
e63da77
address more review comments
havogt Apr 15, 2024
45fba85
move tmp to pretty_printer
havogt Apr 15, 2024
1a70218
pparse for temporaries
havogt Apr 15, 2024
c39c603
rename gtfn.FencilDefinition -> Program
havogt Apr 15, 2024
705cfcf
remove TODO
havogt Apr 15, 2024
c5e78c4
Apply suggestions from code review
havogt Apr 15, 2024
a336bf5
rename as_field_operator -> as_fieldop
havogt Apr 15, 2024
af16f40
missed a file
havogt Apr 15, 2024
2d6bfbf
Merge branch 'itir_program' into itir_program_embedded2
havogt Apr 15, 2024
df1146b
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 15, 2024
bc2c2d3
add fencil2program to roundtrip
havogt Apr 16, 2024
c7ccd6a
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 16, 2024
f45b460
pre-allocate result buffer
havogt Apr 17, 2024
5d4fc3d
fix tracer context
havogt Apr 17, 2024
1d93192
first (almost) complete embedded version
havogt Apr 17, 2024
5882c28
add dim kind to print/parse
havogt Apr 17, 2024
b7cbf16
fix tests
havogt Apr 17, 2024
e97ca25
cleanup test_program
havogt Apr 17, 2024
8c2bd8f
re-enable lift mode in roundtrip
havogt Apr 18, 2024
21b230b
replace lift_mode fixture by backend in program_processor
havogt Apr 18, 2024
0946783
fix doctests
havogt Apr 18, 2024
f93da09
fix tests
havogt Apr 18, 2024
97663bd
undo quickstart changes
havogt Apr 18, 2024
3297f7b
undo delete cpp_backend_tests
havogt Apr 18, 2024
f37b372
fix quickstart guide again
havogt Apr 18, 2024
e242ab6
remove runtime lift
havogt Apr 19, 2024
a07d8ea
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 19, 2024
3c2b9a5
Merge remote-tracking branch 'upstream/main' into test_lift_mode_to_p…
havogt Apr 19, 2024
a9f1043
Update docs/user/next/QuickstartGuide.md
havogt Apr 19, 2024
e7195a5
cleanup out field construction
havogt Apr 19, 2024
3f67746
Update src/gt4py/next/program_processors/runners/double_roundtrip.py
havogt Apr 22, 2024
369eae7
read config.DEBUG at execution
havogt Apr 22, 2024
35f2132
remove LiftMode.SIMPLE_HEURISTIC
havogt Apr 22, 2024
4a4f9b1
fix formatting
havogt Apr 23, 2024
2825588
Merge branch 'test_lift_mode_to_processor' into itir_program_embedded2
havogt Apr 23, 2024
65abde7
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 23, 2024
7373eb6
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 6, 2024
bfcc118
move ordering of unstructured domain to gtfn
havogt May 15, 2024
54f44cc
fix problem in column dtype if contains None
havogt May 16, 2024
b8b26e6
address more review comments
havogt May 16, 2024
b7b489e
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 16, 2024
9ee02e4
fix tuples in columns
havogt May 16, 2024
d104633
fix preserve axis kind in global tmps
havogt May 17, 2024
866c3d6
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 17, 2024
66e2464
fix follow up issue
havogt May 17, 2024
56e0086
Start using tree_map instead of apply_to_primitive_constituents
SF-N Jun 10, 2024
8ab97c4
Add functionality to call also tree_map(lambda x: x + 1, ((1, 2), 3))…
SF-N Jul 3, 2024
7fbd083
Merge main
SF-N Dec 31, 2024
827a4d3
Run pre-commit
SF-N Dec 31, 2024
0a5ac37
Minor
SF-N Dec 31, 2024
70283e6
Replace more apply_to_primitive_constituents by tree_map
SF-N Dec 31, 2024
14619d2
Minor fix
SF-N Dec 31, 2024
fac65de
Revert replacing when tuple_constructor is present
SF-N Dec 31, 2024
de34ef6
Try to use result_collection_constructor
SF-N Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
temporaries
  • Loading branch information
havogt committed Apr 8, 2024
commit 12b8696f7b8ffdb5016c7b8d1315c78b15df5210
14 changes: 10 additions & 4 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import typing
from typing import ClassVar, List, Optional, Union
from typing import Any, ClassVar, List, Optional, Union

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
Expand Down Expand Up @@ -217,15 +217,21 @@ class Stmt(Node): ...


class Assign(Stmt):
target: SymRef
expr: Expr # TODO Program expression
target: Expr # `make_tuple` or SymRef
expr: Expr # only `apply_stencil`


class Temporary(Node):
id: Coerced[eve.SymbolName]
domain: Optional[Expr] = None
dtype: Optional[Any] = None # TODO


class Program(Node, ValidatedSymbolTableTrait):
id: Coerced[SymbolName]
function_definitions: List[FunctionDefinition]
params: List[Sym]
declarations: List[Sym]
declarations: List[Temporary]
body: List[Stmt]

_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS]
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/iterator/transforms/fencil_to_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,12 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program:
declarations=[],
body=self.visit(node.closures),
)

def visit_FencilWithTemporaries(self, node) -> itir.Program:
return itir.Program(
id=node.fencil.id,
function_definitions=node.fencil.function_definitions,
params=node.params,
declarations=node.tmps,
body=self.visit(node.fencil.closures),
)
20 changes: 6 additions & 14 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from collections.abc import Mapping
from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence

import gt4py.eve as eve
import gt4py.next as gtx
from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
Expand Down Expand Up @@ -54,26 +53,18 @@
# Iterator IR extension nodes


class Temporary(ir.Node):
"""Iterator IR extension: declaration of a temporary buffer."""

id: Coerced[eve.SymbolName]
domain: Optional[ir.Expr] = None
dtype: Optional[Any] = None


class FencilWithTemporaries(ir.Node, SymbolTableTrait):
"""Iterator IR extension: declaration of a fencil with temporary buffers."""

fencil: ir.FencilDefinition
params: list[ir.Sym]
tmps: list[Temporary]
tmps: list[ir.Temporary]


# Extensions for `PrettyPrinter` for easier debugging


def pformat_Temporary(printer: PrettyPrinter, node: Temporary, *, prec: int) -> list[str]:
def pformat_Temporary(printer: PrettyPrinter, node: ir.Temporary, *, prec: int) -> list[str]:
start, end = [node.id + " = temporary("], [");"]
args = []
if node.domain is not None:
Expand Down Expand Up @@ -367,7 +358,7 @@ def always_extract_heuristics(_):
location=node.location,
),
params=node.params,
tmps=[Temporary(id=tmp.id) for tmp in tmps],
tmps=[ir.Temporary(id=tmp.id) for tmp in tmps],
)


Expand Down Expand Up @@ -638,7 +629,8 @@ def convert_type(dtype):
fencil=node.fencil,
params=node.params,
tmps=[
Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id]) for tmp in node.tmps
ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id])
for tmp in node.tmps
],
)

Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def apply_common_transforms(
ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args
)

assert isinstance(ir, itir.FencilDefinition)
prog = FencilToProgram.apply(ir)
prog = FencilToProgram.apply(ir) # type: ignore[arg-type] # TODO: remove after refactoring

return prog
17 changes: 5 additions & 12 deletions src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,11 +575,11 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> FencilDefinition:
grid_type=self.grid_type,
offset_definitions=list(offset_definitions.values()),
function_definitions=function_definitions,
temporaries=[],
temporaries=self.visit(node.declarations, params=[p.id for p in node.params]),
)

def visit_Temporary(
self, node: global_tmps.Temporary, *, params: list, **kwargs: Any
self, node: itir.Temporary, *, params: list, **kwargs: Any
) -> TemporaryAllocation:
def dtype_to_cpp(x: int | tuple | str) -> str:
if isinstance(x, int):
Expand All @@ -601,13 +601,6 @@ def dtype_to_cpp(x: int | tuple | str) -> str:
def visit_FencilWithTemporaries(
self, node: global_tmps.FencilWithTemporaries, **kwargs: Any
) -> FencilDefinition:
fencil = self.visit(node.fencil, **kwargs)
return FencilDefinition(
id=fencil.id,
params=self.visit(node.params),
executions=fencil.executions,
grid_type=fencil.grid_type,
offset_definitions=fencil.offset_definitions,
function_definitions=fencil.function_definitions,
temporaries=self.visit(node.tmps, params=[p.id for p in node.params]),
)
raise AssertionError(
"Internal error: Fencils are no longer supported."
) # TODO remove after refactoring is complete
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
import gt4py.eve as eve
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
from gt4py.next.common import NeighborTable
from gt4py.next.iterator import (
ir as itir,
transforms as itir_transforms,
type_inference as itir_typing,
)
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -164,7 +160,7 @@ def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTable],
tmps: list[itir_transforms.global_tmps.Temporary],
tmps: list[itir.Temporary],
use_field_canonical_representation: bool,
column_axis: Optional[Dimension] = None,
):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/program_processors/runners/roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def visit_FencilWithTemporaries(
+ f"\n {node.fencil.id}({args}, **kwargs)\n"
)

def visit_Temporary(self, node: gtmps_transform.Temporary, **kwargs: Any) -> str:
def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str:
assert (
isinstance(node.domain, itir.FunCall)
and isinstance(node.domain.fun, itir.SymRef)
Expand Down