Skip to content

Commit

Permalink
Python: Support complex left-hand side expressions, argmin/max (#1843)
Browse files Browse the repository at this point in the history
Fixes #1842
  • Loading branch information
tbennun authored Jan 6, 2025
1 parent 7dc7957 commit 43883ea
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 62 deletions.
5 changes: 3 additions & 2 deletions dace/frontend/python/memlet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,10 @@ def parse_memlet_subset(array: data.Data,
def ParseMemlet(visitor,
defined_arrays_and_symbols: Dict[str, Any],
node: MemletType,
parsed_slice: Any = None) -> MemletExpr:
parsed_slice: Any = None,
arrname: Optional[str] = None) -> MemletExpr:
das = defined_arrays_and_symbols
arrname = rname(node)
arrname = arrname or rname(node)
if arrname not in das:
raise DaceSyntaxError(visitor, node, 'Use of undefined data "%s" in memlet' % arrname)
array = das[arrname]
Expand Down
45 changes: 38 additions & 7 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3306,12 +3306,37 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
name = rname(target)
tokens = name.split('.')
name = tokens[0]
tokens.pop(0)
true_name = None
true_array = None
visited_target = False

if name in defined_vars:
true_name = defined_vars[name]
if len(tokens) > 1:
true_name = '.'.join([true_name, *tokens[1:]])
# Handle complex object assignment (e.g., A.flat[:])
if isinstance(target, ast.Subscript): # In case of nested subscripts, find the root AST node
last_subscript = target
# Find the first non-subscript target
while isinstance(last_subscript.value, ast.Subscript):
last_subscript = last_subscript.value
if isinstance(target, ast.Subscript) and not isinstance(last_subscript.value, ast.Name):
store_target = copy.copy(last_subscript.value)
store_target.ctx = ast.Store()
true_name = self.visit(store_target)
# Refresh defined variables and arrays
defined_vars = {**self.variables, **self.scope_vars}
defined_arrays = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.scope_arrays})
visited_target = True
else:
true_name = defined_vars[name]
while len(tokens) > 1:
true_name = true_name + '.' + tokens.pop(0)
if true_name not in self.sdfg.arrays:
break
if tokens: # The non-struct remainder will be considered an attribute
attribute_name = '.'.join(tokens)
raise DaceSyntaxError(
self, target, f'Cannot assign to attribute "{attribute_name}" of variable "{true_name}"')

true_array = defined_arrays[true_name]

# If type was already annotated
Expand Down Expand Up @@ -3431,13 +3456,18 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if new_data:
rng = rng or dace.subsets.Range.from_array(new_data)
else:
true_target = copy.copy(target)
true_target = astutils.copy_tree(target)
nslice = None
if isinstance(target, ast.Name):
true_target.id = true_name
elif isinstance(target, ast.Subscript):
true_target.value = copy.copy(true_target.value)
true_target.value.id = true_name
# In case of nested subscripts, find the root AST node
last_subscript = true_target
# Find the first non-subscript target and modify its value to the new name
while isinstance(last_subscript.value, ast.Subscript):
last_subscript = last_subscript.value
last_subscript.value = ast.copy_location(ast.Name(id=true_name, ctx=ast.Store()),
last_subscript.value)

# Visit slice contents
nslice = self._parse_subscript_slice(true_target.slice)
Expand Down Expand Up @@ -3485,7 +3515,8 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if boolarr is not None and indirect_indices:
raise IndexError('Boolean array indexing cannot be combined with indirect access')

if self.nested and not new_data:

if self.nested and not new_data and not visited_target:
new_name, new_rng = self._add_write_access(name, rng, target)
# Local symbol or local data dependent
if _subset_is_local_symbol_dependent(rng, self):
Expand Down
110 changes: 70 additions & 40 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,12 +1171,22 @@ def _slice(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs):


@oprepo.replaces('numpy.argmax')
def _argmax(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis, result_type=dace.int32):
def _argmax(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
a: str,
axis: Optional[int] = None,
result_type=dace.int32):
return _argminmax(pv, sdfg, state, a, axis, func="max", result_type=result_type)


@oprepo.replaces('numpy.argmin')
def _argmin(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis, result_type=dace.int32):
def _argmin(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
a: str,
axis: Optional[int] = None,
result_type=dace.int32):
return _argminmax(pv, sdfg, state, a, axis, func="min", result_type=result_type)


