Skip to content

Commit

Permalink
fix[dace][next]: Fix for DistributedBufferRelocator (#1799)
Browse files Browse the repository at this point in the history
This PR fixes an error that was reported by Edoardo (@edopao).
The bug was because the `DistributedBufferRelocator` transformation did
not check if its insertion would create a read-write conflict. This
commit adds such a check, that is, however, not very sophisticated and
needs some improvements. However, the example
/`model/atmosphere/dycore/tests/dycore_stencil_tests/test_compute_exner_from_rhotheta.py`)
where it surfaced, does hold more challenges. The main purpose of this
PR is to unblock further development in ICON4Py.

Link to ICON4Py PR: C2SM/icon4py#638
  • Loading branch information
philip-paul-mueller authored Jan 16, 2025
1 parent 99b5042 commit 1b88276
Show file tree
Hide file tree
Showing 2 changed files with 406 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def apply(
raise


AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode]
AccessLocation: TypeAlias = tuple[dace_nodes.AccessNode, dace.SDFGState]
"""Describes an access node and the state in which it is located.
"""

Expand All @@ -387,29 +387,38 @@ class DistributedBufferRelocator(dace_transformation.Pass):
in each branch and then in the join state written back. Thus there is some
additional storage needed.
The transformation will look for the following situation:
- A transient data container, called `src_cont`, is written into another
container, called `dst_cont`, which is not transient.
- The access node of `src_cont` has an in degree of zero and an out degree of one.
- The access node of `dst_cont` has an in degree of of one and an
- A transient data container, called `temp_storage`, is written into another
container, called `dest_storage`, which is not transient.
- The access node of `temp_storage` has an in degree of zero and an out degree of one.
- The access node of `dest_storage` has an in degree of of one and an
out degree of zero (this might be lifted).
- `src_cont` is not used afterwards.
- `dst_cont` is only used to implement the buffering.
- `temp_storage` is not used afterwards.
- `dest_storage` is only used to implement the buffering.
The function will relocate the writing of `dst_cont` to where `src_cont` is
The function will relocate the writing of `dest_storage` to where `temp_storage` is
written, which might be multiple locations.
It will also remove the writing back.
It is advised that after this transformation simplify is run again.
The relocation will not take place if it might create data race. A necessary
but not sufficient condition for a data race is if `dest_storage` is present
in the state where `temp_storage` is defined. In addition at least one of the
following conditions has to be met:
- There are accesses to `dest_storage` that are not predecessor to the node where
the data is stored inside `temp_storage`. This check will ignore empty Memlets.
- There is a `dest_storage` access node, that has an output degree larger
than one.
Note:
Essentially this transformation removes the double buffering of `dst_cont`.
Because we ensure that that `dst_cont` is non transient this is okay, as our
rule guarantees this.
- Essentially this transformation removes the double buffering of
`dest_storage`. Because we ensure that that `dest_storage` is non
transient this is okay, as our rule guarantees this.
Todo:
- Allow that `dst_cont` can also be transient.
- Allow that `dst_cont` does not need to be a sink node, this is most
- Allow that `dest_storage` can also be transient.
- Allow that `dest_storage` does not need to be a sink node, this is most
likely most relevant if it is transient.
- Check if `dst_cont` is used between where we want to place it and
- Check if `dest_storage` is used between where we want to place it and
where it is currently used.
"""

Expand Down Expand Up @@ -489,10 +498,10 @@ def _find_candidates(
where the temporary is defined.
"""
# All nodes that are used as distributed buffers.
candidate_src_cont: list[AccessLocation] = []
candidate_temp_storage: list[AccessLocation] = []

# Which `src_cont` access node is written back to which global memory.
src_cont_to_global: dict[dace_nodes.AccessNode, str] = {}
# Which `temp_storage` access node is written back to which global memory.
temp_storage_to_global: dict[dace_nodes.AccessNode, str] = {}

