From 9e1e72139f55e0c9dda75fd32e318abfa88ce89e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 16 Jan 2025 15:01:56 +0100 Subject: [PATCH] This commit fixes the stride reconstruction during stride propagation. NestedSDFG essentially allows to perform some slices, there are technically three chases: - The data container on the inside has a smaller rank than the one on the outside, thus some dimensions were removed. - The data container on the inside has the same rank than the one on the outside. - The data container on the inside has a larger rank than the one on the outside, thus some dimensions were added. The last case is not handled, as it does not happens in GT4Py. Before, the first and second case were handled together, but it was realized that the second case was not implemented properly and it was added explicitly. --- .../dace_fieldview/transformations/strides.py | 49 ++++++---- .../transformation_tests/test_strides.py | 96 +++++++++++++++++++ 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 980b2a8fdf..d1bf8fe266 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -498,31 +498,42 @@ def _gt_map_strides_into_nested_sdfg( inner_shape = inner_desc.shape inner_strides_init = inner_desc.strides + outer_shape = outer_desc.shape outer_strides = outer_desc.strides outer_inflow = outer_subset.size() - new_strides: list = [] - for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): - if dim_oinflow == 1: - # This is the case of implicit slicing along one dimension. - pass - else: - # There is inflow into the SDFG, so we need the stride. - new_strides.append(dim_ostride) - assert len(new_strides) <= len(inner_shape) - - # If we have a scalar on the inside, then there is nothing to adjust. - # We could have performed the test above, but doing it here, gives us - # the chance of validating it. if isinstance(inner_desc, dace_data.Scalar): - if len(new_strides) != 0: - raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") + # A scalar does not have a stride that must be propagated. return - if not isinstance(inner_desc, dace_data.Array): - raise TypeError( - f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." - ) + # Now determine the new stride that is needed on the inside. + new_strides: list = [] + if len(outer_shape) == len(inner_shape): + # The inner and the outer descriptor have the same dimensionality. + # We now have to decide if we should take the stride from the outside, + # which happens for example in case of `A[0:N, 0:M] -> B[N, M]`, or if we + # must take 1, which happens if we do `A[0:N, i] -> B[N, 1]`, we detect that + # based on the volume that flows in. + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + new_strides.append(1 if dim_oinflow == 1 else dim_ostride) + + elif len(inner_shape) < len(outer_shape): + # There are less dimensions on the inside than on the outside. This means + # that some were sliced away. We detect this case by checking if the Memlet + # subset in that dimension has size 1. + # NOTE: That this is not always correct as it might be possible that there + # are some explicit size 1 dimensions at several places. + new_strides = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + if dim_oinflow == 1: + pass + else: + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + else: + # The case that we have more dimensions on the inside than on the outside. + # This is currently not supported. + raise NotImplementedError("NestedSDFGs can not be used to increase the rank.") if len(new_strides) != len(inner_shape): raise ValueError("Failed to compute the inner strides.") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 5b16e41bc3..19b33d0bef 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -539,3 +539,99 @@ def ref(a1, b1): ref(**ref_args) sdfg_level1(**res_args) assert np.allclose(ref_args["b1"], res_args["b1"]) + + +def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_stride_1_nsdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg_level1.add_array( + name, + shape=(10, 1), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i, 0]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("b[__i, 0]")}, + external_edges=True, + ) + sdfg_level1.validate() + return sdfg_level1 + + +def _make_strides_propagation_stride_1_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg = dace.SDFG(util.unique_name("strides_propagation_stride_1_sdfg")) + state = sdfg.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg.add_array( + name, + shape=(10, 10), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + # Now get the nested SDFG. + sdfg_level1 = _make_strides_propagation_stride_1_nsdfg() + + nsdfg = state.add_nested_sdfg( + parent=sdfg, + sdfg=sdfg_level1, + inputs={"a"}, + outputs={"b"}, + symbol_mapping=None, + ) + + state.add_edge(state.add_access("a"), None, nsdfg, "a", dace.Memlet("a[0:10, 3]")) + state.add_edge(nsdfg, "b", state.add_access("b"), None, dace.Memlet("b[0:10, 2]")) + sdfg.validate() + return sdfg, nsdfg + + +def test_strides_propagation_stride_1(): + def ref(a, b): + for i in range(10): + b[i, 2] = a[i, 3] + 10.0 + + sdfg, nsdfg = _make_strides_propagation_stride_1_sdfg() + + outer_desc_a = sdfg.arrays["a"] + inner_desc_a = nsdfg.sdfg.arrays["a"] + assert outer_desc_a.strides == inner_desc_a.strides + + # Now switch the strides of `a` on the top level. + # Essentially going from `C` to FORTRAN order. + stride_outer_a_0, stride_outer_a_1 = outer_desc_a.strides + outer_desc_a.set_shape(outer_desc_a.shape, (stride_outer_a_1, stride_outer_a_0)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of(sdfg=sdfg, data_name="a") + + # Because of the propagation it must now been changed to `(1, 1)` on the inside. + assert inner_desc_a.strides == (1, 1) + + res_args = { + "a": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + "b": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + sdfg(**res_args, a_stride=10, b_stride=10) + ref(**ref_args) + assert np.allclose(ref_args["b"], res_args["b"])