Expand All @@ -1192,7 +1202,12 @@ def _argminmax(pv: ProgramVisitor,

assert func in ['min', 'max']

if axis is None or not isinstance(axis, Integral):
# Flatten the array if axis is not given
if axis is None:
axis = 0
a = flat(pv, sdfg, state, a)

if not isinstance(axis, Integral):
raise SyntaxError('Axis must be an int')

a_arr = sdfg.arrays[a]
Expand All @@ -1202,6 +1217,8 @@ def _argminmax(pv: ProgramVisitor,

reduced_shape = list(copy.deepcopy(a_arr.shape))
reduced_shape.pop(axis)
if not reduced_shape:
reduced_shape = [1]

val_and_idx = dace.struct('_val_and_idx', idx=result_type, val=a_arr.dtype)

Expand All @@ -1211,17 +1228,17 @@ def _argminmax(pv: ProgramVisitor,
code = "__init = _val_and_idx(val={}, idx=-1)".format(
dtypes.min_value(a_arr.dtype) if func == 'max' else dtypes.max_value(a_arr.dtype))

nest.add_state().add_mapped_tasklet(
name="_arg{}_convert_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
inputs={},
code=code,
outputs={
'__init': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
},
external_edges=True)
reduced_expr = ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)
reduced_maprange = {'__i%d' % i: '0:%s' % n for i, n in enumerate(a_arr.shape) if i != axis}
if not reduced_expr:
reduced_expr = '0'
reduced_maprange = {'__i0': '0:1'}
nest.add_state().add_mapped_tasklet(name="_arg{}_convert_".format(func),
map_ranges=reduced_maprange,
inputs={},
code=code,
outputs={'__init': Memlet.simple(reduced_structs, reduced_expr)},
external_edges=True)

nest.add_state().add_mapped_tasklet(
name="_arg{}_reduce_".format(func),
Expand All @@ -1232,7 +1249,7 @@ def _argminmax(pv: ProgramVisitor,
outputs={
'__out':
Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis),
reduced_expr,
wcr_str=("lambda x, y:"
"_val_and_idx(val={}(x.val, y.val), "
"idx=(y.idx if x.val {} y.val else x.idx))").format(
Expand All @@ -1246,16 +1263,14 @@ def _argminmax(pv: ProgramVisitor,

nest.add_state().add_mapped_tasklet(
name="_arg{}_extract_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
map_ranges=reduced_maprange,
inputs={
'__in': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
'__in': Memlet.simple(reduced_structs, reduced_expr)
},
code="__out_val = __in.val\n__out_idx = __in.idx",
outputs={
'__out_val': Memlet.simple(outval, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)),
'__out_idx': Memlet.simple(outidx, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
'__out_val': Memlet.simple(outval, reduced_expr),
'__out_idx': Memlet.simple(outidx, reduced_expr)
},
external_edges=True)

Expand Down Expand Up @@ -4569,25 +4584,40 @@ def _ndarray_min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, kwa
return implement_ufunc_reduce(pv, None, sdfg, state, 'minimum', [arr], kwargs)[0]


# TODO: It looks like `_argminmax` does not work with a flattened array.
# @oprepo.replaces_method('Array', 'argmax')
# @oprepo.replaces_method('Scalar', 'argmax')
# @oprepo.replaces_method('View', 'argmax')
# def _ndarray_argmax(pv: ProgramVisitor,
# sdfg: SDFG,
# state: SDFGState,
# arr: str,
# axis: int = None,
# out: str = None) -> str:
# if not axis:
# axis = 0
# arr = flat(pv, sdfg, state, arr)
# nest, newarr = _argmax(pv, sdfg, state, arr, axis)
# if out:
# r = state.add_read(arr)
# w = state.add_read(newarr)
# state.add_nedge(r, w, dace.Memlet.from_array(newarr, sdfg.arrays[newarr]))
# return new_arr
@oprepo.replaces_method('Array', 'argmax')
@oprepo.replaces_method('Scalar', 'argmax')
@oprepo.replaces_method('View', 'argmax')
def _ndarray_argmax(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
arr: str,
axis: int = None,
out: str = None) -> str:
nest, newarr = _argmax(pv, sdfg, state, arr, axis)
if out:
r = state.add_read(newarr)
w = state.add_write(out)
state.add_nedge(r, w, dace.Memlet.from_array(newarr, sdfg.arrays[newarr]))
newarr = out
return newarr


@oprepo.replaces_method('Array', 'argmin')
@oprepo.replaces_method('Scalar', 'argmin')
@oprepo.replaces_method('View', 'argmin')
def _ndarray_argmin(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
arr: str,
axis: int = None,
out: str = None) -> str:
nest, newarr = _argmin(pv, sdfg, state, arr, axis)
if out:
r = state.add_read(newarr)
w = state.add_write(out)
state.add_nedge(r, w, dace.Memlet.from_array(newarr, sdfg.arrays[newarr]))
newarr = out
return newarr


@oprepo.replaces_method('Array', 'conj')
Expand Down
2 changes: 1 addition & 1 deletion dace/runtime/include/dace/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ namespace dace
{
namespace math
{
static DACE_CONSTEXPR DACE_HostDev typeless_pi pi{};
static DACE_CONSTEXPR_HOSTDEV typeless_pi pi{};
static DACE_CONSTEXPR typeless_nan nan{};
//////////////////////////////////////////////////////
template<typename T>
Expand Down
14 changes: 8 additions & 6 deletions dace/runtime/include/dace/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@
#define DACE_HDFI __host__ __device__ __forceinline__
#define DACE_HFI __host__ __forceinline__
#define DACE_DFI __device__ __forceinline__
#define DACE_HostDev __host__ __device__
#define DACE_Host __host__
#define DACE_Dev __device__
#else
#define DACE_HostDev
#define DACE_Host
#define DACE_Dev
#define DACE_HDFI inline
#define DACE_HFI inline
#define DACE_DFI inline
Expand All @@ -67,6 +61,14 @@
#define __DACE_UNROLL
#endif

// If CUDA version is 11.4 or higher, __device__ variables can be declared as constexpr
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 4))
#define DACE_CONSTEXPR_HOSTDEV constexpr __host__ __device__
#elif defined(__CUDACC__) || defined(__HIPCC__)
#define DACE_CONSTEXPR_HOSTDEV const __host__ __device__
#else
#define DACE_CONSTEXPR_HOSTDEV const
#endif


namespace dace
Expand Down
2 changes: 1 addition & 1 deletion tests/inlining_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test(A: dace.float64[96, 32], B: dace.float64[42, 32]):
sdfg.expand_library_nodes()
sdfg.simplify()

state = sdfg.nodes()[1]
state = sdfg.sink_nodes()[0]
# find nested_sdfg
nsdfg = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.NestedSDFG)][0]
# delete gemm initialization state
Expand Down
15 changes: 15 additions & 0 deletions tests/numpy/advanced_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ def indexing_test(A: dace.float64[5, 5, 5, 5, 5]):
assert np.allclose(A, regression)


def test_aug_implicit_attribute():

@dace.program
def indexing_test(A: dace.float64[5, 5, 5, 5, 5]):
A.flat[10:15][0:2] += 5

A = np.random.rand(5, 5, 5, 5, 5)
regression = np.copy(A)
# FIXME: NumPy does not support augmented assignment on a sub-iterator of a flat iterator
regression.flat[10:12] += 5
indexing_test(A)
assert np.allclose(A, regression)


def test_ellipsis_aug():

@dace.program
Expand Down Expand Up @@ -329,6 +343,7 @@ def indexing_test(A: dace.float64[N, N, N]):
test_flat_noncontiguous()
test_ellipsis()
test_aug_implicit()
test_aug_implicit_attribute()
test_ellipsis_aug()
test_newaxis()
test_multiple_newaxis()
Expand Down
28 changes: 23 additions & 5 deletions tests/numpy/ndarray_attributes_methods_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def test_copy(A: dace.float32[M, N]):
return A.copy()


@compare_numpy_output()
def test_lhs_flat(A: dace.float64[M, N]):
A.flat[150] = 5
return A


@compare_numpy_output()
def test_astype(A: dace.int32[M, N]):
return A.astype(np.float32)
Expand Down Expand Up @@ -85,17 +91,26 @@ def test_max(A: dace.float32[M, N, N, M]):
return A.max()


# TODO: Need to debug `_argminmax`
# @compare_numpy_output()
# def test_argmax(A: dace.float32[M, N, N, M]):
# return A.argmax()
@compare_numpy_output()
def test_argmax(A: dace.float32[M, N, N, M]):
return A.argmax()


@compare_numpy_output()
def test_argmax_dim(A: dace.float32[M, N, N, M]):
return A.argmax(1)


@compare_numpy_output()
def test_min(A: dace.float32[M, N, N, M]):
return A.min()


@compare_numpy_output()
def test_argmin(A: dace.float32[M, N, N, M]):
return A.argmin()


@compare_numpy_output()
def test_conj(A: dace.complex64[M, N, N, M]):
return A.conj()
Expand Down Expand Up @@ -134,6 +149,7 @@ def test_any():
test_real()
test_imag()
test_copy()
test_lhs_flat()
test_astype()
test_fill()
test_fill2()
Expand All @@ -145,8 +161,10 @@ def test_any():
test_flatten()
test_ravel()
test_max()
# test_argmax()
test_argmax()
test_argmax_dim()
test_min()
test_argmin()
test_conj()
test_sum()
test_mean()
Expand Down

0 comments on commit 43883ea

Please sign in to comment.