From c13a36269e8753bcca880534e565581e74c60e2a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 10 Jan 2025 08:25:33 +0100 Subject: [PATCH] Implement replacement for numpy's select statement (#1862) Co-authored-by: Tal Ben-Nun --- dace/frontend/python/replacements.py | 82 +++++++++++++++++++++++++--- tests/numpy/searching_test.py | 15 ++++- 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 406b120567..4fc56559c8 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -1289,7 +1289,10 @@ def _array_array_where(visitor: ProgramVisitor, state: SDFGState, cond_operand: str, left_operand: str = None, - right_operand: str = None): + right_operand: str = None, + generated_nodes: Optional[Set[nodes.Node]] = None, + left_operand_node: Optional[nodes.AccessNode] = None, + right_operand_node: Optional[nodes.AccessNode] = None): if left_operand is None or right_operand is None: raise ValueError('numpy.where is only supported for the case where x and y are given') @@ -1346,12 +1349,26 @@ def _array_array_where(visitor: ProgramVisitor, '__out = {i1} if __incond else {i2}'.format(i1=tasklet_args[1], i2=tasklet_args[2])) n0 = state.add_read(cond_operand) n3 = state.add_write(out_operand) + if generated_nodes is not None: + generated_nodes.add(tasklet) + generated_nodes.add(n0) + generated_nodes.add(n3) state.add_edge(n0, None, tasklet, '__incond', dace.Memlet.from_array(cond_operand, cond_arr)) if left_arr: - n1 = state.add_read(left_operand) + if left_operand_node: + n1 = left_operand_node + else: + n1 = state.add_read(left_operand) + if generated_nodes is not None: + generated_nodes.add(n1) state.add_edge(n1, None, tasklet, '__in1', dace.Memlet.from_array(left_operand, left_arr)) if right_arr: - n2 = state.add_read(right_operand) + if right_operand_node: + n2 = right_operand_node + else: + n2 = state.add_read(right_operand) + if generated_nodes is not None: + generated_nodes.add(n2) state.add_edge(n2, None, tasklet, '__in2', dace.Memlet.from_array(right_operand, right_arr)) state.add_edge(tasklet, '__out', n3, None, dace.Memlet.from_array(out_operand, out_arr)) else: @@ -1361,12 +1378,59 @@ def _array_array_where(visitor: ProgramVisitor, inputs['__in1'] = Memlet.simple(left_operand, left_idx) if right_arr: inputs['__in2'] = Memlet.simple(right_operand, right_idx) - state.add_mapped_tasklet("_where_", - all_idx_dict, - inputs, - '__out = {i1} if __incond else {i2}'.format(i1=tasklet_args[1], i2=tasklet_args[2]), - {'__out': Memlet.simple(out_operand, out_idx)}, - external_edges=True) + input_nodes = {} + if left_operand_node: + input_nodes[left_operand] = left_operand_node + if right_operand_node: + input_nodes[right_operand] = right_operand_node + tasklet, me, mx = state.add_mapped_tasklet("_where_", all_idx_dict, inputs, + '__out = {i1} if __incond else {i2}'.format(i1=tasklet_args[1], + i2=tasklet_args[2]), + {'__out': Memlet.simple(out_operand, out_idx)}, external_edges=True, + input_nodes=input_nodes) + if generated_nodes is not None: + generated_nodes.add(tasklet) + generated_nodes.add(me) + for ie in state.in_edges(me): + if ie.src is not left_operand_node and ie.src is not right_operand_node: + generated_nodes.add(ie.src) + generated_nodes.add(mx) + for oe in state.out_edges(mx): + generated_nodes.add(oe.dst) + + return out_operand + + +@oprepo.replaces('numpy.select') +def _array_array_select(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + cond_list: List[str], + choice_list: List[str], + default = None): + if len(cond_list) != len(choice_list): + raise ValueError('numpy.select is only valid with same-length condition and choice lists') + + default_operand = default if default is not None else 0 + + i = len(cond_list) - 1 + cond_operand = cond_list[i] + left_operand = choice_list[i] + right_operand = default_operand + right_operand_node = None + out_operand = None + while i >= 0: + generated_nodes = set() + out_operand = _array_array_where(visitor, sdfg, state, cond_operand, left_operand, right_operand, + generated_nodes=generated_nodes, right_operand_node=right_operand_node) + i -= 1 + cond_operand = cond_list[i] + left_operand = choice_list[i] + right_operand = out_operand + right_operand_node = None + for nd in generated_nodes: + if isinstance(nd, nodes.AccessNode) and nd.data == out_operand: + right_operand_node = nd return out_operand diff --git a/tests/numpy/searching_test.py b/tests/numpy/searching_test.py index 631d4dba32..d89d41486b 100644 --- a/tests/numpy/searching_test.py +++ b/tests/numpy/searching_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace import numpy as np @@ -15,5 +15,18 @@ def numpy_where(A: dace.float64[N]): assert (np.allclose(numpy_where(A), np.where(A > 0.5, A, 0.0))) +def test_numpy_select(): + @dace.program + def numpy_where(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]): + return np.select([A > 0.5, B > 0.5, C > 0.5], [A, B, C], 0.0) + + for _ in range(10): + A = np.random.randn(N) + B = np.random.randn(N) + C = np.random.randn(N) + assert (np.allclose(numpy_where(A, B, C), np.select([A > 0.5, B > 0.5, C > 0.5], [A, B, C], 0.0))) + + if __name__ == "__main__": test_numpy_where() + test_numpy_select()