diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 980b2a8fdf..d1bf8fe266 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -498,31 +498,42 @@ def _gt_map_strides_into_nested_sdfg( inner_shape = inner_desc.shape inner_strides_init = inner_desc.strides + outer_shape = outer_desc.shape outer_strides = outer_desc.strides outer_inflow = outer_subset.size() - new_strides: list = [] - for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): - if dim_oinflow == 1: - # This is the case of implicit slicing along one dimension. - pass - else: - # There is inflow into the SDFG, so we need the stride. - new_strides.append(dim_ostride) - assert len(new_strides) <= len(inner_shape) - - # If we have a scalar on the inside, then there is nothing to adjust. - # We could have performed the test above, but doing it here, gives us - # the chance of validating it. if isinstance(inner_desc, dace_data.Scalar): - if len(new_strides) != 0: - raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") + # A scalar does not have a stride that must be propagated. return - if not isinstance(inner_desc, dace_data.Array): - raise TypeError( - f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." - ) + # Now determine the new stride that is needed on the inside. + new_strides: list = [] + if len(outer_shape) == len(inner_shape): + # The inner and the outer descriptor have the same dimensionality. + # We now have to decide if we should take the stride from the outside, + # which happens for example in case of `A[0:N, 0:M] -> B[N, M]`, or if we + # must take 1, which happens if we do `A[0:N, i] -> B[N, 1]`, we detect that + # based on the volume that flows in. + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + new_strides.append(1 if dim_oinflow == 1 else dim_ostride) + + elif len(inner_shape) < len(outer_shape): + # There are less dimensions on the inside than on the outside. This means + # that some were sliced away. We detect this case by checking if the Memlet + # subset in that dimension has size 1. + # NOTE: That this is not always correct as it might be possible that there + # are some explicit size 1 dimensions at several places. + new_strides = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + if dim_oinflow == 1: + pass + else: + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + else: + # The case that we have more dimensions on the inside than on the outside. + # This is currently not supported. + raise NotImplementedError("NestedSDFGs can not be used to increase the rank.") if len(new_strides) != len(inner_shape): raise ValueError("Failed to compute the inner strides.") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 5b16e41bc3..19b33d0bef 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -539,3 +539,99 @@ def ref(a1, b1): ref(**ref_args) sdfg_level1(**res_args) assert np.allclose(ref_args["b1"], res_args["b1"]) + + +def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_stride_1_nsdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg_level1.add_array( + name, + shape=(10, 1), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i, 0]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("b[__i, 0]")}, + external_edges=True, + ) + sdfg_level1.validate() + return sdfg_level1 + + +def _make_strides_propagation_stride_1_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg = dace.SDFG(util.unique_name("strides_propagation_stride_1_sdfg")) + state = sdfg.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg.add_array( + name, + shape=(10, 10), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + # Now get the nested SDFG. + sdfg_level1 = _make_strides_propagation_stride_1_nsdfg() + + nsdfg = state.add_nested_sdfg( + parent=sdfg, + sdfg=sdfg_level1, + inputs={"a"}, + outputs={"b"}, + symbol_mapping=None, + ) + + state.add_edge(state.add_access("a"), None, nsdfg, "a", dace.Memlet("a[0:10, 3]")) + state.add_edge(nsdfg, "b", state.add_access("b"), None, dace.Memlet("b[0:10, 2]")) + sdfg.validate() + return sdfg, nsdfg + + +def test_strides_propagation_stride_1(): + def ref(a, b): + for i in range(10): + b[i, 2] = a[i, 3] + 10.0 + + sdfg, nsdfg = _make_strides_propagation_stride_1_sdfg() + + outer_desc_a = sdfg.arrays["a"] + inner_desc_a = nsdfg.sdfg.arrays["a"] + assert outer_desc_a.strides == inner_desc_a.strides + + # Now switch the strides of `a` on the top level. + # Essentially going from `C` to FORTRAN order. + stride_outer_a_0, stride_outer_a_1 = outer_desc_a.strides + outer_desc_a.set_shape(outer_desc_a.shape, (stride_outer_a_1, stride_outer_a_0)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of(sdfg=sdfg, data_name="a") + + # Because of the propagation it must now been changed to `(1, 1)` on the inside. + assert inner_desc_a.strides == (1, 1) + + res_args = { + "a": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + "b": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + sdfg(**res_args, a_stride=10, b_stride=10) + ref(**ref_args) + assert np.allclose(ref_args["b"], res_args["b"])