Skip to content

Commit

Permalink
remove support for tuple iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 9, 2025
1 parent 46322ac commit 2c1156b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _construct_tasklet_result(
),
)

def _visit_deref(self, node: gtir.FunCall) -> DataExpr | tuple[DataExpr | tuple[Any, ...], ...]:
def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
"""
Visit a `deref` node, which represents dereferencing of an iterator.
The iterator is the argument of this node.
Expand All @@ -481,32 +481,24 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr | tuple[DataExpr | tuple[
memlet; otherwise dereferencing is a runtime operation represented in
the SDFG as a tasklet node.
"""
assert len(node.args) == 1
arg_expr = self.visit(node.args[0])

if isinstance(arg_expr, tuple):
assert isinstance(node.type, ts.TupleType)
symbol_tuple = dace_gtir_utils.make_symbol_tuple("x", node.type)
return gtx_utils.tree_map(lambda f, fsym: self._visit_deref_field(f, fsym.type))(
arg_expr, symbol_tuple
)
else:
assert isinstance(node.type, (ts.FieldType, ts.ScalarType, itir_ts.ListType))
return self._visit_deref_field(arg_expr, node.type)

def _visit_deref_field(self, arg_expr: DataExpr, node_type: ts.DataType) -> DataExpr:
# format used for field index tasklet connector
IndexConnectorFmt: Final = "__index_{dim}"

# dereferencing a scalar or a literal node results in the node itself
if isinstance(node.type, ts.TupleType):
raise NotImplementedError("Tuple deref not supported.")

assert len(node.args) == 1
arg_expr = self.visit(node.args[0])

if not isinstance(arg_expr, IteratorExpr):
# dereferencing a scalar or a literal node results in the node itself
return arg_expr

field_desc = arg_expr.field.desc(self.sdfg)
if isinstance(field_desc, dace.data.Scalar):
# deref a zero-dimensional field
assert len(arg_expr.field_domain) == 0
assert isinstance(node_type, ts.ScalarType)
assert isinstance(node.type, ts.ScalarType)
return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0")

# default case: deref a field with one or more dimensions
Expand Down Expand Up @@ -1282,6 +1274,7 @@ def _visit_shift_multidim(
else:
it = self.visit(iterator)

assert isinstance(it, IteratorExpr)
return offset_provider_arg, offset_value_arg, it

def _make_cartesian_shift(
Expand Down Expand Up @@ -1430,9 +1423,7 @@ def _make_unstructured_shift(

return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices)

def _visit_shift(
self, node: gtir.FunCall
) -> IteratorExpr | tuple[IteratorExpr | tuple[Any, ...], ...]:
def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr:
# convert builtin-index type to dace type
IndexDType: Final = dace_utils.as_dace_type(
ts.ScalarType(kind=getattr(ts.ScalarKind, gtir.INTEGER_INDEX_BUILTIN.upper()))
Expand All @@ -1458,25 +1449,19 @@ def _visit_shift(
else self.visit(offset_value_arg)
)

def shift_field_iterator(field_iterator: IteratorExpr) -> IteratorExpr:
if isinstance(offset_provider_type, gtx_common.Dimension):
return self._make_cartesian_shift(field_iterator, offset_provider_type, offset_expr)
else:
# initially, the storage for the connectivity tables is created as transient;
# when the tables are used, the storage is changed to non-transient,
# so the corresponding arrays are supposed to be allocated by the SDFG caller
offset_table = dace_utils.connectivity_identifier(offset)
self.sdfg.arrays[offset_table].transient = False
offset_table_node = self.state.add_access(offset_table)

return self._make_unstructured_shift(
field_iterator, offset_provider_type, offset_table_node, offset_expr
)

if isinstance(it, tuple):
return gtx_utils.tree_map(shift_field_iterator)(it)
if isinstance(offset_provider_type, gtx_common.Dimension):
return self._make_cartesian_shift(it, offset_provider_type, offset_expr)
else:
return shift_field_iterator(it)
# initially, the storage for the connectivity tables is created as transient;
# when the tables are used, the storage is changed to non-transient,
# so the corresponding arrays are supposed to be allocated by the SDFG caller
offset_table = dace_utils.connectivity_identifier(offset)
self.sdfg.arrays[offset_table].transient = False
offset_table_node = self.state.add_access(offset_table)

return self._make_unstructured_shift(
it, offset_provider_type, offset_table_node, offset_expr
)

def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE),
]
)
EMBEDDED_SKIP_LIST = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,31 +122,31 @@ def k_level_condition_upper_tuple(k_idx, k_level):
@pytest.mark.parametrize(
"fun, k_level, inp_function, ref_function",
[
(
pytest.param(
k_level_condition_lower,
lambda inp: 0,
lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)),
lambda inp: np.concatenate([[0], inp[:-1]]),
),
(
pytest.param(
k_level_condition_upper,
lambda inp: inp.shape[0] - 1,
lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)),
lambda inp: np.concatenate([inp[1:], [0]]),
),
(
pytest.param(
k_level_condition_upper_tuple,
lambda inp: inp[0].shape[0] - 1,
lambda k_size: (
gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)),
gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)),
),
lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]),
marks=pytest.mark.uses_tuple_iterator,
),
],
)
@pytest.mark.uses_tuple_args
@pytest.mark.uses_tuple_iterator
def test_k_level_condition(program_processor, fun, k_level, inp_function, ref_function):
program_processor, validate = program_processor

Expand Down

0 comments on commit 2c1156b

Please sign in to comment.