for state in sdfg.states():
# These are the possible targets we want to write into.
Expand All @@ -508,26 +517,26 @@ def _find_candidates(
if len(candidate_dst_nodes) == 0:
continue

for src_cont in state.source_nodes():
if not isinstance(src_cont, dace_nodes.AccessNode):
for temp_storage in state.source_nodes():
if not isinstance(temp_storage, dace_nodes.AccessNode):
continue
if not src_cont.desc(sdfg).transient:
if not temp_storage.desc(sdfg).transient:
continue
if state.out_degree(src_cont) != 1:
if state.out_degree(temp_storage) != 1:
continue
dst_candidate: dace_nodes.AccessNode = next(
iter(edge.dst for edge in state.out_edges(src_cont))
iter(edge.dst for edge in state.out_edges(temp_storage))
)
if dst_candidate not in candidate_dst_nodes:
continue
candidate_src_cont.append((src_cont, state))
src_cont_to_global[src_cont] = dst_candidate.data
candidate_temp_storage.append((temp_storage, state))
temp_storage_to_global[temp_storage] = dst_candidate.data

if len(candidate_src_cont) == 0:
if len(candidate_temp_storage) == 0:
return []

# Now we have to find the places where the temporary sources are defined.
# I.e. This is also the location where the original value is defined.
# I.e. This is also the location where the temporary source was initialized.
result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = []

def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]:
Expand All @@ -537,72 +546,199 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]:
if dst_state in reachable[src_state] and dst_state is not src_state
}

for src_cont in candidate_src_cont:
for temp_storage in candidate_temp_storage:
temp_storage_node, temp_storage_state = temp_storage
def_locations: list[AccessLocation] = []
for upstream_state in find_upstream_states(src_cont[1]):
if src_cont[0].data in access_sets[upstream_state][1]:
for upstream_state in find_upstream_states(temp_storage_state):
if temp_storage_node.data in access_sets[upstream_state][1]:
def_locations.extend(
(data_node, upstream_state)
for data_node in upstream_state.data_nodes()
if data_node.data == src_cont[0].data
if data_node.data == temp_storage_node.data
)
if len(def_locations) != 0:
result_candidates.append((src_cont, def_locations))
result_candidates.append((temp_storage, def_locations))

# This transformation removes `src_cont` by writing its content directly
# to `dst_cont`, at the point where it is defined.
# This transformation removes `temp_storage` by writing its content directly
# to `dest_storage`, at the point where it is defined.
# For this transformation to be valid the following conditions have to be met:
# - Between the definition of `src_cont` and the write back to `dst_cont`,
# `dst_cont` can not be accessed.
# - Between the definitions of `src_cont` and the point where it is written
# back, `src_cont` can only be accessed in the range that is written back.
# - After the write back point, `src_cont` shall not be accessed. This
# - Between the definition of `temp_storage` and the write back to `dest_storage`,
# `dest_storage` can not be accessed.
# - Between the definitions of `temp_storage` and the point where it is written
# back, `temp_storage` can only be accessed in the range that is written back.
# - After the write back point, `temp_storage` shall not be accessed. This
# restriction could be lifted.
#
# To keep the implementation simple, we use the conditions:
# - `src_cont` is only accessed were it is defined and at the write back
# - `temp_storage` is only accessed were it is defined and at the write back
# point.
# - Between the definitions of `src_cont` and the write back point,
# `dst_cont` is not used.
# - Between the definitions of `temp_storage` and the write back point,
# `dest_storage` is not used.

result: list[tuple[AccessLocation, list[AccessLocation]]] = []

for wb_localation, def_locations in result_candidates:
for wb_location, def_locations in result_candidates:
# Get the state and the location where the temporary is written back
# into the global data container.
wb_node, wb_state = wb_location

for def_node, def_state in def_locations:
# Test if `src_cont` is only accessed where it is defined and
# Test if `temp_storage` is only accessed where it is defined and
# where it is written back.
if gtx_transformations.util.is_accessed_downstream(
start_state=def_state,
sdfg=sdfg,
data_to_look=wb_localation[0].data,
nodes_to_ignore={def_node, wb_localation[0]},
data_to_look=wb_node.data,
nodes_to_ignore={def_node, wb_node},
):
break
# check if the global data is not used between the definition of
# `dst_cont` and where its written back. We allow one exception,
# if the global data is used in the state the distributed temporary
# is defined is used only for reading then it is ignored. This is
# allowed because of rule 3 of ADR0018.
glob_nodes_in_def_state = {
dnode
for dnode in def_state.data_nodes()
if dnode.data == src_cont_to_global[wb_localation[0]]
# `dest_storage` and where its written back. However, we ignore
# the state were `temp_storage` is defined. The checks if these
# checks are performed by the `_check_read_write_dependency()`
# function.
global_data_name = temp_storage_to_global[wb_node]
global_nodes_in_def_state = {
dnode for dnode in def_state.data_nodes() if dnode.data == global_data_name
}
if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state):
break
if gtx_transformations.util.is_accessed_downstream(
start_state=def_state,
sdfg=sdfg,
data_to_look=src_cont_to_global[wb_localation[0]],
nodes_to_ignore=glob_nodes_in_def_state,
states_to_ignore={wb_localation[1]},
data_to_look=global_data_name,
nodes_to_ignore=global_nodes_in_def_state,
states_to_ignore={wb_state},
):
break
if self._check_read_write_dependency(sdfg, wb_location, def_locations):
break
else:
result.append((wb_localation, def_locations))
result.append((wb_location, def_locations))

