Skip to content

Commit

Permalink
bug[next]: Fix codegen in gtfn for unused vertical offset provider (#…
Browse files Browse the repository at this point in the history
…1746)

Providing an offest provider for a vertical dimension without using that
dimension in a program, e.g. no arguments are fields defined on K,
resulted in erroneous C++ code.
  • Loading branch information
nfarabullini authored Dec 3, 2024
1 parent f57d6e9 commit e5abcd2
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def _collect_offset_definitions(
"Mapping an offset to a horizontal dimension in unstructured is not allowed."
)
# create alias from vertical offset to vertical dimension
offset_definitions[dim.value] = TagDefinition(
name=Sym(id=dim.value), alias=_vertical_dimension
)
offset_definitions[offset_name] = TagDefinition(
name=Sym(id=offset_name), alias=SymRef(id=dim.value)
)
Expand Down
10 changes: 9 additions & 1 deletion tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,21 @@ def unstructured_case(
Vertex: mesh_descriptor.num_vertices,
Edge: mesh_descriptor.num_edges,
Cell: mesh_descriptor.num_cells,
KDim: 10,
},
grid_type=common.GridType.UNSTRUCTURED,
allocator=exec_alloc_descriptor.allocator,
)


@pytest.fixture
def unstructured_case_3d(unstructured_case):
return dataclasses.replace(
unstructured_case,
default_sizes={**unstructured_case.default_sizes, KDim: 10},
offset_provider={**unstructured_case.offset_provider, "KOff": KDim},
)


def _allocate_from_type(
case: Case,
arg_type: ts.TypeSpec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Edge,
cartesian_case,
unstructured_case,
unstructured_case_3d,
)
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
exec_alloc_descriptor,
Expand Down Expand Up @@ -93,6 +94,20 @@ def testee(a: cases.VField) -> cases.EField:
)


def test_horizontal_only_with_3d_mesh(unstructured_case_3d):
# test field operator operating only on horizontal fields while using an offset provider
# including a vertical dimension.
@gtx.field_operator
def testee(a: cases.VField) -> cases.VField:
return a

cases.verify_with_default_data(
unstructured_case_3d,
testee,
ref=lambda a: a,
)


@pytest.mark.uses_unstructured_shift
def test_composed_unstructured_shift(unstructured_case):
@gtx.field_operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Vertex,
cartesian_case,
unstructured_case,
unstructured_case_3d,
)
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
exec_alloc_descriptor,
Expand Down Expand Up @@ -105,10 +106,10 @@ def reduction_ke_field(
@pytest.mark.parametrize(
"fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__
)
def test_neighbor_sum(unstructured_case, fop):
v2e_table = unstructured_case.offset_provider["V2E"].ndarray
def test_neighbor_sum(unstructured_case_3d, fop):
v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray

edge_f = cases.allocate(unstructured_case, fop, "edge_f")()
edge_f = cases.allocate(unstructured_case_3d, fop, "edge_f")()

local_dim_idx = edge_f.domain.dims.index(Edge) + 1
adv_indexing = tuple(
Expand All @@ -131,10 +132,10 @@ def test_neighbor_sum(unstructured_case, fop):
where=broadcasted_table != common._DEFAULT_SKIP_VALUE,
)
cases.verify(
unstructured_case,
unstructured_case_3d,
fop,
edge_f,
out=cases.allocate(unstructured_case, fop, cases.RETURN)(),
out=cases.allocate(unstructured_case_3d, fop, cases.RETURN)(),
ref=ref,
)

Expand Down Expand Up @@ -463,11 +464,13 @@ def conditional_program(
)


def test_promotion(unstructured_case):
def test_promotion(unstructured_case_3d):
@gtx.field_operator
def promotion(
inp1: gtx.Field[[Edge, KDim], float64], inp2: gtx.Field[[KDim], float64]
) -> gtx.Field[[Edge, KDim], float64]:
return inp1 / inp2

cases.verify_with_default_data(unstructured_case, promotion, ref=lambda inp1, inp2: inp1 / inp2)
cases.verify_with_default_data(
unstructured_case_3d, promotion, ref=lambda inp1, inp2: inp1 / inp2
)

0 comments on commit e5abcd2

Please sign in to comment.