diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 52096730db..f0577ffaf2 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -114,6 +114,7 @@ def decorated_program( args = (*args, out) flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) if len(sdfg.arg_names) > len(flat_args): + # The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments. flat_args = (*flat_args, *arguments.iter_size_args(args)) if sdfg_program._lastargs: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 4f67ec2764..cff0e79dd3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -264,6 +264,72 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes +def _create_field_operator_impl( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: FieldopDomain, + output_edge: gtir_dataflow.DataflowOutputEdge, + output_type: ts.FieldType, + map_exit: dace.nodes.MapExit, +) -> FieldopData: + """ + Helper method to allocate a temporary array that stores one field computed by a field operator. + + This method is called by `_create_field_operator()`. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + output_edge: The dataflow write edge representing the output data. + output_type: The GT4Py field type descriptor. + map_exit: The `MapExit` node of the field operator map scope. + + Returns: + The field data descriptor, which includes the field access node in the given `state` + and the field domain offset. + """ + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + + domain_dims, domain_offset, domain_shape = _get_field_layout(domain) + domain_indices = _get_domain_indices(domain_dims, domain_offset) + domain_subset = dace_subsets.Range.from_indices(domain_indices) + + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == output_type.dtype + field_dtype = output_edge.result.gt_dtype + field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + assert isinstance(dataflow_output_desc, dace.data.Scalar) + field_subset = domain_subset + else: + assert isinstance(output_type.dtype, ts.ListType) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type + field_dtype = output_edge.result.gt_dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 + # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + assert output_edge.result.gt_dtype.offset_type is not None + field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] + field_shape = [*domain_shape, dataflow_output_desc.shape[0]] + field_offset = [*domain_offset, dataflow_output_desc.offset[0]] + field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) + + # allocate local temporary storage + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(map_exit, field_node, field_subset) + + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) + + def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, @@ -275,7 +341,10 @@ def _create_field_operator( | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], ) -> FieldopResult: """ - Helper method to allocate a temporary field to store the output of a field operator. + Helper method to build the output of a field operator, which can consist of + a single field or, in case of scan, a tuple of fields. + The scan field operator typically computes multiple fields for each K-level: + for each field, this method will call `_create_field_operator_impl()`. Args: sdfg: The SDFG that represents the scope of the field data. @@ -287,55 +356,12 @@ def _create_field_operator( output_edges: Single edge or tuple of edges representing the dataflow output data. Returns: - The field data descriptor, which includes the field access node in the given `state` - and the field domain offset. + The descriptor of the field operator result, which can be either a single field + or a tuple fields. """ - def _create_field_operator_impl( - output_edge: gtir_dataflow.DataflowOutputEdge, mx: dace.nodes.MapExit, sym: gtir.Sym - ) -> FieldopData: - assert isinstance(sym.type, ts.FieldType) - dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - assert output_edge.result.gt_dtype == sym.type.dtype - field_dtype = output_edge.result.gt_dtype - field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) - assert isinstance(dataflow_output_desc, dace.data.Scalar) - field_subset = domain_subset - else: - assert isinstance(sym.type.dtype, ts.ListType) - assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) - assert output_edge.result.gt_dtype.element_type == sym.type.dtype.element_type - field_dtype = output_edge.result.gt_dtype.element_type - assert isinstance(dataflow_output_desc, dace.data.Array) - assert len(dataflow_output_desc.shape) == 1 - # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - assert output_edge.result.gt_dtype.offset_type is not None - field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] - field_shape = [*domain_shape, dataflow_output_desc.shape[0]] - field_offset = [*domain_offset, dataflow_output_desc.offset[0]] - field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) - - # allocate local temporary storage - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) - field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) - field_node = state.add_access(field_name) - - # and here the edge writing the dataflow result data through the map exit node - output_edge.connect(mx, field_node, field_subset) - - return FieldopData( - field_node, - ts.FieldType(field_dims, field_dtype), - offset=(field_offset if set(field_offset) != {0} else None), - ) - - domain_dims, domain_offset, domain_shape = _get_field_layout(domain) - domain_indices = _get_domain_indices(domain_dims, domain_offset) - domain_subset = dace_subsets.Range.from_indices(domain_indices) - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( + map_entry, map_exit = sdfg_builder.add_map( "fieldop", state, ndrange={ @@ -346,17 +372,19 @@ def _create_field_operator_impl( # here we setup the edges passing through the map entry node for edge in input_edges: - edge.connect(me) + edge.connect(map_entry) - if isinstance(output_edges, gtir_dataflow.DataflowOutputEdge): - assert isinstance(node_type, ts.FieldType) - return _create_field_operator_impl(output_edges, mx, im.sym("x", node_type)) + if isinstance(node_type, ts.FieldType): + assert isinstance(output_edges, gtir_dataflow.DataflowOutputEdge) + return _create_field_operator_impl(sdfg, state, domain, output_edges, node_type, map_exit) else: # handle tuples of fields - assert isinstance(node_type, ts.TupleType) + tuple_syms = dace_gtir_utils.make_symbol_tuple("x", node_type) return gtx_utils.tree_map( - lambda output_edge, sym: _create_field_operator_impl(output_edge, mx, sym) - )(output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type)) + lambda output_edge, output_sym: _create_field_operator_impl( + sdfg, state, domain, output_edge, output_sym.type, map_exit + ) + )(output_edges, tuple_syms) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -444,7 +472,7 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - input_edges, output_edges = gtir_dataflow.apply( + input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow( sdfg, state, sdfg_builder, stencil_expr, fieldop_args ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 53afedaee6..1eb898530b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -141,7 +141,8 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: class DataflowInputEdge(Protocol): """ - This protocol represents an open connection into the dataflow. + This protocol describes how to concretize a data edge to read data from a source node + into the dataflow. It provides the `connect` method to setup an input edge from an external data source. The most common case is that the dataflow represents a stencil, which is instantied @@ -197,9 +198,11 @@ class EmptyInputEdge(DataflowInputEdge): node: dace.nodes.Tasklet def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: - # the empty edge is not created if the dataflow is instatiated without a map - if map_entry is not None: - self.state.add_nedge(map_entry, self.node, dace.Memlet()) + if map_entry is None: + # outside of a map scope it is possible to instantiate a tasklet node + # without input connectors + return + self.state.add_nedge(map_entry, self.node, dace.Memlet()) @dataclasses.dataclass(frozen=True) @@ -311,7 +314,7 @@ class LambdaToDataflow(eve.NodeVisitor): symbol_map: dict[ str, IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], - ] = dataclasses.field(default_factory=lambda: {}) + ] = dataclasses.field(default_factory=dict) def _add_input_data_edge( self, @@ -579,7 +582,9 @@ def _visit_if_branch( DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], ]: """ - Helper method to visit an if-branch expression and lower it to a dtaflow inside the given nested SDFG and state. + Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. + + This function is called by `_visit_if()` for each if-branch. Args: if_sdfg: The nested SDFG where the if expression is lowered. @@ -600,22 +605,21 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: arg_node = arg.dc_node arg_desc = arg_node.desc(self.sdfg) if isinstance(arg, MemletExpr): - assert set(arg.subset.size()) == {1} + assert arg.subset.sum_elements() == 1 arg_desc = dace.data.Scalar(arg_desc.dtype) - else: - assert isinstance(arg, IteratorExpr) + elif isinstance(arg, IteratorExpr): arg_node = arg.field arg_desc = arg_node.desc(self.sdfg) arg_expr = MemletExpr( arg_node, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc) ) + else: + raise TypeError(f"Unexpected {arg} as input argument.") arg_data = arg_node.data # SDFG data containers with name prefix '__tmp' are expected to be transients inner_data = ( - arg_data.replace("__tmp", "__input", 1) - if arg_data.startswith("__tmp") - else arg_data + arg_data.replace("__tmp", "__gtir", 1) if arg_data.startswith("__tmp") else arg_data ) try: @@ -634,13 +638,13 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: assert isinstance(inner_desc, dace.data.Scalar) return ValueExpr(inner_node, arg.gt_dtype) - lambda_params = [] lambda_args: list[ IteratorExpr | MemletExpr | ValueExpr | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] ] = [] + lambda_params: list[gtir.Sym] = [] for p in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): arg = self.symbol_map[p] if isinstance(arg, tuple): @@ -652,7 +656,9 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: # visit each branch of the if-statement as if it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) - return apply(if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args) + return translate_lambda_to_dataflow( + if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args + ) def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: """ @@ -660,6 +666,48 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A each branch is lowered into a dataflow in a separate state and the if-condition is represented as the inter-state edge condtion. """ + + def visit_if_branch_result( + state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym + ) -> ValueExpr: + # Inner function to create an output connector on the nested SDFG that + # will write the result to the parent SDFG. The result data node + # inside the nested SDFG must have the same name as the output connector. + output_data = str(sym.id) + try: + output_desc = nsdfg.data(output_data) + assert not output_desc.transient + except KeyError: + # if the result is currently written to a transient node, inside the nested SDFG, + # we need to allocate a non-transient data node + result_desc = edge.result.dc_node.desc(nsdfg) + output_desc = result_desc.clone() + output_desc.transient = False + output_data = nsdfg.add_datadesc(output_data, output_desc, find_new_name=True) + output_node = state.add_access(output_data) + state.add_nedge( + edge.result.dc_node, + output_node, + dace.Memlet.from_array(output_data, output_desc), + ) + return ValueExpr(output_node, edge.result.gt_dtype) + + def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr: + # Each output connector of the nested SDFG writes to a transient node in the parent SDFG + inner_data = inner_value.dc_node.data + inner_desc = inner_value.dc_node.desc(nsdfg) + assert not inner_desc.transient + output, output_desc = self.sdfg.add_temp_transient_like(inner_desc) + output_node = self.state.add_access(output) + self.state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output, output_desc), + ) + return ValueExpr(output_node, inner_value.gt_dtype) + assert len(node.args) == 3 # TODO(edopao): enable once supported in next DaCe release @@ -673,7 +721,7 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL ) if isinstance(condition_value, (MemletExpr, ValueExpr)) - else (condition_value.dc_dtype == dace.dtypes.bool) + else (condition_value.dc_dtype == dace.dtypes.bool_) ) nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) @@ -716,39 +764,15 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A for edge in in_edges: edge.connect(map_entry=None) - # the result of each branch needs to be moved to the parent SDFG - def construct_output( - output_state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym - ) -> ValueExpr: - # the output data node has the same name as the nested SDFG output connector - output_data = str(sym.id) - try: - output_desc = nsdfg.data(output_data) - assert not output_desc.transient - except KeyError: - # if the result is currently written to a transient node, inside the nested SDFG, - # we need to allocate a non-transient data node - result_desc = edge.result.dc_node.desc(nsdfg) - output_desc = result_desc.clone() - output_desc.transient = False - output_data = nsdfg.add_datadesc(output_data, output_desc, find_new_name=True) - output_node = output_state.add_access(output_data) - output_state.add_nedge( - edge.result.dc_node, - output_node, - dace.Memlet.from_array(output_data, output_desc), - ) - return ValueExpr(output_node, edge.result.gt_dtype) - if isinstance(out_edge, tuple): assert isinstance(node.type, ts.TupleType) out_symbol = dace_gtir_utils.make_symbol_tuple("__output", node.type) outer_value = gtx_utils.tree_map( - lambda x, y, output_state=if_branch_state: construct_output(output_state, x, y) + lambda x, y, state=if_branch_state: visit_if_branch_result(state, x, y) )(out_edge, out_symbol) else: assert isinstance(node.type, ts.FieldType | ts.ScalarType) - outer_value = construct_output( + outer_value = visit_if_branch_result( if_branch_state, out_edge, im.sym("__output", node.type) ) @@ -790,26 +814,10 @@ def construct_output( self.sdfg.make_array_memlet(input_expr.dc_node.data), ) - def connect_output(inner_value: ValueExpr) -> ValueExpr: - # each output connector of the nested SDFG writes to a transient node in the parent SDFG - inner_data = inner_value.dc_node.data - inner_desc = inner_value.dc_node.desc(nsdfg) - assert not inner_desc.transient - output, output_desc = self.sdfg.add_temp_transient_like(inner_desc) - output_node = self.state.add_access(output) - self.state.add_edge( - nsdfg_node, - inner_data, - output_node, - None, - dace.Memlet.from_array(output, output_desc), - ) - return ValueExpr(output_node, inner_value.gt_dtype) - return ( - gtx_utils.tree_map(connect_output)(result) + gtx_utils.tree_map(write_output_of_nested_sdfg_to_temporary)(result) if isinstance(result, tuple) - else connect_output(result) + else write_output_of_nested_sdfg_to_temporary(result) ) def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: @@ -943,8 +951,7 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: None, dace.Memlet(data=list_arg.dc_node.data, subset=index_arg.value), ) - else: - assert isinstance(index_arg, ValueExpr) + elif isinstance(index_arg, ValueExpr): tasklet_node = self._add_tasklet( "list_get", inputs={"index", "list"}, outputs={"value"}, code="value = list[index]" ) @@ -965,6 +972,8 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: self._add_edge( tasklet_node, "value", result_node, None, dace.Memlet(data=result, subset="0") ) + else: + raise TypeError(f"Unexpected value {index_arg} as index argument.") return ValueExpr(dc_node=result_node, gt_dtype=list_arg.gt_dtype.element_type) @@ -1627,8 +1636,6 @@ def visit_FunCall( def visit_Lambda( self, node: gtir.Lambda ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: - result = self.visit(node.expr) - def _visit_Lambda_impl( output_expr: DataflowOutputEdge | ValueExpr | MemletExpr | SymbolExpr, ) -> DataflowOutputEdge: @@ -1657,9 +1664,11 @@ def _visit_Lambda_impl( output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") return DataflowOutputEdge(self.state, output_expr) + result = self.visit(node.expr) + return ( gtx_utils.tree_map(_visit_Lambda_impl)(result) - if isinstance(result, tuple) + if isinstance(node.type, ts.TupleType) else _visit_Lambda_impl(result) ) @@ -1716,7 +1725,7 @@ def visit_let( return self.visit(node) -def apply( +def translate_lambda_to_dataflow( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.DataflowBuilder, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index e89bc5f513..b51458ad6f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -164,9 +164,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider_type: gtx_common.OffsetProviderType - global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=dict) field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( - default_factory=lambda: {} + default_factory=dict ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map")