From 2c1156b1aa20d7fc06594f1a5b9d5b5be34c731c Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 9 Jan 2025 12:55:14 +0100 Subject: [PATCH] remove support for tuple iterator --- .../runners/dace_fieldview/gtir_dataflow.py | 61 +++++++------------ tests/next_tests/definitions.py | 1 + .../iterator_tests/test_column_stencil.py | 8 +-- 3 files changed, 28 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 189f01abb5..a682d5dfd2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -466,7 +466,7 @@ def _construct_tasklet_result( ), ) - def _visit_deref(self, node: gtir.FunCall) -> DataExpr | tuple[DataExpr | tuple[Any, ...], ...]: + def _visit_deref(self, node: gtir.FunCall) -> DataExpr: """ Visit a `deref` node, which represents dereferencing of an iterator. The iterator is the argument of this node. @@ -481,32 +481,24 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr | tuple[DataExpr | tuple[ memlet; otherwise dereferencing is a runtime operation represented in the SDFG as a tasklet node. """ - assert len(node.args) == 1 - arg_expr = self.visit(node.args[0]) - - if isinstance(arg_expr, tuple): - assert isinstance(node.type, ts.TupleType) - symbol_tuple = dace_gtir_utils.make_symbol_tuple("x", node.type) - return gtx_utils.tree_map(lambda f, fsym: self._visit_deref_field(f, fsym.type))( - arg_expr, symbol_tuple - ) - else: - assert isinstance(node.type, (ts.FieldType, ts.ScalarType, itir_ts.ListType)) - return self._visit_deref_field(arg_expr, node.type) - - def _visit_deref_field(self, arg_expr: DataExpr, node_type: ts.DataType) -> DataExpr: # format used for field index tasklet connector IndexConnectorFmt: Final = "__index_{dim}" - # dereferencing a scalar or a literal node results in the node itself + if isinstance(node.type, ts.TupleType): + raise NotImplementedError("Tuple deref not supported.") + + assert len(node.args) == 1 + arg_expr = self.visit(node.args[0]) + if not isinstance(arg_expr, IteratorExpr): + # dereferencing a scalar or a literal node results in the node itself return arg_expr field_desc = arg_expr.field.desc(self.sdfg) if isinstance(field_desc, dace.data.Scalar): # deref a zero-dimensional field assert len(arg_expr.field_domain) == 0 - assert isinstance(node_type, ts.ScalarType) + assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") # default case: deref a field with one or more dimensions @@ -1282,6 +1274,7 @@ def _visit_shift_multidim( else: it = self.visit(iterator) + assert isinstance(it, IteratorExpr) return offset_provider_arg, offset_value_arg, it def _make_cartesian_shift( @@ -1430,9 +1423,7 @@ def _make_unstructured_shift( return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) - def _visit_shift( - self, node: gtir.FunCall - ) -> IteratorExpr | tuple[IteratorExpr | tuple[Any, ...], ...]: + def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type IndexDType: Final = dace_utils.as_dace_type( ts.ScalarType(kind=getattr(ts.ScalarKind, gtir.INTEGER_INDEX_BUILTIN.upper())) @@ -1458,25 +1449,19 @@ def _visit_shift( else self.visit(offset_value_arg) ) - def shift_field_iterator(field_iterator: IteratorExpr) -> IteratorExpr: - if isinstance(offset_provider_type, gtx_common.Dimension): - return self._make_cartesian_shift(field_iterator, offset_provider_type, offset_expr) - else: - # initially, the storage for the connectivity tables is created as transient; - # when the tables are used, the storage is changed to non-transient, - # so the corresponding arrays are supposed to be allocated by the SDFG caller - offset_table = dace_utils.connectivity_identifier(offset) - self.sdfg.arrays[offset_table].transient = False - offset_table_node = self.state.add_access(offset_table) - - return self._make_unstructured_shift( - field_iterator, offset_provider_type, offset_table_node, offset_expr - ) - - if isinstance(it, tuple): - return gtx_utils.tree_map(shift_field_iterator)(it) + if isinstance(offset_provider_type, gtx_common.Dimension): + return self._make_cartesian_shift(it, offset_provider_type, offset_expr) else: - return shift_field_iterator(it) + # initially, the storage for the connectivity tables is created as transient; + # when the tables are used, the storage is changed to non-transient, + # so the corresponding arrays are supposed to be allocated by the SDFG caller + offset_table = dace_utils.connectivity_identifier(offset) + self.sdfg.arrays[offset_table].transient = False + offset_table_node = self.state.add_access(offset_table) + + return self._make_unstructured_shift( + it, offset_provider_type, offset_table_node, offset_expr + ) def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: """ diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index b2b24ae756..80d0932329 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -149,6 +149,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] ) EMBEDDED_SKIP_LIST = [ diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index e7eb605caa..3b4fc0a70c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -122,19 +122,19 @@ def k_level_condition_upper_tuple(k_idx, k_level): @pytest.mark.parametrize( "fun, k_level, inp_function, ref_function", [ - ( + pytest.param( k_level_condition_lower, lambda inp: 0, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([[0], inp[:-1]]), ), - ( + pytest.param( k_level_condition_upper, lambda inp: inp.shape[0] - 1, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([inp[1:], [0]]), ), - ( + pytest.param( k_level_condition_upper_tuple, lambda inp: inp[0].shape[0] - 1, lambda k_size: ( @@ -142,11 +142,11 @@ def k_level_condition_upper_tuple(k_idx, k_level): gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), ), lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), + marks=pytest.mark.uses_tuple_iterator, ), ], ) @pytest.mark.uses_tuple_args -@pytest.mark.uses_tuple_iterator def test_k_level_condition(program_processor, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor