Skip to content

Commit

Permalink
skip dace orchestration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 16, 2025
1 parent ad68fac commit a28fbf3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


# regex to match the symbols for field shape and strides
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"^__.+_(range_[01]|((size|stride)_\d+))$")
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"^__.+_((\d+_range_[01])|((size|stride)_\d+))$")


def as_dace_type(type_: ts.ScalarType) -> dace.typeclass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils


class CompiledDaceProgram(stages.CompiledProgram):
class CompiledDaceProgram(stages.ExtendedCompiledProgram):
sdfg_program: dace.CompiledSDFG

# Sorted list of SDFG arguments as they appear in program ABI and corresponding data type;
# scalar arguments that are not used in the SDFG will not be present.
sdfg_arglist: list[tuple[str, dace.dtypes.Data]]

def __init__(self, program: dace.CompiledSDFG):
def __init__(self, program: dace.CompiledSDFG, implicit_domain: bool):
self.sdfg_program = program
self.implicit_domain = implicit_domain
# `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument
# name to its data type, in the same order as arguments appear in the program ABI.
# This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`.
Expand Down Expand Up @@ -88,7 +89,7 @@ def __call__(
dace.config.Config.set("compiler", "cpu", "args", value=compiler_args)
sdfg_program = sdfg.compile(validate=False)

return CompiledDaceProgram(sdfg_program)
return CompiledDaceProgram(sdfg_program, inp.program_source.implicit_domain)


class DaCeCompilationStepFactory(factory.Factory):
Expand All @@ -113,9 +114,11 @@ def decorated_program(
if out is not None:
args = (*args, out)
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
if len(sdfg.arg_names) > len(flat_args):
# The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments.
flat_args = (*flat_args, *arguments.iter_size_args(args))
if inp.implicit_domain:
# generate implicit domain size arguments only if necessary
size_args = arguments.iter_size_args(args)
flat_size_args: Sequence[int] = gtx_utils.flatten_nested_tuple(tuple(size_args))
flat_args = (*flat_args, *flat_size_args)

if sdfg_program._lastargs:
kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811
if not cartesian_case.backend or "dace" not in cartesian_case.backend.name:
pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs")

# TODO(edopao): add support for range symbols in field domain and re-enable this test
pytest.skip("Requires support for field domain range.")

backend = cartesian_case.backend

in_field = cases.allocate(cartesian_case, laplap_program, "in_field")()
Expand Down Expand Up @@ -87,6 +90,9 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811
if not unstructured_case.backend or "dace" not in unstructured_case.backend.name:
pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs")

# TODO(edopao): add support for range symbols in field domain and re-enable this test
pytest.skip("Requires support for field domain range.")

allocator, backend = unstructured_case.allocator, unstructured_case.backend

if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE):
Expand Down

0 comments on commit a28fbf3

Please sign in to comment.