diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index edd56fad48..3abf49788f 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -86,6 +86,7 @@ def _type_string(type_: ts.TypeSpec) -> str: return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): ndims = len(type_.dims) + # cannot be ListType: the concept is represented as Field with local Dimension in this interface assert isinstance(type_.dtype, ts.ScalarType) dtype = cpp_interface.render_scalar_type(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 020b1f55ea..48f15acffb 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -135,8 +135,9 @@ def _process_connectivity_args( # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" - f"generated::{connectivity_type.source_dim.value}_t, " - f"generated::{name}_t, {connectivity_type.max_neighbors}" + f"generated::{connectivity_type.domain[0].value}_t, " + f"generated::{connectivity_type.domain[1].value}_t, " + f"{connectivity_type.max_neighbors}" f">(std::forward({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))" ) arg_exprs.append( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index f7bb1805e0..3dc7998a54 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -209,6 +209,10 @@ def _collect_offset_definitions( ): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) + if offset_name != connectivity_type.neighbor_dim.value: + offset_definitions[connectivity_type.neighbor_dim.value] = TagDefinition( + name=Sym(id=connectivity_type.neighbor_dim.value) + ) for dim in [connectivity_type.source_dim, connectivity_type.codomain]: if dim.kind != common.DimensionKind.HORIZONTAL: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 10895ce66e..baddb7b699 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -290,6 +290,7 @@ def _add_storage( # represent zero-dimensional fields as scalar arguments return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) # handle default case: field with one or more dimensions + # ListType not supported: concept is represented as Field with local Dimension assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) if tuple_name is None: diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 060d56aea2..c1c0f0b5e1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -82,7 +82,7 @@ def __str__(self) -> str: class ListType(DataType): """Represents a neighbor list in the ITIR representation. - Note: not used in the frontend. + Note: not used in the frontend. The concept is represented as Field with local Dimension. """ element_type: DataType diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 759cd1cf1f..8a78307f87 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,6 +28,7 @@ common, constructors, field_utils, + utils as gt_utils, ) from gt4py.next.ffront import decorator from gt4py.next.type_system import type_specifications as ts, type_translation @@ -55,7 +56,6 @@ mesh_descriptor, ) -from gt4py.next import utils as gt_utils # mypy does not accept [IDim, ...] as a type diff --git a/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py new file mode 100644 index 0000000000..f95ed4c3a7 --- /dev/null +++ b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py @@ -0,0 +1,63 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py import next as gtx +from gt4py.next import Dims, Field, common + +from next_tests import definitions as test_defs +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests import ffront_test_utils + + +V = gtx.Dimension("V") +E = gtx.Dimension("E") +Neigh = gtx.Dimension("Neigh", kind=common.DimensionKind.LOCAL) +Off = gtx.FieldOffset("Off", source=E, target=(V, Neigh)) + + +@pytest.fixture +def case(): + mesh = ffront_test_utils.simple_mesh() + exec_alloc_descriptor = test_defs.ProgramBackendId.GTFN_CPU.load() + v2e_arr = mesh.offset_provider["V2E"].ndarray + return cases.Case( + exec_alloc_descriptor, + offset_provider={ + "Off": common._connectivity( + v2e_arr, + codomain=E, + domain={V: v2e_arr.shape[0], Neigh: 4}, + skip_value=None, + ), + }, + default_sizes={ + V: mesh.num_vertices, + E: mesh.num_edges, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=exec_alloc_descriptor.allocator, + ) + + +def test_offset_dimension_name_differ(case): + """ + Ensure that gtfn works with offset name that differs from the name of the local dimension. + + If the value of the `NeighborConnectivityType.neighbor_dim` did not match the `FieldOffset` value, + gtfn would silently ignore the neighbor index, see https://github.com/GridTools/gridtools/pull/1814. + """ + + @gtx.field_operator + def foo(a: Field[Dims[E], float]) -> Field[Dims[V], float]: + return a(Off[1]) + + cases.verify_with_default_data( + case, foo, lambda a: a[case.offset_provider["Off"].ndarray[:, 1]] + )