return result

def _check_read_write_dependency(
self,
sdfg: dace.SDFG,
write_back_location: AccessLocation,
target_locations: list[AccessLocation],
) -> bool:
"""Tests if read-write conflicts would be created.
This function ensures that the substitution of `write_back_location` into
`target_locations` will not create a read-write conflict.
The rules that are used for this are outlined in the class description.
Args:
sdfg: The SDFG on which we operate.
write_back_location: Where currently the write back occurs.
target_locations: List of the locations where we would like to perform
the write back instead.
Returns:
If a read-write dependency is detected then the function will return
`True` and if none was detected `False` will be returned.
"""
for target_location in target_locations:
if self._check_read_write_dependency_impl(sdfg, write_back_location, target_location):
return True
return False

def _check_read_write_dependency_impl(
self,
sdfg: dace.SDFG,
write_back_location: AccessLocation,
target_location: AccessLocation,
) -> bool:
"""Tests if read-write conflict would be created for a single location.
Args:
sdfg: The SDFG on which we operate.
write_back_location: Where currently the write back occurs.
target_locations: Location where the new write back should be performed.
Todo:
Refine these checks later.
Returns:
If a read-write dependency is detected then the function will return
`True` and if none was detected `False` will be returned.
"""
assert write_back_location[0].data == target_location[0].data

# Get the state and the location where the temporary is written back
# into the global data container. Because `write_back_node` refers to
# the temporary we must query the graph to find the global node.
write_back_node, write_back_state = write_back_location
write_back_edge = next(iter(write_back_state.out_edges(write_back_node)))
global_data_name = write_back_edge.dst.data
assert not sdfg.arrays[global_data_name].transient
assert write_back_state.out_degree(write_back_node) == 1
assert write_back_state.in_degree(write_back_node) == 0

# Get the location and the state where the temporary is originally defined.
def_location_of_intermediate, state_to_inspect = target_location
assert state_to_inspect.out_degree(def_location_of_intermediate) == 0

# These are all access nodes that refers to the global data, that we want
# to move into the state `state_to_inspect`. We need them to do the
# second test.
accesses_to_global_data: set[dace_nodes.AccessNode] = set()

# In the first check we look for an access node, to the global data, that
# has an output degree larger than one. However, for this we ignore all
# empty Memlets. This is done because such Memlets are used to induce a
# schedule or order in the dataflow graph.
# As a byproduct, for the second test, we also collect all of these nodes.
for dnode in state_to_inspect.data_nodes():
if dnode.data != global_data_name:
continue
dnode_degree = sum(
(1 for oedge in state_to_inspect.out_edges(dnode) if not oedge.data.is_empty())
)
if dnode_degree > 1:
return True
# TODO(phimuell): Maybe AccessNodes with zero input degree should be ignored.
accesses_to_global_data.add(dnode)

# There is no reference to the global data, so no need to do more tests.
if len(accesses_to_global_data) == 0:
return False

# For the second test we will explore the dataflow graph, in reverse order,
# starting from the definition of the temporary node. If we find an access
# to the global data we remove it from the `accesses_to_global_data` list.
# If the list has not become empty, then we know that there is some sind
# branch (or concurrent dataflow) in this state that accesses the global
# data and we will have read-write conflicts.
# It is however, important to realize that passing this check does not
# imply that there are no read-write. We assume here that all accesses to
# the global data that was made before the write back were constructed in
# a correct way.
to_process: list[dace_nodes.Node] = [def_location_of_intermediate]
seen: set[dace_nodes.Node] = set()
while len(to_process) != 0:
node = to_process.pop()
seen.add(node)

if isinstance(node, dace_nodes.AccessNode):
if node.data == global_data_name:
accesses_to_global_data.discard(node)
if len(accesses_to_global_data) == 0:
return False

# Note that we only explore the ingoing edges, thus we will not necessarily
# explore the whole graph. However, this is fine, because we will see the
# relevant parts. To see that assume that we would also have to check the
# outgoing edges, this would mean that there was some branching point,
# which is a serialization point, so the dataflow would have been invalid
# before.
to_process.extend(
iedge.src for iedge in state_to_inspect.in_edges(node) if iedge.src not in seen
)

assert len(accesses_to_global_data) > 0
return True


@dace_properties.make_properties
class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation):
Expand Down
Loading

0 comments on commit 1b88276

Please sign in to comment.