From a0c37cb5ddb177c5103c36d25d943fde5e1091c6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 20 Dec 2024 09:57:10 +0100 Subject: [PATCH] minor edit --- .../gtir_builtin_translators.py | 61 ++++++++++--------- .../runners/dace_fieldview/gtir_dataflow.py | 4 +- .../runners/dace_fieldview/utility.py | 4 +- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index f59755649b..131321f77e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -161,7 +161,7 @@ def get_local_view( def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + Compute the `ts.TupleType` corresponding to the structure of a tuple of `FieldopResult`. """ return ts.TupleType( types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] @@ -169,6 +169,11 @@ def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: + """ + Visit a `FieldopResult`, potentially containing nested tuples, and construct a list + of pairs `(str, FieldopData)` containing the symbol name of each tuple field and + the corresponding `FieldopData`. + """ if isinstance(arg, tuple): tuple_type = get_tuple_type(arg) tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type) @@ -337,9 +342,6 @@ def _create_field_operator( # here we setup the edges passing through the map entry node for edge in input_edges: - if isinstance(edge, gtir_dataflow.EmptyInputEdge) and me is None: - # cannot create empty edge from MapEntry node, if this is not present - continue edge.connect(me) def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) -> FieldopData: @@ -402,6 +404,7 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - assert isinstance(node_type, ts.FieldType) return create_field(output_edges, im.sym("x", node_type)) else: + # handle tuples of fields assert isinstance(node_type, ts.TupleType) return gtx_utils.tree_map(create_field)( output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type) @@ -888,8 +891,8 @@ def translate_scan( stencil_expr = scan_expr.args[0] assert isinstance(stencil_expr, gtir.Lambda) - # params[0]: the lambda parameter to propagate the scan state on the vertical dimension - scan_state = str(stencil_expr.params[0].id) + # params[0]: the lambda parameter to propagate the scan carry on the vertical dimension + scan_carry = str(stencil_expr.params[0].id) # params[1]: boolean flag for forward/backward scan assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type) @@ -908,13 +911,13 @@ def scan_output_name(input_name: str) -> str: # visit the initialization value of the scan expression init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state) - # extract type definition of the scan state - scan_state_type = ( + # extract type definition of the scan carry + scan_carry_type = ( init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data) ) # create list of params to the lambda function with associated node type - lambda_symbols = {scan_state: scan_state_type} | { + lambda_symbols = {scan_carry: scan_carry_type} | { str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) if isinstance(arg.type, ts.DataType) @@ -925,7 +928,7 @@ def scan_output_name(input_name: str) -> str: # the data descriptor with the correct field domain offsets for field arguments lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] lambda_args_mapping = { - scan_input_name(scan_state): init_data, + scan_input_name(scan_carry): init_data, } | { str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) } @@ -941,11 +944,11 @@ def scan_output_name(input_name: str) -> str: { str(sym.id): sym.type for sym in dace_gtir_utils.flatten_tuple_fields( - scan_output_name(scan_state), scan_state_type + scan_output_name(scan_carry), scan_carry_type ) } - if isinstance(scan_state_type, ts.TupleType) - else {scan_output_name(scan_state): scan_state_type} + if isinstance(scan_carry_type, ts.TupleType) + else {scan_output_name(scan_carry): scan_carry_type} ) # the scan operator is implemented as an nested SDFG implementing the lambda expression @@ -998,14 +1001,14 @@ def scan_output_name(input_name: str) -> str: nsdfg, compute_state, stencil_builder, stencil_expr, args=stencil_args ) - # now initialize the scan state - scan_state_input = ( - dace_gtir_utils.make_symbol_tuple(scan_state, scan_state_type) - if isinstance(scan_state_type, ts.TupleType) - else im.sym(scan_state, scan_state_type) + # now initialize the scan carry + scan_carry_input = ( + dace_gtir_utils.make_symbol_tuple(scan_carry, scan_carry_type) + if isinstance(scan_carry_type, ts.TupleType) + else im.sym(scan_carry, scan_carry_type) ) - def init_scan_state(sym: gtir.Sym) -> None: + def init_scan_carry(sym: gtir.Sym) -> None: scan_state = str(sym.id) scan_state_desc = nsdfg.data(scan_state) input_state = scan_input_name(scan_state) @@ -1018,17 +1021,17 @@ def init_scan_state(sym: gtir.Sym) -> None: nsdfg.make_array_memlet(input_state), ) - if isinstance(scan_state_input, tuple): - gtx_utils.tree_map(init_scan_state)(scan_state_input) + if isinstance(scan_carry_input, tuple): + gtx_utils.tree_map(init_scan_carry)(scan_carry_input) else: - init_scan_state(scan_state_input) + init_scan_carry(scan_carry_input) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region for edge in input_edges: edge.connect(map_entry=None) - # connect the dataflow result nodes to the variables that carry the scan state along the column axis + # connect the dataflow result nodes to the carry variables def connect_scan_output( scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym ) -> FieldopData: @@ -1057,12 +1060,12 @@ def connect_scan_output( return FieldopData(output_node, output_type, scan_output_offset) lambda_output = ( - gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) - if (isinstance(result, tuple) and isinstance(scan_state_input, tuple)) - else connect_scan_output(result, scan_state_input) + gtx_utils.tree_map(connect_scan_output)(result, scan_carry_input) + if (isinstance(result, tuple) and isinstance(scan_carry_input, tuple)) + else connect_scan_output(result, scan_carry_input) if ( isinstance(result, gtir_dataflow.DataflowOutputEdge) - and isinstance(scan_state_input, gtir.Sym) + and isinstance(scan_carry_input, gtir.Sym) ) else None ) @@ -1075,8 +1078,8 @@ def connect_scan_output( if (compute_state.degree(data_node) == 0) and ( (not data_desc.transient) or data_node.data.startswith( - scan_state - ) # exceptional case where the state is not used, not a scan indeed + scan_carry + ) # exceptional case where the carry variable is not used, not a scan indeed ): # isolated node, connect it to a transient to avoid SDFG validation errors temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 5d7159c987..ee22c4cd13 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -196,6 +196,7 @@ class EmptyInputEdge(DataflowInputEdge): node: dace.nodes.Tasklet def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + # cannot create empty edge from MapEntry node, if this is not present if map_entry is not None: self.state.add_nedge(map_entry, self.node, dace.Memlet()) @@ -564,7 +565,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: assert len(node.args) == 3 - # TODO(edopao): enable once DaCe supports it in next release + # TODO(edopao): enable once supported in next DaCe release use_conditional_block: Final[bool] = False condition_value = self.visit(node.args[0]) @@ -690,6 +691,7 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: lambda_args.append(inner_arg) lambda_params.append(im.sym(p)) + # visit each branch of the if-statement as it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) return visit_lambda(nsdfg, state, self.subgraph_builder, lambda_node, lambda_args) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 33c333a9f3..ad120e2502 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -30,7 +30,7 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: """ Creates a tuple representation of the symbols corresponding to the tuple fields. - The constructed tuple presrves the nested nature of the type, is any. + The constructed tuple preserves the nested nature of the type, if any. Examples -------- @@ -53,7 +53,7 @@ def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.S def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]: """ - Creates a list of names with the corresponding data type for all elements of the given tuple. + Creates a list of symbols, annotated with the data type, for all elements of the given tuple. Examples --------