Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for creating nested SDFGs with structs and inlining them #1888

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3735,18 +3735,20 @@ def _is_inputnode(self, sdfg: SDFG, name: str):
for state in sdfg.states():
visited_state_data = set()
for node in state.nodes():
if isinstance(node, nodes.AccessNode) and node.data == name:
visited_state_data.add(node.data)
if (node.data not in visited_data and state.in_degree(node) == 0):
return True
if isinstance(node, nodes.AccessNode):
if node.data == name or ('.' in node.data and node.data.split('.')[0] == name):
visited_state_data.add(node.data)
if (node.data not in visited_data and state.in_degree(node) == 0):
return True
visited_data = visited_data.union(visited_state_data)

def _is_outputnode(self, sdfg: SDFG, name: str):
for state in sdfg.states():
for node in state.nodes():
if isinstance(node, nodes.AccessNode) and node.data == name:
if state.in_degree(node) > 0:
return True
if isinstance(node, nodes.AccessNode):
if node.data == name or ('.' in node.data and node.data.split('.')[0] == name):
if state.in_degree(node) > 0:
return True

def _get_sdfg(self, value: Any, args: Tuple[Any], kwargs: Dict[str, Any]) -> SDFG:
if isinstance(value, SDFG): # Already an SDFG
Expand Down
50 changes: 40 additions & 10 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,11 @@ def apply(self, state: SDFGState, sdfg: SDFG):
pass
for node in nstate.sink_nodes():
if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes):
new_outgoing_edges[node] = outputs[node.data]
sink_accesses.add(node)
try:
new_outgoing_edges[node] = outputs[node.data]
sink_accesses.add(node)
except KeyError:
pass

# All constants (and associated transients) become constants of the parent
for cstname, (csttype, cstval) in nsdfg.constants_prop.items():
Expand Down Expand Up @@ -427,11 +430,26 @@ def apply(self, state: SDFGState, sdfg: SDFG):

orig_data: Dict[Union[nodes.AccessNode, MultiConnectorEdge], str] = {}
for node in nstate.nodes():
if isinstance(node, nodes.AccessNode) and node.data in repldict:
orig_data[node] = node.data
node.data = repldict[node.data]
if isinstance(node, nodes.AccessNode):
if '.' in node.data:
parts = node.data.split('.')
root_container = parts[0]
if root_container in repldict:
orig_data[node] = node.data
full_data = [repldict[root_container]] + parts[1:]
node.data = '.'.join(full_data)
elif node.data in repldict:
orig_data[node] = node.data
node.data = repldict[node.data]
for edge in nstate.edges():
if edge.data.data in repldict:
if edge.data.data is not None and '.' in edge.data.data:
parts = edge.data.data.split('.')
root_container = parts[0]
if root_container in repldict:
orig_data[edge] = edge.data.data
full_data = [repldict[root_container]] + parts[1:]
edge.data.data = '.'.join(full_data)
elif edge.data.data in repldict:
orig_data[edge] = edge.data.data
edge.data.data = repldict[edge.data.data]

Expand Down Expand Up @@ -557,27 +575,39 @@ def apply(self, state: SDFGState, sdfg: SDFG):
for edge in removed_in_edges:
# Find first access node that refers to this edge
try:
node = next(n for n in order if n.data == edge.data.data)
node = next(n for n in order
if n.data == edge.data.data or ('.' in n.data and n.data.split('.')[0] == edge.data.data))
except StopIteration:
continue
# raise NameError(f'Access node with data "{edge.data.data}" not found in'
# f' nested SDFG "{nsdfg.name}" while inlining '
# '(reconnecting inputs)')
state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data)
if node.data != edge.data.data:
anode = state.add_access(edge.data.data)
state.add_edge(edge.src, edge.src_conn, anode, edge.dst_conn, edge.data)
state.add_edge(anode, None, node, None, Memlet())
else:
state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data)
# Fission state if necessary
cc = utils.weakly_connected_component(state, node)
if not any(n in cc for n in subgraph.nodes()):
helpers.state_fission(cc)
for edge in removed_out_edges:
# Find last access node that refers to this edge
try:
node = next(n for n in reversed(order) if n.data == edge.data.data)
node = next(n for n in reversed(order)
if n.data == edge.data.data or ('.' in n.data and n.data.split('.')[0] == edge.data.data))
except StopIteration:
continue
# raise NameError(f'Access node with data "{edge.data.data}" not found in'
# f' nested SDFG "{nsdfg.name}" while inlining '
# '(reconnecting outputs)')
state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
if node.data != edge.data.data:
anode = state.add_access(edge.data.data)
state.add_edge(node, None, anode, None, Memlet())
state.add_edge(anode, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
else:
state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
# Fission state if necessary
cc = utils.weakly_connected_component(state, node)
if not any(n in cc for n in subgraph.nodes()):
Expand Down
Loading