diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index f5559984e7..f3cad90c51 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -60,7 +60,8 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union import networkx as nx import sympy as sp -from dace import dtypes +import dace +from dace import dtypes, symbolic from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnBlock, SDFGState) @@ -465,6 +466,22 @@ def first_block(self) -> SDFGState: def children(self) -> List[ControlFlow]: return [self.body] + def loop_range(self) -> Optional[Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType]]: + """ + For well-formed loops, returns a tuple of (start, end, stride). Otherwise, returns None. + """ + from dace.transformation.interstate.loop_detection import find_for_loop + sdfg = self.guard.parent + for e in sdfg.out_edges(self.guard): + if e.data.condition == self.condition: + break + else: + return None # Condition edge not found + result = find_for_loop(sdfg, self.guard, e.dst, self.itervar) + if result is None: + return None + return result[1] + @dataclass class WhileScope(ControlFlow): diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 9357ca3db9..c3a9ea049b 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict import copy from typing import Dict, List, Set @@ -325,6 +325,29 @@ def remove_name_collisions(sdfg: SDFG): nsdfg.replace_dict(replacements) +def create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot): + """ + Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes + data containers, symbols, constants, etc. + + :param sdfg: The top-level SDFG to create the repository from. + :param stree: The tree root in which to make the unified descriptor repository. + """ + stree.containers = sdfg.arrays + stree.symbols = sdfg.symbols + stree.constants = sdfg.constants_prop + + # Since the SDFG is assumed to be de-aliased and contain unique names, we union the contents of + # the nested SDFGs' descriptor repositories + for nsdfg in sdfg.all_sdfgs_recursive(): + transients = {k: v for k, v in nsdfg.arrays.items() if v.transient} + symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols} + constants = {k: v for k, v in nsdfg.constants_prop.items() if k not in stree.constants} + stree.containers.update(transients) + stree.symbols.update(symbols) + stree.constants.update(constants) + + def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, viewed_name: str) -> tn.ViewNode: """ @@ -614,7 +637,7 @@ def _generate_views_in_scope(edges: List[gr.MultiConnectorEdge[Memlet]], return result -def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope: +def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeRoot: """ Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. @@ -648,7 +671,6 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) dealias_sdfg(sdfg) # Handle name collisions (in arrays, state labels, symbols) remove_name_collisions(sdfg) - ############################# # Create initial tree from CFG @@ -754,7 +776,18 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche return result # Recursive traversal of the control flow tree - result = tn.ScheduleTreeScope(children=totree(cfg)) + children = totree(cfg) + + # Create the scope object + if toplevel: + # Create the root with the elements of the descriptor repository + result = tn.ScheduleTreeRoot(name=sdfg.name, + children=children, + arg_names=sdfg.arg_names, + callback_mapping=sdfg.callback_mapping) + create_unified_descriptor_repository(sdfg, result) + else: + result = tn.ScheduleTreeScope(children=children) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py new file mode 100644 index 0000000000..9a7c181209 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -0,0 +1,175 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from collections import defaultdict +from dace.memlet import Memlet +from dace.sdfg import nodes, memlet_utils as mmu +from dace.sdfg.sdfg import SDFG, ControlFlowRegion +from dace.sdfg.state import SDFGState +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from enum import Enum, auto +from typing import Dict, List, Set, Union + + +class StateBoundaryBehavior(Enum): + STATE_TRANSITION = auto() #: Creates multiple states with a state transition + EMPTY_MEMLET = auto() #: Happens-before empty memlet edges in the same state + + +def from_schedule_tree(stree: tn.ScheduleTreeRoot, + state_boundary_behavior: StateBoundaryBehavior = StateBoundaryBehavior.STATE_TRANSITION) -> SDFG: + """ + Converts a schedule tree into an SDFG. + + :param stree: The schedule tree root to convert. + :param state_boundary_behavior: Sets the behavior upon encountering a state boundary (e.g., write-after-write). + See the ``StateBoundaryBehavior`` enumeration for more details. + :return: An SDFG representing the schedule tree. + """ + # Set SDFG descriptor repository + result = SDFG(stree.name, propagate=False) + result.arg_names = copy.deepcopy(stree.arg_names) + result._arrays = copy.deepcopy(stree.containers) + result.constants_prop = copy.deepcopy(stree.constants) + result.symbols = copy.deepcopy(stree.symbols) + + # TODO: Fill SDFG contents + stree = insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. + + # TODO: create_state_boundary + # TODO: When creating a state boundary, include all inter-state assignments that precede it. + # TODO: create_loop_block + # TODO: create_conditional_block + # TODO: create_dataflow_scope + + return result + + +def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: + """ + Inserts StateBoundaryNode objects into a schedule tree where more than one SDFG state would be necessary. + Operates in-place on the given schedule tree. + + This happens when there is a: + * write-after-write dependency; + * write-after-read dependency that cannot be fulfilled via memlets; + * control flow block (for/if); or + * otherwise before a state label (which means a state transition could occur, e.g., in a gblock) + + :param stree: The schedule tree to operate on. + """ + + # Simple boundary node inserter for control flow blocks and state labels + class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer): + + def visit_scope(self, scope: tn.ScheduleTreeScope): + if isinstance(scope, tn.ControlFlowScope): + return [tn.StateBoundaryNode(True), self.generic_visit(scope)] + return self.generic_visit(scope) + + def visit_StateLabel(self, node: tn.StateLabel): + return [tn.StateBoundaryNode(True), self.generic_visit(node)] + + # First, insert boundaries around labels and control flow + stree = SimpleStateBoundaryInserter().visit(stree) + + # Then, insert boundaries after unmet memory dependencies or potential data races + _insert_memory_dependency_state_boundaries(stree) + + return stree + + +def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): + """ + Helper function that inserts boundaries after unmet memory dependencies. + """ + reads: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict() + writes: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict() + parents: Dict[int, Set[int]] = defaultdict(set) + boundaries_to_insert: List[int] = [] + + for i, n in enumerate(scope.children): + if isinstance(n, (tn.StateBoundaryNode, tn.ControlFlowScope)): # Clear state + reads.clear() + writes.clear() + parents.clear() + if isinstance(n, tn.ControlFlowScope): # Insert memory boundaries recursively + _insert_memory_dependency_state_boundaries(n) + continue + + # If dataflow scope, insert state boundaries recursively and as a node + if isinstance(n, tn.DataflowScope): + _insert_memory_dependency_state_boundaries(n) + + inputs = n.input_memlets() + outputs = n.output_memlets() + + # Register reads + for inp in inputs: + if inp not in reads: + reads[inp] = [n] + else: + reads[inp].append(n) + + # Transitively add parents + if inp in writes: + for parent in writes[inp]: + parents[id(n)].add(id(parent)) + parents[id(n)].update(parents[id(parent)]) + + # Inter-state assignment nodes with reads necessitate a state transition if they were written to. + if isinstance(n, tn.AssignNode) and any(inp in writes for inp in inputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Write after write or potential write/write data race, insert state boundary + if any(o in writes and (o not in reads or any(id(r) not in parents for r in reads[o])) for o in outputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Potential read/write data race: if any read is not in the parents of this node, it might + # be performed in parallel + if any(o in reads and any(id(r) not in parents for r in reads[o]) for o in outputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Register writes after all hazards have been tested for + for out in outputs: + if out not in writes: + writes[out] = [n] + else: + writes[out].append(n) + + # Insert memory dependency state boundaries in reverse in order to keep indices intact + for i in reversed(boundaries_to_insert): + scope.children.insert(i, tn.StateBoundaryNode()) + + +############################################################################# +# SDFG content creation functions + + +def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowRegion, state: SDFGState, + behavior: StateBoundaryBehavior) -> SDFGState: + """ + Creates a boundary between two states + + :param bnode: The state boundary node to generate. + :param sdfg_region: The control flow block in which to generate the boundary (e.g., SDFG). + :param state: The last state prior to this boundary. + :param behavior: The state boundary behavior with which to create the boundary. + :return: The newly created state. + """ + # TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every + # behavior. Fall back to state transition in that case. + scope: tn.ControlFlowScope = bnode.parent + assert scope is not None + pass diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index dabd436b56..eec66b0524 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,13 +1,16 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import ast from dataclasses import dataclass, field from dace import nodes, data, subsets from dace.codegen import control_flow as cf from dace.properties import CodeBlock -from dace.sdfg import InterstateEdge +from dace.sdfg.memlet_utils import MemletSet +from dace.sdfg.propagation import propagate_subset +from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet -from typing import Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union INDENTATION = ' ' @@ -29,12 +32,35 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: """ yield self + def get_root(self) -> 'ScheduleTreeRoot': + if self.parent is None: + raise ValueError('Non-root schedule tree node has no parent') + return self.parent.get_root() + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + """ + Returns a set of inputs for this node. For scopes, returns the union of its contents. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :return: A set of memlets representing the inputs of this node. + """ + raise NotImplementedError + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + """ + Returns a set of outputs for this node. For scopes, returns the union of its contents. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :return: A set of memlets representing the inputs of this node. + """ + raise NotImplementedError + @dataclass class ScheduleTreeScope(ScheduleTreeNode): children: List['ScheduleTreeNode'] - containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) - symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): self.children = children or [] @@ -57,7 +83,121 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: for child in self.children: yield from child.preorder_traversal() - # TODO: Helper function that gets input/output memlets of the scope + def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoot'], keep_locals: bool, + propagate: Dict[str, + subsets.Range], disallow_propagation: Set[str], **kwargs) -> MemletSet: + gather = (lambda n, root: n.input_memlets(root, **kwargs)) if inputs else ( + lambda n, root: n.output_memlets(root, **kwargs)) + + # Fast path, no propagation necessary + if keep_locals: + return MemletSet().union(*(gather(c) for c in self.children)) + + root = root if root is not None else self.get_root() + + if propagate: + to_propagate = list(propagate.items()) + propagate_keys = [a[0] for a in to_propagate] + propagate_values = subsets.Range([a[1] for a in to_propagate]) + + current_locals = set() + current_locals |= disallow_propagation + result = MemletSet() + + # Loop over children in order, if any new symbol is defined within this scope (e.g., symbol assignment, + # dynamic map range), consider it as a new local + for c in self.children: + # Add new locals + if isinstance(c, AssignNode): + current_locals.add(c.name) + elif isinstance(c, DynScopeCopyNode): + current_locals.add(c.target) + + internal_memlets: MemletSet = gather(c, root) + if propagate: + for memlet in internal_memlets: + result.add( + propagate_subset([memlet], + root.containers[memlet.data], + propagate_keys, + propagate_values, + undefined_variables=current_locals, + use_dst=not inputs)) + + return result + + def input_memlets(self, + root: Optional['ScheduleTreeRoot'] = None, + keep_locals: bool = False, + propagate: Optional[Dict[str, subsets.Range]] = None, + disallow_propagation: Optional[Set[str]] = None, + **kwargs) -> MemletSet: + """ + Returns a union of the set of inputs for this scope. Propagates the memlets used in the scope if ``keep_locals`` + is set to False. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :param keep_locals: If True, keeps the local symbols defined within the scope as part of the resulting memlets. + Otherwise, performs memlet propagation (see ``propagate`` and ``disallow_propagation``) or + assumes the entire container is used. + :param propagate: An optional dictionary mapping symbols to their corresponding ranges outside of this scope. + For example, the range of values a for-loop may take. + If ``keep_locals`` is False, this dictionary will be used to create projection memlets over + the ranges. See :ref:`memprop` in the documentation for more information. + :param disallow_propagation: If ``keep_locals`` is False, this optional set of strings will be considered + as additional locals. + :return: A set of memlets representing the inputs of this scope. + """ + return self._gather_memlets_in_scope(True, root, keep_locals, propagate or {}, disallow_propagation or set(), + **kwargs) + + def output_memlets(self, + root: Optional['ScheduleTreeRoot'] = None, + keep_locals: bool = False, + propagate: Optional[Dict[str, subsets.Range]] = None, + disallow_propagation: Optional[Set[str]] = None, + **kwargs) -> MemletSet: + """ + Returns a union of the set of outputs for this scope. Propagates the memlets used in the scope if + ``keep_locals`` is set to False. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :param keep_locals: If True, keeps the local symbols defined within the scope as part of the resulting memlets. + Otherwise, performs memlet propagation (see ``propagate`` and ``disallow_propagation``) or + assumes the entire container is used. + :param propagate: An optional dictionary mapping symbols to their corresponding ranges outside of this scope. + For example, the range of values a for-loop may take. + If ``keep_locals`` is False, this dictionary will be used to create projection memlets over + the ranges. See :ref:`memprop` in the documentation for more information. + :param disallow_propagation: If ``keep_locals`` is False, this optional set of strings will be considered + as additional locals. + :return: A set of memlets representing the inputs of this scope. + """ + return self._gather_memlets_in_scope(False, root, keep_locals, propagate or {}, disallow_propagation or set(), + **kwargs) + + +@dataclass +class ScheduleTreeRoot(ScheduleTreeScope): + """ + A root of an SDFG schedule tree. This is a schedule tree scope with additional information on + the available descriptors, symbol types, and constants of the tree, aka the descriptor repository. + """ + name: str + containers: Dict[str, data.Data] = field(default_factory=dict) + symbols: Dict[str, symbol] = field(default_factory=dict) + constants: Dict[str, Tuple[data.Data, Any]] = field(default_factory=dict) + callback_mapping: Dict[str, str] = field(default_factory=dict) + arg_names: List[str] = field(default_factory=list) + + def as_sdfg(self) -> SDFG: + from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s # Avoid import loop + return t2s.from_schedule_tree(self) + + def get_root(self) -> 'ScheduleTreeRoot': + return self @dataclass @@ -90,6 +230,12 @@ class StateLabel(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'label {self.state.name}:' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class GotoNode(ScheduleTreeNode): @@ -99,6 +245,12 @@ def as_string(self, indent: int = 0): name = self.target or 'exit' return indent * INDENTATION + f'goto {name}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class AssignNode(ScheduleTreeNode): @@ -112,6 +264,13 @@ class AssignNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + return MemletSet(self.edge.get_read_memlets(root.containers)) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class ForScope(ControlFlowScope): @@ -127,6 +286,33 @@ def as_string(self, indent: int = 0): f'{node.itervar} = {node.update}:\n') return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(memlets_in_ast(ast.parse(self.header.init), root.containers)) + result.update(memlets_in_ast(self.header.condition.code[0], root.containers)) + result.update(memlets_in_ast(ast.parse(self.header.update), root.containers)) + + # If loop range is well-formed, use it in propagation + rng = self.header.loop_range() + if rng is not None: + propagate = {self.header.itervar: rng} + else: + propagate = None + + result.update(super().input_memlets(root, propagate=propagate, **kwargs)) + return result + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + # If loop range is well-formed, use it in propagation + rng = self.header.loop_range() + if rng is not None: + propagate = {self.header.itervar: rng} + else: + propagate = None + + return super().output_memlets(root, propagate=propagate, **kwargs) + @dataclass class WhileScope(ControlFlowScope): @@ -139,6 +325,13 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'while {self.header.test.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(memlets_in_ast(self.header.test.code[0], root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result + @dataclass class DoWhileScope(ControlFlowScope): @@ -152,6 +345,13 @@ def as_string(self, indent: int = 0): footer = indent * INDENTATION + f'while {self.header.test.as_string}\n' return header + super().as_string(indent) + footer + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(memlets_in_ast(self.header.test.code[0], root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result + @dataclass class GeneralLoopScope(ControlFlowScope): @@ -203,6 +403,13 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result + @dataclass class StateIfScope(IfScope): @@ -224,6 +431,12 @@ class BreakNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'break' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class ContinueNode(ScheduleTreeNode): @@ -234,6 +447,12 @@ class ContinueNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'continue' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class ElifScope(ControlFlowScope): @@ -246,6 +465,13 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result + @dataclass class ElseScope(ControlFlowScope): @@ -269,6 +495,18 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return super().input_memlets(root, + propagate={k: v + for k, v in zip(self.node.map.params, self.node.map.range)}, + **kwargs) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return super().output_memlets(root, + propagate={k: v + for k, v in zip(self.node.map.params, self.node.map.range)}, + **kwargs) + @dataclass class ConsumeScope(DataflowScope): @@ -284,7 +522,7 @@ def as_string(self, indent: int = 0): @dataclass -class PipelineScope(DataflowScope): +class PipelineScope(MapScope): """ Pipeline scope. """ @@ -308,12 +546,18 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'tasklet({in_memlets})' return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet(self.in_memlets.values()) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet(self.out_memlets.values()) + @dataclass class LibraryCall(ScheduleTreeNode): node: nodes.LibraryNode - in_memlets: Union[Dict[str, Memlet], Set[Memlet]] - out_memlets: Union[Dict[str, Memlet], Set[Memlet]] + in_memlets: Union[Dict[str, Memlet], MemletSet] + out_memlets: Union[Dict[str, Memlet], MemletSet] def as_string(self, indent: int = 0): if isinstance(self.in_memlets, set): @@ -330,6 +574,16 @@ def as_string(self, indent: int = 0): if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + if isinstance(self.in_memlets, set): + return MemletSet(self.in_memlets) + return MemletSet(self.in_memlets.values()) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + if isinstance(self.out_memlets, set): + return MemletSet(self.out_memlets) + return MemletSet(self.out_memlets.values()) + @dataclass class CopyNode(ScheduleTreeNode): @@ -348,6 +602,16 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + if self.memlet.other_subset is not None: + return MemletSet({Memlet(data=self.target, subset=self.memlet.other_subset, wcr=self.memlet.wcr)}) + + return MemletSet({Memlet.from_array(self.target, root.containers[self.target], self.memlet.wcr)}) + @dataclass class DynScopeCopyNode(ScheduleTreeNode): @@ -360,6 +624,12 @@ class DynScopeCopyNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class ViewNode(ScheduleTreeNode): @@ -372,6 +642,12 @@ class ViewNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({Memlet.from_array(self.target, self.view_desc)}) + @dataclass class NView(ViewNode): @@ -398,6 +674,30 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = refset from {type(self.src_desc).__name__.lower()}' return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({Memlet.from_array(self.target, self.ref_desc)}) + + +@dataclass +class StateBoundaryNode(ScheduleTreeNode): + """ + A node that represents a state boundary (e.g., when a write-after-write is encountered). This node + is used only during conversion from a schedule tree to an SDFG. + """ + due_to_control_flow: bool = False + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'state boundary' + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes class ScheduleNodeVisitor: diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 59a2c178d2..65b34db6f4 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -1,9 +1,11 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast +from collections import defaultdict +import copy from dace.frontend.python import memlet_parser -from dace import data, Memlet -from typing import Callable, Dict, Optional, Set, Union +from dace import data, Memlet, subsets +from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Tuple, Union class MemletReplacer(ast.NodeTransformer): @@ -77,3 +79,168 @@ def visit_Subscript(self, node: ast.Subscript): if isinstance(node.value, ast.Name) and node.value.id in self.array_filter: return self._replace(node) return self.generic_visit(node) + + +class MemletSet(Set[Memlet]): + """ + Implements a set of memlets that considers subsets that intersect or are covered by its other memlets. + Set updates and unions also perform unions on the contained memlet subsets. + """ + + def __init__(self, iterable: Optional[Iterable[Memlet]] = None, intersection_is_contained: bool = True): + """ + Initializes a memlet set. + + :param iterable: An optional iterable of memlets to initialize the set with. + :param intersection_is_contained: Whether the check ``m in memlet_set`` should return True if the memlet + only intersects with the contents of the set. If False, only completely + covered subsets would return True. + """ + self.internal_set: Dict[str, Set[Memlet]] = {} + self.intersection_is_contained = intersection_is_contained + if iterable is not None: + self.update(iterable) + + def __iter__(self): + for subset in self.internal_set.values(): + yield from subset + + def update(self, *iterable: Iterable[Memlet]): + """ + Updates set of memlets via union of existing ranges. + """ + if len(iterable) == 0: + return + if len(iterable) > 1: + for i in iterable: + self.update(i) + return + + to_update, = iterable + for elem in to_update: + self.add(elem) + + def add(self, elem: Memlet): + """ + Adds a memlet to the set, potentially performing a union of existing ranges. + """ + if elem.data not in self.internal_set: + self.internal_set[elem.data] = {elem} + return + + # Memlet is in set, either perform a union (if possible) or add to internal set + # TODO(later): Consider other_subset as well + for existing_memlet in self.internal_set[elem.data]: + try: + if existing_memlet.subset.intersects(elem.subset) == True: # Definitely intersects + if existing_memlet.subset.covers(elem.subset): + break # Nothing to do + + # Create a new union memlet + self.internal_set[elem.data].remove(existing_memlet) + new_memlet = copy.deepcopy(existing_memlet) + new_memlet.subset = subsets.union(existing_memlet.subset, elem.subset) + self.internal_set[elem.data].add(new_memlet) + break + except TypeError: # Indeterminate + pass + else: # all intersections were False or indeterminate (may or does not intersect with existing memlets) + self.internal_set[elem.data].add(elem) + + def __contains__(self, elem: Memlet) -> bool: + """ + Returns True iff the memlet or a range superset thereof exists in this set. + """ + if elem.data not in self.internal_set: + return False + for existing_memlet in self.internal_set[elem.data]: + if existing_memlet.subset.covers(elem.subset): + return True + if self.intersection_is_contained: + try: + if existing_memlet.subset.intersects(elem.subset) == False: + continue + else: # May intersect or indeterminate + return True + except TypeError: + return True + + return False + + def union(self, *s: Iterable[Memlet]) -> 'MemletSet': + """ + Performs a set-union (with memlet union) over the given sets of memlets. + + :return: New memlet set containing the union of this set and the inputs. + """ + newset = MemletSet(self) + newset.update(s) + return newset + + +T = TypeVar('T') + + +class MemletDict(Dict[Memlet, T]): + """ + Implements a dictionary with memlet keys that considers subsets that intersect or are covered by its other memlets. + """ + + def __init__(self, **kwargs): + self.internal_dict: Dict[str, Dict[Memlet, T]] = defaultdict(dict) + if kwargs: + self.update(kwargs) + + def _getkey(self, elem: Memlet) -> Optional[Memlet]: + """ + Returns the corresponding key (exact, covered, intersecting, or indeterminately intersecting memlet) if + exists in the dictionary, or None if it does not. + """ + if elem.data not in self.internal_dict: + return None + for existing_memlet in self.internal_dict[elem.data]: + if existing_memlet.subset.covers(elem.subset): + return existing_memlet + try: + if existing_memlet.subset.intersects(elem.subset) == False: # Definitely does not intersect + continue + except TypeError: + pass + + # May or will intersect + return existing_memlet + + return None + + def _setkey(self, key: Memlet, value: T) -> None: + self.internal_dict[key.data][key] = value + + def clear(self): + self.internal_dict.clear() + + def update(self, mapping: Dict[Memlet, T]): + for k, v in mapping.items(): + ak = self._getkey(k) + if ak is None: + self._setkey(k, v) + else: + self._setkey(ak, v) + + def __contains__(self, elem: Memlet) -> bool: + """ + Returns True iff the memlet or a range superset thereof exists in this dictionary. + """ + return self._getkey(elem) is not None + + def __getitem__(self, key: Memlet) -> T: + actual_key = self._getkey(key) + if actual_key is None: + raise KeyError(key) + return self.internal_dict[key.data][actual_key] + + def __setitem__(self, key: Memlet, value: T) -> None: + actual_key = self._getkey(key) + if actual_key is None: + self._setkey(key, value) + else: + self._setkey(actual_key, value) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index f048389421..cc279479ff 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -419,6 +419,10 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di if symbolic.issymbolic(dim): used_symbols.update(dim.free_symbols) + if any(s not in defined_vars for s in (used_symbols - set(self.params))): + # Cannot propagate symbols that are undefined outside scope (e.g., internal symbols) + return False + if (used_symbols & set(self.params) and any(symbolic.pystr_to_symbolic(s) not in defined_vars for s in node_range.free_symbols)): # Cannot propagate symbols that are undefined in the outer range @@ -682,6 +686,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): return condition_edges + def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: """ Annotate the states of an SDFG with the number of executions. @@ -1393,6 +1398,7 @@ def propagate_subset(memlets: List[Memlet], params: List[str], rng: subsets.Subset, defined_variables: Set[symbolic.SymbolicType] = None, + undefined_variables: Set[symbolic.SymbolicType] = None, use_dst: bool = False) -> Memlet: """ Tries to propagate a list of memlets through a range (computes the image of the memlet function applied on an integer set of, e.g., a @@ -1405,8 +1411,12 @@ def propagate_subset(memlets: List[Memlet], range to propagate with. :param defined_variables: A set of symbols defined that will remain the same throughout propagation. If None, assumes - that all symbols outside of `params` have been - defined. + that all symbols outside of ``params``, except + for ``undefined_variables``, have been defined. + :param undefined_variables: A set of symbols that are explicitly considered + as not defined throughout propagation, such as + locals. Their existence will trigger propagating + the entire memlet. :param use_dst: Whether to propagate the memlets' dst subset or use the src instead, depending on propagation direction. :return: Memlet with propagated subset and volume. @@ -1420,6 +1430,11 @@ def propagate_subset(memlets: List[Memlet], defined_variables |= memlet.free_symbols defined_variables -= set(params) defined_variables = set(symbolic.pystr_to_symbolic(p) for p in defined_variables) + else: + defined_variables = set(defined_variables) + + if undefined_variables: + defined_variables = defined_variables - set(symbolic.pystr_to_symbolic(p) for p in undefined_variables) # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1450,18 +1465,25 @@ def propagate_subset(memlets: List[Memlet], tmp_subset = pattern.propagate(arr, [subset], rng) break else: - # No patterns found. Emit a warning and propagate the entire - # array whenever symbols are used - warnings.warn('Cannot find appropriate memlet pattern to ' - 'propagate %s through %s' % (str(subset), str(rng))) + # No patterns found. Propagate the entire array whenever symbols are used entire_array = subsets.Range.from_array(arr) paramset = set(map(str, params)) # Fill in the entire array only if one of the parameters appears in the - # free symbols list of the subset dimension - tmp_subset = subsets.Range([ - ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s - for s, ea in zip(subset, entire_array) - ]) + # free symbols list of the subset dimension or is undefined outside + tmp_subset_rng = [] + for s, ea in zip(subset, entire_array): + contains_params = False + contains_undefs = False + for sdim in s: + fsyms = _freesyms(sdim) + fsyms_str = set(map(str, fsyms)) + contains_params |= len(fsyms_str & paramset) != 0 + contains_undefs |= len(fsyms - defined_variables) != 0 + if contains_params or contains_undefs: + tmp_subset_rng.append(ea) + else: + tmp_subset_rng.append(s) + tmp_subset = subsets.Range(tmp_subset_rng) # Union edges as necessary if new_subset is None: diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 927f033584..5db88370d0 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -38,7 +38,7 @@ from dace.codegen.instrumentation.report import InstrumentationReport from dace.codegen.instrumentation.data.data_report import InstrumentedDataReport from dace.codegen.compiled_sdfg import CompiledSDFG - from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeScope + from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeRoot class NestedDict(dict): @@ -1079,7 +1079,7 @@ def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, ########################################## - def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeScope': + def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeRoot': """ Creates a schedule tree from this SDFG and all nested SDFGs. The schedule tree is a tree of nodes that represent the execution order of the SDFG. @@ -1087,7 +1087,8 @@ def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeScope': etc.) or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, - erasing an empty if branch, or merging two consecutive for-loops. + erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the + ``as_sdfg`` method or the ``from_schedule_tree`` function in ``dace.sdfg.analysis.schedule_tree.tree_to_sdfg``. :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might not be usable after the conversion if ``in_place`` is True! diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py new file mode 100644 index 0000000000..507a3d7226 --- /dev/null +++ b/tests/schedule_tree/propagation_test.py @@ -0,0 +1,105 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests schedule tree input/output memlet computation +""" +import dace +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn +from dace.properties import CodeBlock +import numpy as np + + +def test_stree_propagation_forloop(): + N = dace.symbol('N') + + @dace.program + def tester(a: dace.float64[20]): + for i in range(1, N): + a[i] = 2 + a[1] = 1 + + stree = tester.to_sdfg().as_schedule_tree() + stree = t2s.insert_state_boundaries_to_tree(stree) + + node_types = [n for n in stree.preorder_traversal()] + assert isinstance(node_types[2], tn.ForScope) + memlet = dace.Memlet('a[1:N]') + memlet._is_data_src = False + assert list(node_types[2].output_memlets()) == [memlet] + + +def test_stree_propagation_symassign(): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={ + 'N': N, + }, + children=[ + tn.MapScope(node=dace.nodes.MapEntry(dace.nodes.Map('map', ['i'], dace.subsets.Range([(1, N - 1, 1)]))), + children=[ + tn.AssignNode('j', CodeBlock('N + i'), dace.InterstateEdge(assignments=dict(j='N + i'))), + tn.TaskletNode(nodes.Tasklet('inner', {}, {'out'}, 'out = inp + 2'), + {'inp': dace.Memlet('A[j]')}, {'out': dace.Memlet('A[j]')}), + ]), + ], + ) + stree.children[0].parent = stree + for c in stree.children[0].children: + c.parent = stree.children[0] + + assert list(stree.children[0].input_memlets()) == [dace.Memlet('A[0:20]', volume=N - 1)] + + +def test_stree_propagation_dynset(): + H = dace.symbol('H') + nnz = dace.symbol('nnz') + W = dace.symbol('W') + + @dace.program + def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32[nnz], x: dace.float32[W]): + b = np.zeros([H], dtype=np.float32) + + for i in dace.map[0:H]: + for j in dace.map[A_row[i]:A_row[i + 1]]: + b[i] += A_val[j] * x[A_col[j]] + + return b + + sdfg = spmv.to_sdfg() + stree = sdfg.as_schedule_tree() + assert len(stree.children) == 2 + assert all(isinstance(c, tn.MapScope) for c in stree.children) + mapscope = stree.children[1] + _, _, dynrangemap = mapscope.children + assert isinstance(dynrangemap, tn.MapScope) + + # Check dynamic range map memlets + internal_memlets = list(dynrangemap.input_memlets()) + internal_memlet_data = [m.data for m in internal_memlets] + assert 'x' in internal_memlet_data + assert 'A_val' in internal_memlet_data + assert 'A_row' not in internal_memlet_data + for m in internal_memlets: + if m.data == 'A_val': + assert m.subset != dace.subsets.Range([(0, nnz - 1, 1)]) # Not propagated + + # Check top-level scope memlets + external_memlets = list(mapscope.input_memlets()) + assert dace.Memlet('A_row[0:H]') in external_memlets + assert dace.Memlet('A_row[1:H+1]') in external_memlets + assert dace.Memlet('x[0:W]', volume=0, dynamic=True) in external_memlets + assert dace.Memlet('A_val[0:nnz]', volume=0, dynamic=True) in external_memlets + for m in external_memlets: + if m.data == 'A_val': + assert m.subset == dace.subsets.Range([(0, nnz - 1, 1)]) # Propagated + + +if __name__ == '__main__': + test_stree_propagation_forloop() + test_stree_propagation_symassign() + test_stree_propagation_dynset() diff --git a/tests/schedule_tree/roundtrip_test.py b/tests/schedule_tree/roundtrip_test.py new file mode 100644 index 0000000000..e4aea2a56a --- /dev/null +++ b/tests/schedule_tree/roundtrip_test.py @@ -0,0 +1,46 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests conversion of schedule trees to SDFGs. +""" +import dace +import numpy as np + + +def test_implicit_inline_and_constants(): + """ + Tests implicit inlining upon roundtrip conversion, as well as constants with conflicting names. + """ + + @dace + def nester(A: dace.float64[20]): + A[:] = 12 + + @dace.program + def tester(A: dace.float64[20, 20]): + for i in dace.map[0:20]: + nester(A[:, i]) + + sdfg = tester.to_sdfg(simplify=False) + + # Inject constant into nested SDFG + assert len(list(sdfg.all_sdfgs_recursive())) > 1 + sdfg.add_constant('cst', 13) # Add an unused constant + sdfg.sdfg_list[-1].add_constant('cst', 1, dace.data.Scalar(dace.float64)) + tasklet = next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.Tasklet)) + tasklet.code.as_string = tasklet.code.as_string.replace('12', 'cst') + + # Perform a roundtrip conversion + stree = sdfg.as_schedule_tree() + new_sdfg = stree.as_sdfg() + + assert len(list(new_sdfg.all_sdfgs_recursive())) == 1 + assert new_sdfg.constants['cst_0'].dtype == np.float64 + + # Test SDFG + a = np.random.rand(20, 20) + new_sdfg(a) # Tests arg_names + assert np.allclose(a, 1) + + +if __name__ == '__main__': + test_implicit_inline_and_constants() diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py new file mode 100644 index 0000000000..5422f94472 --- /dev/null +++ b/tests/schedule_tree/to_sdfg_test.py @@ -0,0 +1,222 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests components in conversion of schedule trees to SDFGs. +""" +import dace +from dace.codegen import control_flow as cf +from dace.properties import CodeBlock +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn +import pytest + + +def test_state_boundaries_none(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert tn.StateBoundaryNode not in [type(n) for n in stree.children] + + +def test_state_boundaries_waw(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +@pytest.mark.parametrize('overlap', (False, True)) +def test_state_boundaries_waw_ranges(overlap): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={'N': N}, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'pass'), {}, {'out': dace.Memlet('A[0:N/2]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'pass'), {}, + {'out': dace.Memlet('A[1:N]' if overlap else 'A[N/2+1:N]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + if overlap: + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + else: + assert [tn.TaskletNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_war(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_read_write_chain(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('B[0]')}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla3', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.TaskletNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_data_race(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla11', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('B[0]')}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla3', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode, + tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_cfg(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + tn.ForScope([ + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet('A[1]')}), + ], cf.ForScope(None, None, True, 'i', None, '0', CodeBlock('i < 20'), 'i + 1', None, [])), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.ForScope] == [type(n) for n in stree.children] + + +def test_state_boundaries_state_transition(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={ + 'N': dace.symbol('N'), + }, + children=[ + tn.AssignNode('irrelevant', CodeBlock('N + 1'), dace.InterstateEdge(assignments=dict(irrelevant='N + 1'))), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + tn.AssignNode('relevant', CodeBlock('A[1] + 2'), + dace.InterstateEdge(assignments=dict(relevant='A[1] + 2'))), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.AssignNode, tn.TaskletNode, tn.StateBoundaryNode, tn.AssignNode] == [type(n) for n in stree.children] + + +@pytest.mark.parametrize('boundary', (False, True)) +def test_state_boundaries_propagation(boundary): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={ + 'N': N, + }, + children=[ + tn.MapScope(node=dace.nodes.MapEntry(dace.nodes.Map('map', ['i'], dace.subsets.Range([(1, N - 1, 1)]))), + children=[ + tn.TaskletNode(nodes.Tasklet('inner', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[i]')}), + ]), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[1]' if boundary else 'A[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + + node_types = [type(n) for n in stree.preorder_traversal()] + if boundary: + assert [tn.MapScope, tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == node_types[1:] + else: + assert [tn.MapScope, tn.TaskletNode, tn.TaskletNode] == node_types[1:] + + +if __name__ == '__main__': + test_state_boundaries_none() + test_state_boundaries_waw() + test_state_boundaries_waw_ranges(overlap=False) + test_state_boundaries_waw_ranges(overlap=True) + test_state_boundaries_war() + test_state_boundaries_read_write_chain() + test_state_boundaries_data_race() + test_state_boundaries_cfg() + test_state_boundaries_state_transition() + test_state_boundaries_propagation(boundary=False) + test_state_boundaries_propagation(boundary=True)