From 3c463a62a98ef44cd47b52d9752fcb06f2066c49 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 31 Oct 2023 10:53:59 +0100 Subject: [PATCH] fix[next]: Improvements in DaCe backend (#1354) This PR contains some fixes and code refactoring in DaCe backend: * (refactoring) Use memlet API for full array subset * Fix for gpu execution: import cupy for sorting of field dimensions. * Fix for symbolic analysis of memlet volume: define symbols before visiting the closure domain in order to allow symbolic analysis of memlet volume --- .../runners/dace_iterator/__init__.py | 37 +++++++++++-------- .../runners/dace_iterator/itir_to_sdfg.py | 6 +-- .../runners/dace_iterator/itir_to_tasklet.py | 33 ++++++++--------- .../runners/dace_iterator/utility.py | 6 +-- .../ffront_tests/test_gpu_backend.py | 4 +- 5 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 1c1bed9c5e..be63d6809d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -31,6 +31,12 @@ from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims +try: + import cupy as cp +except ImportError: + cp = None + + def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: sorted_dims = get_sorted_dims(domain.dims) return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] @@ -49,8 +55,11 @@ def convert_arg(arg: Any): sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) dim_indices = [dim_index for dim_index, _ in sorted_dims] - assert isinstance(arg.ndarray, np.ndarray) - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) + if isinstance(arg.ndarray, np.ndarray): + return np.moveaxis(arg.ndarray, range(ndim), dim_indices) + else: + assert cp is not None and isinstance(arg.ndarray, cp.ndarray) + return cp.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg @@ -226,24 +235,22 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: @program_executor -def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: - run_dace_iterator( - program, - *args, - **kwargs, - build_cache=_build_cache_cpu, - build_type=_build_type, - run_on_gpu=False, - ) +def run_dace(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_on_gpu = any(not isinstance(arg.ndarray, np.ndarray) for arg in args if is_field(arg)) + if run_on_gpu: + if cp is None: + raise RuntimeError( + f"Non-numpy field argument passed to program {program.id} but module cupy not installed" + ) + if not all(isinstance(arg.ndarray, cp.ndarray) for arg in args if is_field(arg)): + raise RuntimeError("Execution on GPU requires all fields to be stored as cupy arrays") -@program_executor -def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_iterator( program, *args, **kwargs, - build_cache=_build_cache_gpu, + build_cache=_build_cache_gpu if run_on_gpu else _build_cache_cpu, build_type=_build_type, - run_on_gpu=True, + run_on_gpu=run_on_gpu, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 1f9692356e..9e9cc4bf29 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -258,9 +258,9 @@ def visit_StencilClosure( # Update symbol table and get output domain of the closure for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): + dtype = as_dace_type(type_) + closure_sdfg.add_symbol(name, dtype) if name in input_names: - dtype = as_dace_type(type_) - closure_sdfg.add_symbol(name, dtype) out_name = unique_var_name() closure_sdfg.add_scalar(out_name, dtype, transient=True) out_tasklet = closure_init_state.add_tasklet( @@ -272,7 +272,7 @@ def visit_StencilClosure( closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: - program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_)) + program_arg_syms[name] = SymbolExpr(name, dtype) closure_domain = self._visit_domain(node.domain, closure_ctx) # Map SDFG tasklet arguments to parameters diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 1634596afa..5d47cad909 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -34,6 +34,7 @@ add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, + create_memlet_at, create_memlet_full, filter_neighbor_tables, flatten_list, @@ -199,7 +200,6 @@ def builtin_neighbors( result_access = state.add_access(result_name) table_name = connectivity_identifier(offset_dim) - table_array = sdfg.arrays[table_name] # generate unique map index name to avoid conflict with other maps inside same state index_name = unique_name("__neigh_idx") @@ -225,14 +225,14 @@ def builtin_neighbors( state.add_access(table_name), me, shift_tasklet, - memlet=dace.Memlet(data=table_name, subset=",".join(f"0:{s}" for s in table_array.shape)), + memlet=create_memlet_full(table_name, sdfg.arrays[table_name]), dst_conn="__table", ) state.add_memlet_path( iterator.indices[shifted_dim], me, shift_tasklet, - memlet=dace.Memlet(data=iterator.indices[shifted_dim].data, subset="0"), + memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), dst_conn="__idx", ) state.add_edge( @@ -240,28 +240,25 @@ def builtin_neighbors( "__result", data_access_tasklet, "__idx", - dace.Memlet(data=idx_name, subset="0"), + dace.Memlet.simple(idx_name, "0"), ) # select full shape only in the neighbor-axis dimension - field_subset = [ + field_subset = tuple( f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) - ] + ) state.add_memlet_path( iterator.field, me, data_access_tasklet, - memlet=dace.Memlet( - data=iterator.field.data, - subset=",".join(field_subset), - ), + memlet=create_memlet_at(iterator.field.data, field_subset), dst_conn="__field", ) state.add_memlet_path( data_access_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset=index_name), + memlet=dace.Memlet.simple(result_name, index_name), src_conn="__result", ) @@ -438,7 +435,7 @@ def visit_Lambda( result_access, None, # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), + dace.Memlet.simple(result_access.data, "0", wcr_str=context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: @@ -616,7 +613,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: deref_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset=index_name), + memlet=dace.Memlet.simple(result_name, index_name), src_conn="__result", ) @@ -738,13 +735,13 @@ def _visit_reduce(self, node: itir.FunCall): assert isinstance(op_name, itir.SymRef) init = node.fun.args[1] - nreduce = self.context.body.arrays[neighbors_expr.value.data].shape[0] + reduce_array_desc = neighbors_expr.value.desc(self.context.body) self.context.body.add_scalar(result_name, result_dtype, transient=True) op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") reduce_tasklet = self.context.state.add_tasklet( "reduce", - code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", + code=f"__result = {init}\nfor __idx in range({reduce_array_desc.shape[0]}):\n __result = {op_str}", inputs={"__values"}, outputs={"__result"}, ) @@ -753,14 +750,14 @@ def _visit_reduce(self, node: itir.FunCall): None, reduce_tasklet, "__values", - dace.Memlet(data=neighbors_expr.value.data, subset=f"0:{nreduce}"), + create_memlet_full(neighbors_expr.value.data, reduce_array_desc), ) self.context.state.add_edge( reduce_tasklet, "__result", result_access, None, - dace.Memlet(data=result_name, subset="0"), + dace.Memlet.simple(result_name, "0"), ) else: assert isinstance(node.fun, itir.FunCall) @@ -973,7 +970,7 @@ def closure_to_tasklet_sdfg( tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") access = state.add_access(name) idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) + state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) for name, ty in inputs: if isinstance(ty, ts.FieldType): ndim = len(ty.dims) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 1fdd022a49..c17a39ef2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -48,14 +48,12 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): - bounds = [(0, size) for size in source_array.shape] - subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet.simple(source_identifier, subset) + return dace.Memlet.from_array(source_identifier, source_array) def create_memlet_at(source_identifier: str, index: tuple[str, ...]): subset = ", ".join(index) - return dace.Memlet(data=source_identifier, subset=subset) + return dace.Memlet.simple(source_identifier, subset) def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py index 290cece3fa..381cc740c5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py @@ -16,7 +16,7 @@ import gt4py.next as gtx from gt4py.next.iterator import embedded -from gt4py.next.program_processors.runners import gtfn +from gt4py.next.program_processors.runners import dace_iterator, gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case # noqa: F401 @@ -26,7 +26,7 @@ @pytest.mark.requires_gpu -@pytest.mark.parametrize("fieldview_backend", [gtfn.run_gtfn_gpu]) +@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace, gtfn.run_gtfn_gpu]) def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures import cupy as cp # TODO(ricoh): replace with storages solution when available