diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c9cfb698b4..ae8a4558cd 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3744,18 +3744,20 @@ def _is_inputnode(self, sdfg: SDFG, name: str): for state in sdfg.states(): visited_state_data = set() for node in state.nodes(): - if isinstance(node, nodes.AccessNode) and node.data == name: - visited_state_data.add(node.data) - if (node.data not in visited_data and state.in_degree(node) == 0): - return True + if isinstance(node, nodes.AccessNode): + if node.data == name or ('.' in node.data and node.data.split('.')[0] == name): + visited_state_data.add(node.data) + if (node.data not in visited_data and state.in_degree(node) == 0): + return True visited_data = visited_data.union(visited_state_data) def _is_outputnode(self, sdfg: SDFG, name: str): for state in sdfg.states(): for node in state.nodes(): - if isinstance(node, nodes.AccessNode) and node.data == name: - if state.in_degree(node) > 0: - return True + if isinstance(node, nodes.AccessNode): + if node.data == name or ('.' in node.data and node.data.split('.')[0] == name): + if state.in_degree(node) > 0: + return True def _get_sdfg(self, value: Any, args: Tuple[Any], kwargs: Dict[str, Any]) -> SDFG: if isinstance(value, SDFG): # Already an SDFG diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 31e751bb6a..d62a9f0bf9 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -380,8 +380,11 @@ def apply(self, state: SDFGState, sdfg: SDFG): pass for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): - new_outgoing_edges[node] = outputs[node.data] - sink_accesses.add(node) + try: + new_outgoing_edges[node] = outputs[node.data] + sink_accesses.add(node) + except KeyError: + pass # All constants (and associated transients) become constants of the parent for cstname, (csttype, cstval) in nsdfg.constants_prop.items(): @@ -427,11 +430,26 @@ def apply(self, state: SDFGState, sdfg: SDFG): orig_data: Dict[Union[nodes.AccessNode, MultiConnectorEdge], str] = {} for node in nstate.nodes(): - if isinstance(node, nodes.AccessNode) and node.data in repldict: - orig_data[node] = node.data - node.data = repldict[node.data] + if isinstance(node, nodes.AccessNode): + if '.' in node.data: + parts = node.data.split('.') + root_container = parts[0] + if root_container in repldict: + orig_data[node] = node.data + full_data = [repldict[root_container]] + parts[1:] + node.data = '.'.join(full_data) + elif node.data in repldict: + orig_data[node] = node.data + node.data = repldict[node.data] for edge in nstate.edges(): - if edge.data.data in repldict: + if edge.data.data is not None and '.' in edge.data.data: + parts = edge.data.data.split('.') + root_container = parts[0] + if root_container in repldict: + orig_data[edge] = edge.data.data + full_data = [repldict[root_container]] + parts[1:] + edge.data.data = '.'.join(full_data) + elif edge.data.data in repldict: orig_data[edge] = edge.data.data edge.data.data = repldict[edge.data.data] @@ -557,13 +575,19 @@ def apply(self, state: SDFGState, sdfg: SDFG): for edge in removed_in_edges: # Find first access node that refers to this edge try: - node = next(n for n in order if n.data == edge.data.data) + node = next(n for n in order + if n.data == edge.data.data or ('.' in n.data and n.data.split('.')[0] == edge.data.data)) except StopIteration: continue # raise NameError(f'Access node with data "{edge.data.data}" not found in' # f' nested SDFG "{nsdfg.name}" while inlining ' # '(reconnecting inputs)') - state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) + if node.data != edge.data.data: + anode = state.add_access(edge.data.data) + state.add_edge(edge.src, edge.src_conn, anode, edge.dst_conn, edge.data) + state.add_edge(anode, None, node, None, Memlet()) + else: + state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) # Fission state if necessary cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): @@ -571,13 +595,19 @@ def apply(self, state: SDFGState, sdfg: SDFG): for edge in removed_out_edges: # Find last access node that refers to this edge try: - node = next(n for n in reversed(order) if n.data == edge.data.data) + node = next(n for n in reversed(order) + if n.data == edge.data.data or ('.' in n.data and n.data.split('.')[0] == edge.data.data)) except StopIteration: continue # raise NameError(f'Access node with data "{edge.data.data}" not found in' # f' nested SDFG "{nsdfg.name}" while inlining ' # '(reconnecting outputs)') - state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) + if node.data != edge.data.data: + anode = state.add_access(edge.data.data) + state.add_edge(node, None, anode, None, Memlet()) + state.add_edge(anode, edge.src_conn, edge.dst, edge.dst_conn, edge.data) + else: + state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) # Fission state if necessary cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()):