Skip to content

Commit

Permalink
Implement replacement for numpy's select statement (spcl#1862)
Browse files Browse the repository at this point in the history
Co-authored-by: Tal Ben-Nun <tbennun@users.noreply.github.com>
  • Loading branch information
phschaad and tbennun authored Jan 10, 2025
1 parent e606ca0 commit c13a362
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 10 deletions.
82 changes: 73 additions & 9 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,10 @@ def _array_array_where(visitor: ProgramVisitor,
state: SDFGState,
cond_operand: str,
left_operand: str = None,
right_operand: str = None):
right_operand: str = None,
generated_nodes: Optional[Set[nodes.Node]] = None,
left_operand_node: Optional[nodes.AccessNode] = None,
right_operand_node: Optional[nodes.AccessNode] = None):
if left_operand is None or right_operand is None:
raise ValueError('numpy.where is only supported for the case where x and y are given')

Expand Down Expand Up @@ -1346,12 +1349,26 @@ def _array_array_where(visitor: ProgramVisitor,
'__out = {i1} if __incond else {i2}'.format(i1=tasklet_args[1], i2=tasklet_args[2]))
n0 = state.add_read(cond_operand)
n3 = state.add_write(out_operand)
if generated_nodes is not None:
generated_nodes.add(tasklet)
generated_nodes.add(n0)
generated_nodes.add(n3)
state.add_edge(n0, None, tasklet, '__incond', dace.Memlet.from_array(cond_operand, cond_arr))
if left_arr:
n1 = state.add_read(left_operand)
if left_operand_node:
n1 = left_operand_node
else:
n1 = state.add_read(left_operand)
if generated_nodes is not None:
generated_nodes.add(n1)
state.add_edge(n1, None, tasklet, '__in1', dace.Memlet.from_array(left_operand, left_arr))
if right_arr:
n2 = state.add_read(right_operand)
if right_operand_node:
n2 = right_operand_node
else:
n2 = state.add_read(right_operand)
if generated_nodes is not None:
generated_nodes.add(n2)
state.add_edge(n2, None, tasklet, '__in2', dace.Memlet.from_array(right_operand, right_arr))
state.add_edge(tasklet, '__out', n3, None, dace.Memlet.from_array(out_operand, out_arr))
else:
Expand All @@ -1361,12 +1378,59 @@ def _array_array_where(visitor: ProgramVisitor,
inputs['__in1'] = Memlet.simple(left_operand, left_idx)
if right_arr:
inputs['__in2'] = Memlet.simple(right_operand, right_idx)
state.add_mapped_tasklet("_where_",
all_idx_dict,
inputs,
'__out = {i1} if __incond else {i2}'.format(i1=tasklet_args[1], i2=tasklet_args[2]),
{'__out': Memlet.simple(out_operand, out_idx)},
external_edges=True)
input_nodes = {}
if left_operand_node:
input_nodes[left_operand] = left_operand_node
if right_operand_node:
input_nodes[right_operand] = right_operand_node
tasklet, me, mx = state.add_mapped_tasklet("_where_", all_idx_dict, inputs,
'__out = {i1} if __incond else {i2}'.format(i1=tasklet_args[1],
i2=tasklet_args[2]),
{'__out': Memlet.simple(out_operand, out_idx)}, external_edges=True,
input_nodes=input_nodes)
if generated_nodes is not None:
generated_nodes.add(tasklet)
generated_nodes.add(me)
for ie in state.in_edges(me):
if ie.src is not left_operand_node and ie.src is not right_operand_node:
generated_nodes.add(ie.src)
generated_nodes.add(mx)
for oe in state.out_edges(mx):
generated_nodes.add(oe.dst)

return out_operand


@oprepo.replaces('numpy.select')
def _array_array_select(visitor: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
cond_list: List[str],
choice_list: List[str],
default = None):
if len(cond_list) != len(choice_list):
raise ValueError('numpy.select is only valid with same-length condition and choice lists')

default_operand = default if default is not None else 0

i = len(cond_list) - 1
cond_operand = cond_list[i]
left_operand = choice_list[i]
right_operand = default_operand
right_operand_node = None
out_operand = None
while i >= 0:
generated_nodes = set()
out_operand = _array_array_where(visitor, sdfg, state, cond_operand, left_operand, right_operand,
generated_nodes=generated_nodes, right_operand_node=right_operand_node)
i -= 1
cond_operand = cond_list[i]
left_operand = choice_list[i]
right_operand = out_operand
right_operand_node = None
for nd in generated_nodes:
if isinstance(nd, nodes.AccessNode) and nd.data == out_operand:
right_operand_node = nd

return out_operand

Expand Down
15 changes: 14 additions & 1 deletion tests/numpy/searching_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np

Expand All @@ -15,5 +15,18 @@ def numpy_where(A: dace.float64[N]):
assert (np.allclose(numpy_where(A), np.where(A > 0.5, A, 0.0)))


def test_numpy_select():
@dace.program
def numpy_where(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]):
return np.select([A > 0.5, B > 0.5, C > 0.5], [A, B, C], 0.0)

for _ in range(10):
A = np.random.randn(N)
B = np.random.randn(N)
C = np.random.randn(N)
assert (np.allclose(numpy_where(A, B, C), np.select([A > 0.5, B > 0.5, C > 0.5], [A, B, C], 0.0)))


if __name__ == "__main__":
test_numpy_where()
test_numpy_select()

0 comments on commit c13a362

Please sign in to comment.