-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: GTIR embedded and GTFN temporaries with new lowering #1648
Changes from all commits
b0d688a
65a22fe
8a27d16
3edc130
031e459
dd63122
2d840cb
7172442
3313191
4a20973
f3cbc18
d4f066a
9a69dea
c1038fa
31cd94e
160aeaf
67682ed
881babf
2bf9231
19a2c5e
a6a010c
a57cb99
b50b85a
e2e3fa1
bb2f2b1
729968e
cdcaac0
9472502
01dd86e
e078b36
f0331cb
21d2fb5
3bd3ad6
61e9665
f9dff50
71f512f
6fb424f
7d809fb
c7b79c0
f16a961
3e3f9a1
edffd97
3196a11
6a4d227
183139b
04f59dd
feab647
5869769
376153f
b61be89
52a1b90
a1b4448
f93c8b4
9684507
5b86f19
7784896
0e50214
068ff06
aaba729
6044d76
1dc9ebb
685bedb
378b3b3
83e5ce2
b3ae17b
4637b7a
dbd71a9
0990858
da4a63c
c4b1ed8
f545984
0904d88
6f6c65b
606f662
b917011
6c9e8ab
320c7f8
b589346
309d58f
4d2b3da
cfb59d7
1fe44e0
ab5a6a2
6072d53
821af59
70c0dff
108af05
7380d6e
c268ee1
c5b0171
bba6aa4
840c004
f9fc5c5
c241bc4
e120849
45f41be
fccd43b
1874eba
8685731
dd5bfa7
8f1e84a
d53d3bb
4bfef54
78b7a98
88a7660
1b79b3a
b6b603e
5beccf0
deca907
af9d776
8cb36da
ee0b94a
270e173
8e2ba0c
229d3ac
580cb79
c67d355
79ab838
d53fda9
24c3e87
04f110f
c399f65
0faa7ef
8e8a0a1
e03dd38
7a4c692
405cbb0
4c27279
a666eef
2d6464f
4b67a99
af2ed5f
52df315
262ffdd
16f143f
75695d9
f3b1c6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
|
||
import dataclasses | ||
import functools | ||
from typing import Any, Literal, Mapping | ||
from typing import Any, Literal, Mapping, Optional | ||
|
||
import gt4py.next as gtx | ||
from gt4py.next import common | ||
|
@@ -93,6 +93,9 @@ def translate( | |
..., | ||
], | ||
offset_provider: common.OffsetProvider, | ||
#: A dictionary mapping axes names to their length. See | ||
#: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. | ||
symbolic_domain_sizes: Optional[dict[str, str]] = None, | ||
) -> SymbolicDomain: | ||
dims = list(self.ranges.keys()) | ||
new_ranges = {dim: self.ranges[dim] for dim in dims} | ||
|
@@ -119,18 +122,24 @@ def translate( | |
trace_shifts.Sentinel.ALL_NEIGHBORS, | ||
trace_shifts.Sentinel.VALUE, | ||
] | ||
# note: ugly but cheap re-computation, but should disappear | ||
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) | ||
horizontal_sizes: dict[str, itir.Expr] | ||
if symbolic_domain_sizes is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is used in the blue-line to inject the size of the temporaries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a comment to explain what's going on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a docstring to infer_expr and a back reference to that docstring here. |
||
horizontal_sizes = {k: im.ref(v) for k, v in symbolic_domain_sizes.items()} | ||
else: | ||
# note: ugly but cheap re-computation, but should disappear | ||
horizontal_sizes = { | ||
k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN) | ||
for k, v in _max_domain_sizes_by_location_type(offset_provider).items() | ||
} | ||
|
||
old_dim = nbt_provider.origin_axis | ||
new_dim = nbt_provider.neighbor_axis | ||
|
||
assert new_dim not in new_ranges or old_dim == new_dim | ||
|
||
# TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON? | ||
new_range = SymbolicRange( | ||
im.literal("0", itir.INTEGER_INDEX_BUILTIN), | ||
im.literal(str(horizontal_sizes[new_dim.value]), itir.INTEGER_INDEX_BUILTIN), | ||
horizontal_sizes[new_dim.value], | ||
) | ||
new_ranges = dict( | ||
(dim, range_) if dim != old_dim else (new_dim, new_range) | ||
|
@@ -140,7 +149,9 @@ def translate( | |
raise AssertionError() | ||
return SymbolicDomain(self.grid_type, new_ranges) | ||
elif len(shift) > 2: | ||
return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider) | ||
return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( | ||
shift[2:], offset_provider, symbolic_domain_sizes | ||
) | ||
else: | ||
raise AssertionError("Number of shifts must be a multiple of 2.") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already handled by ruff where it is additionally configurable on a line-by-line basis.