Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Support complex left-hand side expressions, argmin/max #1843

Merged
merged 6 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dace/frontend/python/memlet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,10 @@ def parse_memlet_subset(array: data.Data,
def ParseMemlet(visitor,
defined_arrays_and_symbols: Dict[str, Any],
node: MemletType,
parsed_slice: Any = None) -> MemletExpr:
parsed_slice: Any = None,
arrname: Optional[str] = None) -> MemletExpr:
das = defined_arrays_and_symbols
arrname = rname(node)
arrname = arrname or rname(node)
if arrname not in das:
raise DaceSyntaxError(visitor, node, 'Use of undefined data "%s" in memlet' % arrname)
array = das[arrname]
Expand Down
45 changes: 38 additions & 7 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3306,12 +3306,37 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
name = rname(target)
tokens = name.split('.')
name = tokens[0]
tokens.pop(0)
true_name = None
true_array = None
visited_target = False

if name in defined_vars:
true_name = defined_vars[name]
if len(tokens) > 1:
true_name = '.'.join([true_name, *tokens[1:]])
# Handle complex object assignment (e.g., A.flat[:])
if isinstance(target, ast.Subscript): # In case of nested subscripts, find the root AST node
last_subscript = target
# Find the first non-subscript target
while isinstance(last_subscript.value, ast.Subscript):
last_subscript = last_subscript.value
if isinstance(target, ast.Subscript) and not isinstance(last_subscript.value, ast.Name):
store_target = copy.copy(last_subscript.value)
store_target.ctx = ast.Store()
true_name = self.visit(store_target)
# Refresh defined variables and arrays
defined_vars = {**self.variables, **self.scope_vars}
defined_arrays = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.scope_arrays})
visited_target = True
else:
true_name = defined_vars[name]
while len(tokens) > 1:
true_name = true_name + '.' + tokens.pop(0)
if true_name not in self.sdfg.arrays:
break
if tokens: # The non-struct remainder will be considered an attribute
attribute_name = '.'.join(tokens)
raise DaceSyntaxError(
self, target, f'Cannot assign to attribute "{attribute_name}" of variable "{true_name}"')

true_array = defined_arrays[true_name]

# If type was already annotated
Expand Down Expand Up @@ -3431,13 +3456,18 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if new_data:
rng = rng or dace.subsets.Range.from_array(new_data)
else:
true_target = copy.copy(target)
true_target = astutils.copy_tree(target)
nslice = None
if isinstance(target, ast.Name):
true_target.id = true_name
elif isinstance(target, ast.Subscript):
true_target.value = copy.copy(true_target.value)
true_target.value.id = true_name
# In case of nested subscripts, find the root AST node
last_subscript = true_target
# Find the first non-subscript target and modify its value to the new name
while isinstance(last_subscript.value, ast.Subscript):
last_subscript = last_subscript.value
last_subscript.value = ast.copy_location(ast.Name(id=true_name, ctx=ast.Store()),
last_subscript.value)

# Visit slice contents
nslice = self._parse_subscript_slice(true_target.slice)
Expand Down Expand Up @@ -3485,7 +3515,8 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if boolarr is not None and indirect_indices:
raise IndexError('Boolean array indexing cannot be combined with indirect access')

if self.nested and not new_data:

if self.nested and not new_data and not visited_target:
new_name, new_rng = self._add_write_access(name, rng, target)
# Local symbol or local data dependent
if _subset_is_local_symbol_dependent(rng, self):
Expand Down
110 changes: 70 additions & 40 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,12 +1159,22 @@ def _slice(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs):


@oprepo.replaces('numpy.argmax')
def _argmax(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis, result_type=dace.int32):
def _argmax(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
a: str,
axis: Optional[int] = None,
result_type=dace.int32):
return _argminmax(pv, sdfg, state, a, axis, func="max", result_type=result_type)


@oprepo.replaces('numpy.argmin')
def _argmin(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis, result_type=dace.int32):
def _argmin(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
a: str,
axis: Optional[int] = None,
result_type=dace.int32):
return _argminmax(pv, sdfg, state, a, axis, func="min", result_type=result_type)


