diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 29395a30c1..3e96ef3cec 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Optional, Sequence +from typing import Final, Literal, Optional, Sequence import dace @@ -51,12 +51,16 @@ def connectivity_identifier(name: str) -> str: return f"connectivity_{name}" +def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str: + return f"__{field_name}_{sym}_{axis}" + + def field_size_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_size_{axis}" + return field_symbol_name(field_name, axis, "size") def field_stride_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_stride_{axis}" + return field_symbol_name(field_name, axis, "stride") def is_field_symbol(name: str) -> bool: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 69aedf44d6..60dcd8ddc9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace import dace.subsets as sbs @@ -33,6 +33,34 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg +def _get_domain_indices( + dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None +) -> sbs.Indices: + """ + Helper function to construct the list of indices for a field domain, applying + an optional offset in each dimension as start index. + + Args: + dims: The field dimensions. + offsets: The range start index in each dimension. + + Returns: + A list of indices for field access in dace arrays. As this list is returned + as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before + being used in memlet subset because ranges are better supported throughout DaCe. + """ + index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] + if offsets is None: + return sbs.Indices(index_variables) + else: + return sbs.Indices( + [ + index - offset if offset != 0 else index + for index, offset in zip(index_variables, offsets, strict=True) + ] + ) + + @dataclasses.dataclass(frozen=True) class FieldopData: """ @@ -45,42 +73,59 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. gt_type: GT4Py type definition, which includes the field domain information. + offset: List of index offsets, in each dimension, when the dimension range + does not start from zero; assume zero offset, if not set. """ dc_node: dace.nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType + offset: Optional[list[dace.symbolic.SymExpr]] + + def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData: + """Create a copy of this data descriptor with a different access node.""" + assert data_node != self.dc_node + return FieldopData(data_node, self.gt_type, self.offset) def get_local_view( self, domain: FieldopDomain ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: - """Helper method to access a field in local view, given a field operator domain.""" + """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { - dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) - for dim, _, _ in domain + domain_dims = [dim for dim, _, _ in domain] + domain_indices = _get_domain_indices(domain_dims) + it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) + for dim, index in zip(domain_dims, domain_indices) } + field_domain = [ + (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) + for i, dim in enumerate(self.gt_type.dims) + ] local_dims = [ dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL ] - if len(local_dims) == 0: return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices + self.dc_node, self.gt_type.dtype, field_domain, it_indices ) elif len(local_dims) == 1: field_dtype = itir_ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) - field_dims = [ - dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL + field_domain = [ + (dim, offset) + for dim, offset in field_domain + if dim.kind != gtx_common.DimensionKind.LOCAL ] - return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices) + return gtir_dataflow.IteratorExpr( + self.dc_node, field_dtype, field_domain, it_indices + ) else: raise ValueError( @@ -155,9 +200,9 @@ def _parse_fieldop_arg( return arg.get_local_view(domain) -def _get_field_shape( +def _get_field_layout( domain: FieldopDomain, -) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]: +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: """ Parse the field operator domain and generates the shape of the result field. @@ -174,11 +219,14 @@ def _get_field_shape( domain: The field operator domain. Returns: - A tuple of two lists: the list of field dimensions and the list of dace - array sizes in each dimension. + A tuple of three lists containing: + - the domain dimensions + - the domain offset in each dimension + - the domain size in each dimension """ - domain_dims, _, domain_ubs = zip(*domain) - return list(domain_dims), list(domain_ubs) + domain_dims, domain_lbs, domain_ubs = zip(*domain) + domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] + return list(domain_dims), list(domain_lbs), domain_sizes def _create_temporary_field( @@ -189,7 +237,7 @@ def _create_temporary_field( dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" - field_dims, field_shape = _get_field_shape(domain) + field_dims, field_offset, field_shape = _get_field_layout(domain) output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): @@ -197,6 +245,7 @@ def _create_temporary_field( assert isinstance(node_type.dtype.element_type, ts.ScalarType) assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + field_offset.extend(output_desc.offset) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) @@ -215,7 +264,11 @@ def _create_temporary_field( assert dataflow_output.result.gt_dtype.offset_type is not None field_dims.append(dataflow_output.result.gt_dtype.offset_type) - return FieldopData(field_node, ts.FieldType(field_dims, field_dtype)) + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -285,7 +338,8 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + domain_dims, domain_offsets, _ = zip(*domain) + domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] @@ -350,10 +404,8 @@ def translate_broadcast_scalar( assert cpm.is_ref_to(stencil_expr, "deref") domain = extract_domain(domain_expr) - field_dims, field_shape = _get_field_shape(domain) - field_subset = sbs.Range.from_string( - ",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims) - ) + output_dims, output_offset, output_shape = _get_field_layout(domain) + output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) assert len(node.args) == 1 scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) @@ -369,26 +421,15 @@ def translate_broadcast_scalar( assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) if len(node.args[0].type.dims) == 0: # zero-dimensional field input_subset = "0" - elif all( - isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) - for dim in scalar_expr.dimensions - if dim not in field_dims - ): - input_subset = ",".join( - dace_gtir_utils.get_map_variable(dim) - if dim in field_dims - else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above - for dim in scalar_expr.dimensions - ) else: - raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") + input_subset = scalar_expr.get_memlet_subset(sdfg) input_node = scalar_expr.field gt_dtype = node.args[0].type.dtype else: raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype) + output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) output_node = state.add_access(output) sdfg_builder.add_mapped_tasklet( @@ -400,13 +441,13 @@ def translate_broadcast_scalar( }, inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)}, + outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, input_nodes={input_node.data: input_node}, output_nodes={output_node.data: output_node}, external_edges=True, ) - return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype)) + return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) def translate_if( @@ -467,7 +508,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData: outer, _ = sdfg.add_temp_transient_like(inner_desc) outer_node = state.add_access(outer) - return FieldopData(outer_node, inner_data.gt_type) + return inner_data.make_copy(outer_node) result_temps = gtx_utils.tree_map(construct_output)(true_br_args) @@ -513,7 +554,7 @@ def _get_data_nodes( ) -> FieldopResult: if isinstance(data_type, ts.FieldType): data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.ScalarType): if data_name in sdfg.symbols: @@ -522,7 +563,7 @@ def _get_data_nodes( ) else: data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) @@ -579,7 +620,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type) + return FieldopData(data_node, data_type, offset=None) def translate_make_tuple( @@ -708,7 +749,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type) + return FieldopData(temp_node, node.type, offset=None) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 74142dec66..cfba4d61e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -90,17 +90,42 @@ class IteratorExpr: Args: field: Access node to the field this iterator operates on. gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. - dimensions: Field domain represented as a sorted list of dimensions, needed - to order the map index variables and dereference an element in the field. + field_domain: Field domain represented as a sorted list of dimensions and offset values, + used to find the position of a map index variable in the memlet subset. The offset + value is either the start index of dimension range or the compile-time value of + a shift expression, or a composition of both, and it must be subtracted to the index + variable when constructing the memlet subset range. indices: Maps each dimension to an index value, which could be either a symbolic value or the result of a tasklet computation like neighbors connectivity or dynamic offset. """ field: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - dimensions: list[gtx_common.Dimension] + field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] + def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): + raise ValueError(f"Cannot deref iterator {self}.") + + field_desc = self.field.desc(sdfg) + if isinstance(self.gt_dtype, itir_ts.ListType): + assert len(field_desc.shape) == len(self.field_domain) + 1 + assert self.gt_dtype.offset_type is not None + field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] + else: + assert len(field_desc.shape) == len(self.field_domain) + field_domain = self.field_domain + + return sbs.Range.from_string( + ",".join( + str(self.indices[dim].value - offset) # type: ignore[union-attr] + if dim in self.indices + else f"0:{size}" + for (dim, offset), size in zip(field_domain, field_desc.shape, strict=True) + ) + ) + class DataflowInputEdge(Protocol): """ @@ -271,8 +296,17 @@ def _add_input_data_edge( src_subset: sbs.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, + src_offset: Optional[list[dace.symbolic.SymExpr]] = None, ) -> None: - edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) + input_subset = ( + src_subset + if src_offset is None + else sbs.Range( + (start - off, stop - off, step) + for (start, stop, step), off in zip(src_subset, src_offset, strict=True) + ) + ) + edge = MemletInputEdge(self.state, src, input_subset, dst_node, dst_conn) self.input_edges.append(edge) def _add_edge( @@ -440,34 +474,21 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: field_desc = arg_expr.field.desc(self.sdfg) if isinstance(field_desc, dace.data.Scalar): # deref a zero-dimensional field - assert len(arg_expr.dimensions) == 0 + assert len(arg_expr.field_domain) == 0 assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") # default case: deref a field with one or more dimensions if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - if isinstance(arg_expr.gt_dtype, itir_ts.ListType): - assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 - assert arg_expr.gt_dtype.offset_type is not None - field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] - else: - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_dims = arg_expr.dimensions - - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(field_dims, field_desc.shape) - ) + # when all indices are symbolic expressions, we can perform direct field access through a memlet + field_subset = arg_expr.get_memlet_subset(self.sdfg) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) + assert len(field_desc.shape) == len(arg_expr.field_domain) + field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] index_connectors = [ IndexConnectorFmt.format(dim=dim.value) for dim, index in field_indices @@ -494,6 +515,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: sbs.Range.from_array(field_desc), deref_node, "field", + src_offset=[offset for (_, offset) in arg_expr.field_domain], ) for dim, index_expr in field_indices: @@ -532,7 +554,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.codomain in it.dimensions + assert any(dim == offset_provider.codomain for dim, _ in it.field_domain) assert offset_provider.source_dim in it.indices origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) @@ -560,10 +582,12 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=node.type, subset=sbs.Range.from_string( ",".join( - it.indices[dim].value # type: ignore[union-attr] + str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + for (dim, offset), size in zip( + it.field_domain, field_desc.shape, strict=True + ) ) ), ) @@ -971,14 +995,13 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert offset_dim in it.dimensions + assert any(dim == offset_dim for dim, _ in it.field_domain) new_index: SymbolExpr | ValueExpr - assert offset_dim in it.indices index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( - dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value, + index_expr.value + offset_expr.value, index_expr.dc_dtype, ) else: @@ -1032,15 +1055,10 @@ def _make_cartesian_shift( ) # a new iterator with a shifted index along one dimension - return IteratorExpr( - field=it.field, - gt_dtype=it.gt_dtype, - dimensions=it.dimensions, - indices={ - dim: (new_index if dim == offset_dim else index) - for dim, index in it.indices.items() - }, - ) + shifted_indices = { + dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() + } + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _make_dynamic_neighbor_offset( self, @@ -1094,7 +1112,7 @@ def _make_unstructured_shift( offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.codomain in it.dimensions + assert any(dim == connectivity.codomain for dim, _ in it.field_domain) neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices @@ -1117,9 +1135,7 @@ def _make_unstructured_shift( offset_expr, offset_table_node, origin_index ) - return IteratorExpr( - field=it.field, gt_dtype=it.gt_dtype, dimensions=it.dimensions, indices=shifted_indices - ) + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 52284edfac..f15287e64c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -16,6 +16,7 @@ import abc import dataclasses +import functools import itertools import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union @@ -98,9 +99,16 @@ def add_mapped_tasklet( class SDFGBuilder(DataflowBuilder, Protocol): """Visitor interface available to GTIR-primitive translators.""" + @abc.abstractmethod + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + """Retrieve the field data descriptor including the domain offset information.""" + ... + @abc.abstractmethod def get_symbol_type(self, symbol_name: str) -> ts.DataType: - """Retrieve the GT4Py type of a symbol used in the program.""" + """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... @abc.abstractmethod @@ -141,6 +149,15 @@ def _collect_symbols_in_domain_expressions( ) +def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + """ + return ts.TupleType( + types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -157,6 +174,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( + default_factory=lambda: {} + ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) @@ -167,6 +187,15 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: return self.offset_provider_type[offset] + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + if isinstance(data_type, ts.FieldType): + domain_offset = self.field_offsets.get(data_node.data, None) + else: + domain_offset = None + return gtir_builtin_translators.FieldopData(data_node, data_type, domain_offset) + def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -248,12 +277,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( - name, gt_type, flatten=True - ): + for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): tuple_fields.extend( self._add_storage( - sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name + sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name ) ) return tuple_fields @@ -275,7 +302,6 @@ def _add_storage( tuple_name, gt_type.dims ) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) - return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): @@ -344,7 +370,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return gtir_builtin_translators.FieldopData(temp_node, field.gt_type) + return field.make_copy(temp_node) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -405,6 +431,10 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + # Since program field arguments are passed to the SDFG as full-shape arrays, + # there is no offset that needs to be compensated. + assert len(self.field_offsets) == 0 + sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) @@ -459,7 +489,7 @@ def visit_SetAt( The SDFG head state, eventually updated if the target write requires a new state. """ - temp_fields = self._visit_expression(stmt.expr, sdfg, state) + source_fields = self._visit_expression(stmt.expr, sdfg, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field @@ -482,17 +512,26 @@ def visit_SetAt( } target_state: Optional[dace.SDFGState] = None - for temp, target in zip(temp_fields, target_fields, strict=True): + for source, target in zip(source_fields, target_fields, strict=True): target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient if isinstance(target.gt_type, ts.FieldType): - subset = ",".join( + target_subset = ",".join( f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims ) + source_subset = ( + target_subset + if source.offset is None + else ",".join( + f"{domain[dim][0] - offset}:{domain[dim][1] - offset}" + for dim, offset in zip(target.gt_type.dims, source.offset, strict=True) + ) + ) else: assert len(domain) == 0 - subset = "0" + target_subset = "0" + source_subset = "0" if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state @@ -501,17 +540,21 @@ def visit_SetAt( target_state = sdfg.add_state_after(state, f"post_{state.label}") # create new access nodes in the target state target_state.add_nedge( - target_state.add_access(temp.dc_node.data), + target_state.add_access(source.dc_node.data), target_state.add_access(target.dc_node.data), - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) # remove isolated access node state.remove_node(target.dc_node) else: state.add_nedge( - temp.dc_node, + source.dc_node, target.dc_node, - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) return target_state or state @@ -574,17 +617,65 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] + def flatten_tuples( + name: str, + arg: gtir_builtin_translators.FieldopResult, + ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: + if isinstance(arg, tuple): + tuple_type = _get_tuple_type(arg) + tuple_field_names = [ + arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) + ] + tuple_args = zip(tuple_field_names, arg, strict=True) + return list( + itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) + ) + else: + return [(name, arg)] + + lambda_arg_nodes = dict( + itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + ) + # inherit symbols from parent scope but eventually override with local symbols lambda_symbols = { sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type for pname, arg in lambda_args_mapping } + def get_field_domain_offset( + p_name: str, p_type: ts.DataType + ) -> dict[str, Optional[list[dace.symbolic.SymExpr]]]: + if isinstance(p_type, ts.FieldType): + if p_name in lambda_arg_nodes: + arg = lambda_arg_nodes[p_name] + assert isinstance(arg, gtir_builtin_translators.FieldopData) + return {p_name: arg.offset} + elif field_domain_offset := self.field_offsets.get(p_name, None): + return {p_name: field_domain_offset} + elif isinstance(p_type, ts.TupleType): + p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + return functools.reduce( + lambda field_offsets, field: ( + field_offsets | get_field_domain_offset(field[0], field[1]) + ), + p_fields, + {}, + ) + return {} + + # populate mapping from field name to domain offset + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} + for p_name, p_type in lambda_symbols.items(): + lambda_field_offsets |= get_field_domain_offset(p_name, p_type) + # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) + lambda_translator = GTIRToSDFG( + self.offset_provider_type, lambda_symbols, lambda_field_offsets + ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -603,30 +694,11 @@ def visit_Lambda( head_state=nstate, ) - def _flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = dace_gtir_utils.get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[_flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - # Process lambda inputs # # All input arguments are passed as parameters to the nested SDFG, therefore # we they are stored as non-transient array and scalar objects. # - lambda_arg_nodes = dict( - itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) - ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) @@ -739,7 +811,7 @@ def construct_output_for_nested_sdfg( head_state.add_edge( nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned @@ -748,7 +820,7 @@ def construct_output_for_nested_sdfg( outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index caec6cd87e..118f0449c8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import itertools -from typing import Any, Dict, TypeVar +from typing import Dict, TypeVar import dace @@ -58,15 +58,6 @@ def get_tuple_fields( return fields -def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: """ Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 9c52ea81c3..f5191fbaaa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -47,7 +47,7 @@ VFTYPE = ts.FieldType(dims=[Vertex], dtype=FLOAT_TYPE) V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { - "IDim": IDim, + IDim.value: IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() @@ -735,13 +735,13 @@ def test_gtir_cartesian_shift_left(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -749,13 +749,15 @@ def test_gtir_cartesian_shift_left(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -764,14 +766,14 @@ def test_gtir_cartesian_shift_left(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -828,13 +830,13 @@ def test_gtir_cartesian_shift_right(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", -OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, -OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", -OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -842,13 +844,15 @@ def test_gtir_cartesian_shift_right(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -857,14 +861,14 @@ def test_gtir_cartesian_shift_right(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -1539,6 +1543,91 @@ def test_gtir_reduce_with_cond_neighbors(): assert np.allclose(v, v_ref) +def test_gtir_symbolic_domain(): + MARGIN = 2 + assert MARGIN < N + OFFSET = 1000 * 1000 * 1000 + domain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + left_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, + ) + right_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, + ) + shift_left_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))) + shift_right_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))) + testee = gtir.Program( + id="symbolic_domain", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let( + "xᐞ1", + im.op_as_fieldop("multiplies", left_domain)( + 4.0, + im.as_fieldop( + shift_left_stencil, + left_domain, + )("x"), + ), + )( + im.let( + "xᐞ2", + im.op_as_fieldop("multiplies", right_domain)( + 3.0, + im.as_fieldop( + shift_right_stencil, + right_domain, + )("x"), + ), + )( + im.let( + "xᐞ3", + im.as_fieldop( + shift_right_stencil, + domain, + )("xᐞ1"), + )( + im.let( + "xᐞ4", + im.as_fieldop( + shift_left_stencil, + domain, + )("xᐞ2"), + )( + im.let("xᐞ5", im.op_as_fieldop("plus", domain)("xᐞ3", "xᐞ4"))( + im.op_as_fieldop("plus", domain)("xᐞ5", "x") + ) + ) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, **FSYMBOLS) + assert np.allclose(b, ref) + + def test_gtir_let_lambda(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) subdomain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) @@ -1722,7 +1811,7 @@ def test_gtir_let_lambda_with_cond(): def test_gtir_let_lambda_with_tuple1(): - domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) testee = gtir.Program( id="let_lambda_with_tuple1", function_definitions=[], @@ -1753,10 +1842,12 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) + a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) + b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) sdfg(a, b, *z_fields, **FSYMBOLS) - assert np.allclose(z_fields[0], a) - assert np.allclose(z_fields[1], b) + assert np.allclose(z_fields[0], a_ref) + assert np.allclose(z_fields[1], b_ref) def test_gtir_let_lambda_with_tuple2():