Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: iterator-view support to DaCe backend #1790

Merged
merged 143 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
df1847a
scan - working draft
edopao Nov 29, 2024
89ca8f7
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 3, 2024
f22eb64
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
c26d906
Improve utility functions for tuples
edopao Dec 4, 2024
ba0a9ba
Fix for empty field domain
edopao Dec 4, 2024
ac7acf8
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
877d81e
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
8baf6d1
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
784b573
Add exclusive if_ in dataflow
edopao Dec 5, 2024
de9c9de
Better handling of isolated nodes
edopao Dec 5, 2024
14e66e8
Fix field offset in nested SDFG context
edopao Dec 6, 2024
fcfaf72
fix problem with dereferencil of 1D vertical fields inside scan
edopao Dec 6, 2024
79204ee
generalize previous fix to all scan input fields
edopao Dec 6, 2024
5fe461a
minor edit
edopao Dec 6, 2024
a4bde3a
fix out-of-bound access
edopao Dec 6, 2024
c75a8e4
Better handling of isolated nodes
edopao Dec 6, 2024
6f72cac
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
397acae
exclude scan tests on dace backend with optimizations
edopao Dec 6, 2024
acf5ac0
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
a706b27
fix pre-commit
edopao Dec 6, 2024
c22cfc8
fix doctest
edopao Dec 6, 2024
59e0ed5
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
792a8eb
temporarily disable one optimize transformation
edopao Dec 9, 2024
61985f7
Revert "temporarily disable one optimize transformation"
edopao Dec 9, 2024
aa236a2
fix for scan output stride
edopao Dec 10, 2024
9bdc75b
fix previous commit
edopao Dec 10, 2024
746f9d8
converto scalar to array on nsdfg output
edopao Dec 10, 2024
0d894ff
Revert "converto scalar to array on nsdfg output"
edopao Dec 11, 2024
440a474
Split handling of let-statement lambdas from stencil body
edopao Dec 11, 2024
500590b
minor edit
edopao Dec 11, 2024
c56e062
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 12, 2024
5d5992a
use dace auto-optimize on gpu
edopao Dec 12, 2024
c167def
Merge remote-tracking branch 'origin/dace-gtir-scan' into dace-gtir-scan
edopao Dec 12, 2024
eb17345
Revert "use dace auto-optimize on gpu"
edopao Dec 12, 2024
8b163da
make map_strides recursive
edopao Dec 12, 2024
d15213a
rename module alias
edopao Dec 13, 2024
55811dc
review comments
edopao Dec 13, 2024
8f0e515
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 13, 2024
f01d291
add test case for sdfg transformation
edopao Dec 13, 2024
62e1648
review comments (1)
edopao Dec 16, 2024
72e8830
review comments (2)
edopao Dec 16, 2024
39aeb20
Merge branch 'dace-refact-lambda' into dace-gtir-scan
edopao Dec 16, 2024
de4a80e
review comments (2)
edopao Dec 16, 2024
45f9927
Merge remote-tracking branch 'origin/main' into dace-refact-lambda
edopao Dec 16, 2024
3fe538b
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 16, 2024
ee62266
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 16, 2024
4b0ac60
Propagate strides to nested SDFG when changing transient strides
edopao Dec 16, 2024
f701605
rename function
edopao Dec 16, 2024
a19019f
fix bug
edopao Dec 16, 2024
c03492c
fix previous commit
edopao Dec 16, 2024
310fcce
Test commit
edopao Dec 16, 2024
4b487ea
propagate strides also to destination nested SDFG
edopao Dec 16, 2024
4cf66e7
fix previous commit (skip scalar inner nodes)
edopao Dec 16, 2024
ab7ee5f
fix - do not call free_symbols on int stride
edopao Dec 17, 2024
82cf491
run simplify before gpu transformations
edopao Dec 17, 2024
a0dbea5
undo renaming graph -> state
edopao Dec 17, 2024
9128ffb
increase slurm timeout to 20 minutes
edopao Dec 17, 2024
f940c4e
increase slurm timeout to 30 minutes
edopao Dec 17, 2024
cc0777b
minor edit
edopao Dec 17, 2024
462f3c5
exclude test_ternary_scan from gpu tests
edopao Dec 17, 2024
d9218b6
This are the changes Edoardo implemented to fix some issues in the op…
edopao Dec 17, 2024
9d7e722
First rework.
philip-paul-mueller Dec 18, 2024
1ddd6fe
Updated some commenst.
philip-paul-mueller Dec 18, 2024
95e0007
I want to ignore register, not only consider them.
philip-paul-mueller Dec 18, 2024
f1b7a3f
There was a missing `not` in the check.
philip-paul-mueller Dec 18, 2024
50ad620
Had to update the propagation, to also handle aliasing.
philip-paul-mueller Dec 18, 2024
983022c
In the function for looking for top level accesses the `only_transien…
philip-paul-mueller Dec 18, 2024
e7b1afb
Small reminder of the future.
philip-paul-mueller Dec 18, 2024
df7bd0c
Forgot to export the new SDFG stuff.
philip-paul-mueller Dec 18, 2024
363ab59
Had to update function for actuall renaming of the strides.
philip-paul-mueller Dec 18, 2024
9c19d32
Added a todo to the replacement function.
philip-paul-mueller Dec 18, 2024
9cad1f7
Added a first test to the propagation function.
philip-paul-mueller Dec 18, 2024
2700f53
Modified the function that performs the actuall modification of the s…
philip-paul-mueller Dec 19, 2024
a20d3c0
Updated some tes, but more are missing.
philip-paul-mueller Dec 19, 2024
b5ff462
Subset caching strikes again.
philip-paul-mueller Dec 19, 2024
d326d3b
It seems that the explicit handling of one dimensions is not working.
philip-paul-mueller Dec 19, 2024
252f348
The test must be moved bellow.
philip-paul-mueller Dec 19, 2024
49f8172
The symbol is also needed to be present in the nested SDFG.
philip-paul-mueller Dec 19, 2024
2d6dfc0
Fixed a bug in determining the free symbols that we need.
philip-paul-mueller Dec 19, 2024
6124c6d
Updated the propagation code for the symbols.
philip-paul-mueller Dec 19, 2024
45bcf97
Addressed Edoardo's changes.
philip-paul-mueller Dec 19, 2024
23b0baa
Updated how we get the type of symbols.
philip-paul-mueller Dec 19, 2024
ff05880
New restriction on the update of the symbol mapping.
philip-paul-mueller Dec 19, 2024
43ec33c
Updated the tests, now also made one that has tests for the symbol ma…
philip-paul-mueller Dec 19, 2024
d43153a
Fixed two bug in the stride propagation function.
philip-paul-mueller Dec 19, 2024
2e82bd5
Added a test that ensures that the dependent adding works.
philip-paul-mueller Dec 19, 2024
07e6a5c
Changed the default of `ignore_symbol_mapping` to `True`.
philip-paul-mueller Dec 19, 2024
4bf145b
Added Edoardo's comments.
philip-paul-mueller Dec 19, 2024
2b03bb4
Removed the creation of aliasing if symbol tables are ignored.
philip-paul-mueller Dec 20, 2024
40c225d
Added a test that shows that `ignore_symbol_mapping=False` does produ…
philip-paul-mueller Dec 20, 2024
419a386
Updated the description.
philip-paul-mueller Dec 20, 2024
cc9801b
Applied Edoardo's comment.
philip-paul-mueller Dec 20, 2024
360baae
Added a todo from Edoardo's suggestions.
philip-paul-mueller Dec 20, 2024
f2396c4
Merge remote-tracking branch 'philip/dace-gtir-better-strides' into d…
edopao Dec 20, 2024
a0c37cb
minor edit
edopao Dec 20, 2024
45c69ec
Merge branch 'main' into dace-gtir-scan
edopao Dec 20, 2024
0f9043b
fix for missing symbols in nested sdfg
edopao Dec 20, 2024
059a448
wip - fix iterator tests
edopao Dec 20, 2024
b8fe277
disable tests with sparse fields
edopao Jan 7, 2025
0dd4b4e
disable unsupported features
edopao Jan 8, 2025
2d3238c
fix for if_ lowering
edopao Jan 8, 2025
aec47c8
lowering of tuple_deref
edopao Jan 8, 2025
1f68857
lowering of tuple iterators
edopao Jan 8, 2025
c20728d
allow tuple fields with different size
edopao Jan 8, 2025
d9691a8
Merge remote-tracking branch 'origin/main' into dace-gtir-iterator_view
edopao Jan 8, 2025
312e69c
add scan test marker
edopao Jan 8, 2025
65b4dd2
undo lowering of scan
edopao Jan 8, 2025
61b06b3
Minor edit based on review comments
edopao Jan 8, 2025
62a2a80
ignore atlas tests
edopao Jan 8, 2025
aa9f999
undo scan-related change
edopao Jan 8, 2025
7508e03
Minor edit based on review comments
edopao Jan 8, 2025
43f2e40
fix
edopao Jan 8, 2025
de203f9
Revert "undo scan-related change"
edopao Jan 8, 2025
ab11d77
fix previous commits
edopao Jan 8, 2025
7329c4b
update test skip list
edopao Jan 9, 2025
fab8288
fix gtir dace tests (add tuple symbols)
edopao Jan 9, 2025
46322ac
undo extra change
edopao Jan 9, 2025
2c1156b
remove support for tuple iterator
edopao Jan 9, 2025
ac24404
fix test marker
edopao Jan 9, 2025
a414eba
move 2 nested function definitions to separate helper functions
edopao Jan 9, 2025
015f69c
edit test markers
edopao Jan 9, 2025
f05a730
edit test markers
edopao Jan 9, 2025
2363b62
Revert "edit test markers"
edopao Jan 9, 2025
fd1462d
Merge remote-tracking branch 'origin/main' into dace-gtir-iterator_view
edopao Jan 10, 2025
db94493
remove wrong assert
edopao Jan 10, 2025
f7b18b3
edit code comments
edopao Jan 10, 2025
87b5bd5
add tuple_get
edopao Jan 10, 2025
d93a387
better symbol mapping for lambda nested SDFG
edopao Jan 13, 2025
678b782
Revert "add tuple_get"
edopao Jan 13, 2025
eaaee4e
Merge remote-tracking branch 'origin/main' into dace-gtir-iterator_view
edopao Jan 13, 2025
d4599d2
address review comments
edopao Jan 13, 2025
4a82810
fix
edopao Jan 13, 2025
81b5fd3
fix subset num_elements
edopao Jan 13, 2025
ea44598
address review comments (1)
edopao Jan 14, 2025
981f2c7
fix test markers
edopao Jan 14, 2025
6035ccc
address review comments (2)
edopao Jan 14, 2025
bed7e0d
fix previous commit
edopao Jan 14, 2025
3529964
better tuple symbol tree
edopao Jan 15, 2025
974d643
rename sym_tree to symbol_tree
edopao Jan 15, 2025
6f4ff65
helper function add_temp_array
edopao Jan 15, 2025
46879b7
make _visit_if_branch_result separate function
edopao Jan 15, 2025
4437108
fix doc test
edopao Jan 15, 2025
d7671a7
address review comment (1)
edopao Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def decorated_program(
args = (*args, out)
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
if len(sdfg.arg_names) > len(flat_args):
# The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments.
flat_args = (*flat_args, *arguments.iter_size_args(args))

if sdfg_program._lastargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,72 @@ def _get_field_layout(
return list(domain_dims), list(domain_lbs), domain_sizes


def _create_field_operator_impl(
sdfg: dace.SDFG,
state: dace.SDFGState,
domain: FieldopDomain,
output_edge: gtir_dataflow.DataflowOutputEdge,
output_type: ts.FieldType,
map_exit: dace.nodes.MapExit,
) -> FieldopData:
"""
Helper method to allocate a temporary array that stores one field computed by a field operator.

This method is called by `_create_field_operator()`.

Args:
sdfg: The SDFG that represents the scope of the field data.
state: The SDFG state where to create an access node to the field data.
domain: The domain of the field operator that computes the field.
output_edge: The dataflow write edge representing the output data.
output_type: The GT4Py field type descriptor.
map_exit: The `MapExit` node of the field operator map scope.

Returns:
The field data descriptor, which includes the field access node in the given `state`
and the field domain offset.
"""
dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)

domain_dims, domain_offset, domain_shape = _get_field_layout(domain)
domain_indices = _get_domain_indices(domain_dims, domain_offset)
domain_subset = dace_subsets.Range.from_indices(domain_indices)

if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == output_type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
assert isinstance(dataflow_output_desc, dace.data.Scalar)
field_subset = domain_subset
else:
assert isinstance(output_type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type
field_dtype = output_edge.result.gt_dtype.element_type
assert isinstance(dataflow_output_desc, dace.data.Array)
assert len(dataflow_output_desc.shape) == 1
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
assert output_edge.result.gt_dtype.offset_type is not None
field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type]
field_shape = [*domain_shape, dataflow_output_desc.shape[0]]
field_offset = [*domain_offset, dataflow_output_desc.offset[0]]
field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype)
field_node = state.add_access(field_name)

# and here the edge writing the dataflow result data through the map exit node
output_edge.connect(map_exit, field_node, field_subset)

return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
)


def _create_field_operator(
sdfg: dace.SDFG,
state: dace.SDFGState,
Expand All @@ -275,7 +341,10 @@ def _create_field_operator(
| tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...],
) -> FieldopResult:
"""
Helper method to allocate a temporary field to store the output of a field operator.
Helper method to build the output of a field operator, which can consist of
a single field or, in case of scan, a tuple of fields.
edopao marked this conversation as resolved.
Show resolved Hide resolved
The scan field operator typically computes multiple fields for each K-level:
for each field, this method will call `_create_field_operator_impl()`.

Args:
sdfg: The SDFG that represents the scope of the field data.
Expand All @@ -287,55 +356,12 @@ def _create_field_operator(
output_edges: Single edge or tuple of edges representing the dataflow output data.

Returns:
The field data descriptor, which includes the field access node in the given `state`
and the field domain offset.
The descriptor of the field operator result, which can be either a single field
or a tuple fields.
"""

def _create_field_operator_impl(
output_edge: gtir_dataflow.DataflowOutputEdge, mx: dace.nodes.MapExit, sym: gtir.Sym
) -> FieldopData:
assert isinstance(sym.type, ts.FieldType)
dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)
if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == sym.type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
assert isinstance(dataflow_output_desc, dace.data.Scalar)
field_subset = domain_subset
else:
assert isinstance(sym.type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
assert output_edge.result.gt_dtype.element_type == sym.type.dtype.element_type
field_dtype = output_edge.result.gt_dtype.element_type
assert isinstance(dataflow_output_desc, dace.data.Array)
assert len(dataflow_output_desc.shape) == 1
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
assert output_edge.result.gt_dtype.offset_type is not None
field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type]
field_shape = [*domain_shape, dataflow_output_desc.shape[0]]
field_offset = [*domain_offset, dataflow_output_desc.offset[0]]
field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype)
field_node = state.add_access(field_name)

# and here the edge writing the dataflow result data through the map exit node
output_edge.connect(mx, field_node, field_subset)

return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
)

domain_dims, domain_offset, domain_shape = _get_field_layout(domain)
domain_indices = _get_domain_indices(domain_dims, domain_offset)
domain_subset = dace_subsets.Range.from_indices(domain_indices)

# create map range corresponding to the field operator domain
me, mx = sdfg_builder.add_map(
map_entry, map_exit = sdfg_builder.add_map(
"fieldop",
state,
ndrange={
Expand All @@ -346,17 +372,19 @@ def _create_field_operator_impl(

# here we setup the edges passing through the map entry node
for edge in input_edges:
edge.connect(me)
edge.connect(map_entry)

if isinstance(output_edges, gtir_dataflow.DataflowOutputEdge):
assert isinstance(node_type, ts.FieldType)
return _create_field_operator_impl(output_edges, mx, im.sym("x", node_type))
if isinstance(node_type, ts.FieldType):
assert isinstance(output_edges, gtir_dataflow.DataflowOutputEdge)
return _create_field_operator_impl(sdfg, state, domain, output_edges, node_type, map_exit)
else:
# handle tuples of fields
assert isinstance(node_type, ts.TupleType)
tuple_syms = dace_gtir_utils.make_symbol_tuple("x", node_type)
return gtx_utils.tree_map(
lambda output_edge, sym: _create_field_operator_impl(output_edge, mx, sym)
)(output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type))
lambda output_edge, output_sym: _create_field_operator_impl(
sdfg, state, domain, output_edge, output_sym.type, map_exit
)
)(output_edges, tuple_syms)


def extract_domain(node: gtir.Node) -> FieldopDomain:
Expand Down Expand Up @@ -444,7 +472,7 @@ def translate_as_fieldop(
fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
input_edges, output_edges = gtir_dataflow.apply(
input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow(
sdfg, state, sdfg_builder, stencil_expr, fieldop_args
)

Expand Down
Loading
Loading