diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 234f5bd0f0..33719cfab6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -298,16 +298,17 @@ def get_tuple_type( """ Compute the `ts.TupleType` corresponding to the tuple structure of input data expressions. """ - return ts.TupleType( - types=[ - get_tuple_type(d) - if isinstance(d, tuple) - else d.get_field_type() - if isinstance(d, IteratorExpr) - else d.gt_dtype - for d in data - ] - ) + data_types: list[ts.DataType] = [] + for dataitem in data: + if isinstance(dataitem, tuple): + data_types.append(get_tuple_type(dataitem)) + elif isinstance(dataitem, IteratorExpr): + data_types.append(dataitem.get_field_type()) + elif isinstance(dataitem, MemletExpr): + data_types.append(dataitem.gt_dtype) + else: + data_types.append(dataitem.gt_dtype) + return ts.TupleType(data_types) @dataclasses.dataclass(frozen=True)