Expand All @@ -1180,7 +1190,12 @@ def _argminmax(pv: ProgramVisitor,

assert func in ['min', 'max']

if axis is None or not isinstance(axis, Integral):
# Flatten the array if axis is not given
if axis is None:
axis = 0
a = flat(pv, sdfg, state, a)

if not isinstance(axis, Integral):
raise SyntaxError('Axis must be an int')

a_arr = sdfg.arrays[a]
Expand All @@ -1190,6 +1205,8 @@ def _argminmax(pv: ProgramVisitor,

reduced_shape = list(copy.deepcopy(a_arr.shape))
reduced_shape.pop(axis)
if not reduced_shape:
reduced_shape = [1]

val_and_idx = dace.struct('_val_and_idx', idx=result_type, val=a_arr.dtype)

Expand All @@ -1199,17 +1216,17 @@ def _argminmax(pv: ProgramVisitor,
code = "__init = _val_and_idx(val={}, idx=-1)".format(
dtypes.min_value(a_arr.dtype) if func == 'max' else dtypes.max_value(a_arr.dtype))

nest.add_state().add_mapped_tasklet(
name="_arg{}_convert_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
inputs={},
code=code,
outputs={
'__init': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
},
external_edges=True)
reduced_expr = ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)
reduced_maprange = {'__i%d' % i: '0:%s' % n for i, n in enumerate(a_arr.shape) if i != axis}
if not reduced_expr:
reduced_expr = '0'
reduced_maprange = {'__i0': '0:1'}
nest.add_state().add_mapped_tasklet(name="_arg{}_convert_".format(func),
map_ranges=reduced_maprange,
inputs={},
code=code,
outputs={'__init': Memlet.simple(reduced_structs, reduced_expr)},
external_edges=True)

nest.add_state().add_mapped_tasklet(
name="_arg{}_reduce_".format(func),
Expand All @@ -1220,7 +1237,7 @@ def _argminmax(pv: ProgramVisitor,
outputs={
'__out':
Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis),
reduced_expr,
wcr_str=("lambda x, y:"
"_val_and_idx(val={}(x.val, y.val), "
"idx=(y.idx if x.val {} y.val else x.idx))").format(
Expand All @@ -1234,16 +1251,14 @@ def _argminmax(pv: ProgramVisitor,

nest.add_state().add_mapped_tasklet(
name="_arg{}_extract_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
map_ranges=reduced_maprange,
inputs={
'__in': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
'__in': Memlet.simple(reduced_structs, reduced_expr)
},
code="__out_val = __in.val\n__out_idx = __in.idx",
outputs={
'__out_val': Memlet.simple(outval, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)),
'__out_idx': Memlet.simple(outidx, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
'__out_val': Memlet.simple(outval, reduced_expr),
'__out_idx': Memlet.simple(outidx, reduced_expr)
},
external_edges=True)

Expand Down Expand Up @@ -4557,25 +4572,40 @@ def _ndarray_min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, kwa
return implement_ufunc_reduce(pv, None, sdfg, state, 'minimum', [arr], kwargs)[0]


# TODO: It looks like `_argminmax` does not work with a flattened array.
# @oprepo.replaces_method('Array', 'argmax')
# @oprepo.replaces_method('Scalar', 'argmax')
# @oprepo.replaces_method('View', 'argmax')
# def _ndarray_argmax(pv: ProgramVisitor,
# sdfg: SDFG,
# state: SDFGState,
# arr: str,
# axis: int = None,
# out: str = None) -> str:
# if not axis:
# axis = 0
# arr = flat(pv, sdfg, state, arr)
# nest, newarr = _argmax(pv, sdfg, state, arr, axis)
# if out:
# r = state.add_read(arr)
# w = state.add_read(newarr)
# state.add_nedge(r, w, dace.Memlet.from_array(newarr, sdfg.arrays[newarr]))
# return new_arr
@oprepo.replaces_method('Array', 'argmax')
@oprepo.replaces_method('Scalar', 'argmax')
@oprepo.replaces_method('View', 'argmax')
def _ndarray_argmax(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
arr: str,
axis: int = None,
out: str = None) -> str:
nest, newarr = _argmax(pv, sdfg, state, arr, axis)
if out:
r = state.add_read(newarr)
w = state.add_write(out)
state.add_nedge(r, w, dace.Memlet.from_array(newarr, sdfg.arrays[newarr]))
newarr = out
return newarr


@oprepo.replaces_method('Array', 'argmin')
@oprepo.replaces_method('Scalar', 'argmin')
@oprepo.replaces_method('View', 'argmin')
def _ndarray_argmin(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
arr: str,
axis: int = None,
out: str = None) -> str:
nest, newarr = _argmin(pv, sdfg, state, arr, axis)
if out:
r = state.add_read(newarr)
w = state.add_write(out)
state.add_nedge(r, w, dace.Memlet.from_array(newarr, sdfg.arrays[newarr]))
newarr = out
return newarr


@oprepo.replaces_method('Array', 'conj')
Expand Down
2 changes: 1 addition & 1 deletion dace/runtime/include/dace/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ namespace dace
{
namespace math
{
static DACE_CONSTEXPR DACE_HostDev typeless_pi pi{};
static DACE_CONSTEXPR_HOSTDEV typeless_pi pi{};
static DACE_CONSTEXPR typeless_nan nan{};
//////////////////////////////////////////////////////
template<typename T>
Expand Down
14 changes: 8 additions & 6 deletions dace/runtime/include/dace/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@
#define DACE_HDFI __host__ __device__ __forceinline__
#define DACE_HFI __host__ __forceinline__
#define DACE_DFI __device__ __forceinline__
#define DACE_HostDev __host__ __device__
#define DACE_Host __host__
#define DACE_Dev __device__
#else
#define DACE_HostDev
#define DACE_Host
#define DACE_Dev
#define DACE_HDFI inline
#define DACE_HFI inline
#define DACE_DFI inline
Expand All @@ -67,6 +61,14 @@
#define __DACE_UNROLL
#endif

// If CUDA version is 11.4 or higher, __device__ variables can be declared as constexpr
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 4))
#define DACE_CONSTEXPR_HOSTDEV constexpr __host__ __device__
#elif defined(__CUDACC__) || defined(__HIPCC__)
#define DACE_CONSTEXPR_HOSTDEV const __host__ __device__
#else
#define DACE_CONSTEXPR_HOSTDEV const
#endif


namespace dace
Expand Down
2 changes: 1 addition & 1 deletion tests/inlining_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test(A: dace.float64[96, 32], B: dace.float64[42, 32]):
sdfg.expand_library_nodes()
sdfg.simplify()

state = sdfg.nodes()[1]
state = sdfg.sink_nodes()[0]
# find nested_sdfg
nsdfg = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.NestedSDFG)][0]
# delete gemm initialization state
Expand Down
15 changes: 15 additions & 0 deletions tests/numpy/advanced_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ def indexing_test(A: dace.float64[5, 5, 5, 5, 5]):
assert np.allclose(A, regression)


def test_aug_implicit_attribute():

@dace.program
def indexing_test(A: dace.float64[5, 5, 5, 5, 5]):
A.flat[10:15][0:2] += 5

A = np.random.rand(5, 5, 5, 5, 5)
regression = np.copy(A)
# FIXME: NumPy does not support augmented assignment on a sub-iterator of a flat iterator
regression.flat[10:12] += 5
indexing_test(A)
assert np.allclose(A, regression)


def test_ellipsis_aug():

@dace.program
Expand Down Expand Up @@ -329,6 +343,7 @@ def indexing_test(A: dace.float64[N, N, N]):
test_flat_noncontiguous()
test_ellipsis()
test_aug_implicit()
test_aug_implicit_attribute()
test_ellipsis_aug()
test_newaxis()
test_multiple_newaxis()
Expand Down
28 changes: 23 additions & 5 deletions tests/numpy/ndarray_attributes_methods_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def test_copy(A: dace.float32[M, N]):
return A.copy()


@compare_numpy_output()
def test_lhs_flat(A: dace.float64[M, N]):
A.flat[150] = 5
return A


@compare_numpy_output()
def test_astype(A: dace.int32[M, N]):
return A.astype(np.float32)
Expand Down Expand Up @@ -85,17 +91,26 @@ def test_max(A: dace.float32[M, N, N, M]):
return A.max()


# TODO: Need to debug `_argminmax`
# @compare_numpy_output()
# def test_argmax(A: dace.float32[M, N, N, M]):
# return A.argmax()
@compare_numpy_output()
def test_argmax(A: dace.float32[M, N, N, M]):
return A.argmax()


@compare_numpy_output()
def test_argmax_dim(A: dace.float32[M, N, N, M]):
return A.argmax(1)


@compare_numpy_output()
def test_min(A: dace.float32[M, N, N, M]):
return A.min()


@compare_numpy_output()
def test_argmin(A: dace.float32[M, N, N, M]):
return A.argmin()


@compare_numpy_output()
def test_conj(A: dace.complex64[M, N, N, M]):
return A.conj()
Expand Down Expand Up @@ -134,6 +149,7 @@ def test_any():
test_real()
test_imag()
test_copy()
test_lhs_flat()
test_astype()
test_fill()
test_fill2()
Expand All @@ -145,8 +161,10 @@ def test_any():
test_flatten()
test_ravel()
test_max()
# test_argmax()
test_argmax()
test_argmax_dim()
test_min()
test_argmin()
test_conj()
test_sum()
test_mean()
Expand Down
Loading