Skip to content

Commit

Permalink
This commit fixes the stride reconstruction during stride propagation.
Browse files Browse the repository at this point in the history
NestedSDFG essentially allows to perform some slices, there are technically three chases:
- The data container on the inside has a smaller rank than the one on the outside, thus some dimensions were removed.
- The data container on the inside has the same rank than the one on the outside.
- The data container on the inside has a larger rank than the one on the outside, thus some dimensions were added.

The last case is not handled, as it does not happens in GT4Py.
Before, the first and second case were handled together, but it was realized that the second case was not implemented properly and it was added explicitly.
  • Loading branch information
philip-paul-mueller committed Jan 16, 2025
1 parent 1b88276 commit 9e1e721
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -498,31 +498,42 @@ def _gt_map_strides_into_nested_sdfg(
inner_shape = inner_desc.shape
inner_strides_init = inner_desc.strides

outer_shape = outer_desc.shape
outer_strides = outer_desc.strides
outer_inflow = outer_subset.size()

new_strides: list = []
for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True):
if dim_oinflow == 1:
# This is the case of implicit slicing along one dimension.
pass
else:
# There is inflow into the SDFG, so we need the stride.
new_strides.append(dim_ostride)
assert len(new_strides) <= len(inner_shape)

# If we have a scalar on the inside, then there is nothing to adjust.
# We could have performed the test above, but doing it here, gives us
# the chance of validating it.
if isinstance(inner_desc, dace_data.Scalar):
if len(new_strides) != 0:
raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.")
# A scalar does not have a stride that must be propagated.
return

if not isinstance(inner_desc, dace_data.Array):
raise TypeError(
f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'."
)
# Now determine the new stride that is needed on the inside.
new_strides: list = []
if len(outer_shape) == len(inner_shape):
# The inner and the outer descriptor have the same dimensionality.
# We now have to decide if we should take the stride from the outside,
# which happens for example in case of `A[0:N, 0:M] -> B[N, M]`, or if we
# must take 1, which happens if we do `A[0:N, i] -> B[N, 1]`, we detect that
# based on the volume that flows in.
for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True):
new_strides.append(1 if dim_oinflow == 1 else dim_ostride)

elif len(inner_shape) < len(outer_shape):
# There are less dimensions on the inside than on the outside. This means
# that some were sliced away. We detect this case by checking if the Memlet
# subset in that dimension has size 1.
# NOTE: That this is not always correct as it might be possible that there
# are some explicit size 1 dimensions at several places.
new_strides = []
for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True):
if dim_oinflow == 1:
pass
else:
new_strides.append(dim_ostride)
assert len(new_strides) <= len(inner_shape)
else:
# The case that we have more dimensions on the inside than on the outside.
# This is currently not supported.
raise NotImplementedError("NestedSDFGs can not be used to increase the rank.")

if len(new_strides) != len(inner_shape):
raise ValueError("Failed to compute the inner strides.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,99 @@ def ref(a1, b1):
ref(**ref_args)
sdfg_level1(**res_args)
assert np.allclose(ref_args["b1"], res_args["b1"])


def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG:
sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_stride_1_nsdfg"))
state = sdfg_level1.add_state(is_start_block=True)

a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64)
b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64)
stride_syms = {"a": a_stride_sym, "b": b_stride_sym}

for name in ["a", "b"]:
sdfg_level1.add_array(
name,
shape=(10, 1),
strides=(stride_syms[name], 1),
dtype=dace.float64,
transient=False,
)

state.add_mapped_tasklet(
"computation",
map_ranges={"__i": "0:10"},
inputs={"__in": dace.Memlet("a[__i, 0]")},
code="__out = __in + 10",
outputs={"__out": dace.Memlet("b[__i, 0]")},
external_edges=True,
)
sdfg_level1.validate()
return sdfg_level1


def _make_strides_propagation_stride_1_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]:
sdfg = dace.SDFG(util.unique_name("strides_propagation_stride_1_sdfg"))
state = sdfg.add_state(is_start_block=True)

a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64)
b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64)
stride_syms = {"a": a_stride_sym, "b": b_stride_sym}

for name in ["a", "b"]:
sdfg.add_array(
name,
shape=(10, 10),
strides=(stride_syms[name], 1),
dtype=dace.float64,
transient=False,
)

# Now get the nested SDFG.
sdfg_level1 = _make_strides_propagation_stride_1_nsdfg()

nsdfg = state.add_nested_sdfg(
parent=sdfg,
sdfg=sdfg_level1,
inputs={"a"},
outputs={"b"},
symbol_mapping=None,
)

state.add_edge(state.add_access("a"), None, nsdfg, "a", dace.Memlet("a[0:10, 3]"))
state.add_edge(nsdfg, "b", state.add_access("b"), None, dace.Memlet("b[0:10, 2]"))
sdfg.validate()
return sdfg, nsdfg


def test_strides_propagation_stride_1():
def ref(a, b):
for i in range(10):
b[i, 2] = a[i, 3] + 10.0

sdfg, nsdfg = _make_strides_propagation_stride_1_sdfg()

outer_desc_a = sdfg.arrays["a"]
inner_desc_a = nsdfg.sdfg.arrays["a"]
assert outer_desc_a.strides == inner_desc_a.strides

# Now switch the strides of `a` on the top level.
# Essentially going from `C` to FORTRAN order.
stride_outer_a_0, stride_outer_a_1 = outer_desc_a.strides
outer_desc_a.set_shape(outer_desc_a.shape, (stride_outer_a_1, stride_outer_a_0))

# Now we propagate the data into it.
gtx_transformations.gt_propagate_strides_of(sdfg=sdfg, data_name="a")

# Because of the propagation it must now been changed to `(1, 1)` on the inside.
assert inner_desc_a.strides == (1, 1)

res_args = {
"a": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True),
"b": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True),
}
ref_args = copy.deepcopy(res_args)

sdfg(**res_args, a_stride=10, b_stride=10)
ref(**ref_args)
assert np.allclose(ref_args["b"], res_args["b"])

0 comments on commit 9e1e721

Please sign in to comment.