Skip to content

Commit

Permalink
Fully implement NumPy advanced indexing for reads (#1837)
Browse files Browse the repository at this point in the history
* Adds support for multi-dimensional integer arrays as indices
* Adds support for mixing advanced and basic indexing (see
https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing)
* Adds support for `newaxis` in conjunction with advanced indexing
* Fixes array indirection promotion for multidimensional slices with
offset dimensions
  • Loading branch information
tbennun authored Jan 12, 2025
1 parent 3ad9f82 commit f94324a
Show file tree
Hide file tree
Showing 11 changed files with 560 additions and 152 deletions.
19 changes: 12 additions & 7 deletions dace/frontend/python/memlet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

MemletType = Union[ast.Call, ast.Attribute, ast.Subscript, ast.Name]


if sys.version_info < (3, 8):
_simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num)
BytesConstant = ast.Bytes
Expand Down Expand Up @@ -107,6 +106,11 @@ def _fill_missing_slices(das, ast_ndslice, array, indices):
idx = 0
new_idx = 0
has_ellipsis = False

# Count new axes
num_new_axes = sum(1 for dim in ast_ndslice
if (dim is None or (isinstance(dim, (ast.Constant, NameConstant)) and dim.value is None)))

for dim in ast_ndslice:
if isinstance(dim, (str, list, slice)):
dim = ast.Name(id=dim)
Expand Down Expand Up @@ -136,7 +140,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices):
if has_ellipsis:
raise IndexError('an index can only have a single ellipsis ("...")')
has_ellipsis = True
remaining_dims = len(ast_ndslice) - idx - 1
remaining_dims = len(ast_ndslice) - num_new_axes - idx - 1
for j in range(idx, len(ndslice) - remaining_dims):
ndslice[j] = (0, array.shape[j] - 1, 1)
idx += 1
Expand Down Expand Up @@ -170,7 +174,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices):
if desc.dtype == dtypes.bool:
# Boolean array indexing
if len(ast_ndslice) > 1:
raise IndexError(f'Invalid indexing into array "{dim.id}". ' 'Only one boolean array is allowed.')
raise IndexError(f'Invalid indexing into array "{dim.id}". Only one boolean array is allowed.')
if tuple(desc.shape) != tuple(array.shape):
raise IndexError(f'Invalid indexing into array "{dim.id}". '
'Shape of boolean index must match original array.')
Expand Down Expand Up @@ -251,9 +255,9 @@ def parse_memlet_subset(array: data.Data,
# Loop over the N dimensions
ndslice, offsets, new_extra_dims, arrdims = _fill_missing_slices(das, ast_ndslice, narray, offsets)
if new_extra_dims and idx != (len(ast_ndslices) - 1):
raise NotImplementedError('New axes only implemented for last ' 'slice')
raise NotImplementedError('New axes only implemented for last slice')
if arrdims and len(ast_ndslices) != 1:
raise NotImplementedError('Array dimensions not implemented ' 'for consecutive subscripts')
raise NotImplementedError('Array dimensions not implemented for consecutive subscripts')
extra_dims = new_extra_dims
subset_array.append(_ndslice_to_subset(ndslice))

Expand Down Expand Up @@ -305,8 +309,9 @@ def ParseMemlet(visitor,
try:
subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice)
except IndexError:
raise DaceSyntaxError(visitor, node, 'Failed to parse memlet expression due to dimensionality. '
f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}')
raise DaceSyntaxError(
visitor, node, 'Failed to parse memlet expression due to dimensionality. '
f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}')

# If undefined, default number of accesses is the slice size
if num_accesses is None:
Expand Down
370 changes: 261 additions & 109 deletions dace/frontend/python/newast.py

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def _linspace(pv: ProgramVisitor,
start_shape = sdfg.arrays[start].shape if (isinstance(start, str) and start in sdfg.arrays) else []
stop_shape = sdfg.arrays[stop].shape if (isinstance(stop, str) and stop in sdfg.arrays) else []

shape, ranges, outind, ind1, ind2 = _broadcast_together(start_shape, stop_shape)
shape, ranges, outind, ind1, ind2 = broadcast_together(start_shape, stop_shape)
shape_with_axis = _add_axis_to_shape(shape, axis, num)
ranges_with_axis = _add_axis_to_shape(ranges, axis, ('__sind', f'0:{symbolic.symstr(num)}'))
if outind:
Expand Down Expand Up @@ -1325,10 +1325,10 @@ def _array_array_where(visitor: ProgramVisitor,
right_shape = right_arr.shape if right_arr else [1]
cond_shape = cond_arr.shape if cond_arr else [1]

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

# Broadcast condition with broadcasted left+right
_, _, _, cond_idx, _ = _broadcast_together(cond_shape, out_shape)
_, _, _, cond_idx, _ = broadcast_together(cond_shape, out_shape)

# Fix for Scalars
if isinstance(left_arr, data.Scalar):
Expand Down Expand Up @@ -1464,18 +1464,18 @@ def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str):
return name


def _broadcast_to(target_shape, operand_shape):
def broadcast_to(target_shape, operand_shape):
# the difference to normal broadcasting is that the broadcasted shape is the same as the target
# I was unable to find documentation for this in numpy, so we follow the description from ONNX
results = _broadcast_together(target_shape, operand_shape, unidirectional=True)
results = broadcast_together(target_shape, operand_shape, unidirectional=True)

# the output_shape should be equal to the target_shape
assert all(i == o for i, o in zip(target_shape, results[0]))

return results


def _broadcast_together(arr1_shape, arr2_shape, unidirectional=False):
def broadcast_together(arr1_shape, arr2_shape, unidirectional=False):

all_idx_dict, all_idx, a1_idx, a2_idx = {}, [], [], []

Expand Down Expand Up @@ -1523,9 +1523,9 @@ def get_idx(i):
all_idx_dict[get_idx(i)] = dim1
else:
if unidirectional:
raise SyntaxError(f"could not broadcast input array from shape {arr2_shape} into shape {arr1_shape}")
raise IndexError(f"could not broadcast input array from shape {arr2_shape} into shape {arr1_shape}")
else:
raise SyntaxError("operands could not be broadcast together with shapes {}, {}".format(
raise IndexError("operands could not be broadcast together with shapes {}, {}".format(
arr1_shape, arr2_shape))

def to_string(idx):
Expand All @@ -1543,7 +1543,7 @@ def _binop(sdfg: SDFG, state: SDFGState, op1: str, op2: str, opcode: str, opname
arr1 = sdfg.arrays[op1]
arr2 = sdfg.arrays[op2]

out_shape, all_idx_tup, all_idx, arr1_idx, arr2_idx = _broadcast_together(arr1.shape, arr2.shape)
out_shape, all_idx_tup, all_idx, arr1_idx, arr2_idx = broadcast_together(arr1.shape, arr2.shape)

name, _ = sdfg.add_temp_transient(out_shape, restype, arr1.storage)
state.add_mapped_tasklet("_%s_" % opname,
Expand Down Expand Up @@ -1928,7 +1928,7 @@ def _array_array_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le
left_shape = left_arr.shape
right_shape = right_arr.shape

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

# Fix for Scalars
if isinstance(left_arr, data.Scalar):
Expand Down Expand Up @@ -1996,7 +1996,7 @@ def _array_const_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le
if right_cast is not None:
tasklet_args[1] = "{c}({o})".format(c=str(right_cast).replace('::', '.'), o=tasklet_args[1])

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

out_operand, out_arr = sdfg.add_temp_transient(out_shape, result_type, storage)

Expand Down Expand Up @@ -2066,7 +2066,7 @@ def _array_sym_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, left
if right_cast is not None:
tasklet_args[1] = "{c}({o})".format(c=str(right_cast).replace('::', '.'), o=tasklet_args[1])

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

out_operand, out_arr = sdfg.add_temp_transient(out_shape, result_type, storage)

Expand Down
22 changes: 11 additions & 11 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ class InterstateEdge(object):
loop iterates).
"""

assignments = Property(dtype=dict,
desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')")
assignments = Property(dtype=dict, desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')")
condition = CodeProperty(desc="Transition condition", default=CodeBlock("1"))
guid = Property(dtype=str, allow_none=False)

Expand Down Expand Up @@ -214,7 +213,7 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == 'guid': # Skip ID
if k == 'guid': # Skip ID
continue
setattr(result, k, copy.deepcopy(v, memo))
return result
Expand Down Expand Up @@ -416,7 +415,11 @@ class SDFG(ControlFlowRegion):

name = Property(dtype=str, desc="Name of the SDFG")
arg_names = ListProperty(element_type=str, desc='Ordered argument names (used for calling conventions).')
constants_prop = Property(dtype=dict, default={}, desc="Compile-time constants")
constants_prop: Dict[str, Tuple[dt.Data, Any]] = Property(
dtype=dict,
default={},
desc='Compile-time constants. The dictionary maps between a constant name to '
'a tuple of its type and the actual constant data.')
_arrays = Property(dtype=NestedDict,
desc="Data descriptors for this SDFG",
to_json=_arrays_to_json,
Expand Down Expand Up @@ -463,7 +466,8 @@ class SDFG(ControlFlowRegion):
desc='Mapping between callback name and its original callback '
'(for when the same callback is used with a different signature)')

using_explicit_control_flow = Property(dtype=bool, default=False,
using_explicit_control_flow = Property(dtype=bool,
default=False,
desc="Whether the SDFG contains explicit control flow constructs")

def __init__(self,
Expand Down Expand Up @@ -612,9 +616,7 @@ def from_json(cls, json_obj, context=None):

ret = SDFG(name=attrs['name'], constants=constants_prop, parent=context['sdfg'])

dace.serialize.set_properties_from_json(ret,
json_obj,
ignore_properties={'constants_prop', 'name', 'hash'})
dace.serialize.set_properties_from_json(ret, json_obj, ignore_properties={'constants_prop', 'name', 'hash'})

nodelist = []
for n in nodes:
Expand Down Expand Up @@ -742,7 +744,6 @@ def replace_dict(self,
if symrepl:
symrepl = {k: v for k, v in symrepl.items() if str(k) != str(v)}


symrepl = symrepl or {
symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v
for k, v in repldict.items()
Expand Down Expand Up @@ -2318,8 +2319,7 @@ def is_loaded(self) -> bool:
dll = cs.ReloadableDLL(binary_filename, self.name)
return dll.is_loaded()

def compile(self, output_file=None, validate=True,
return_program_handle=True) -> 'CompiledSDFG':
def compile(self, output_file=None, validate=True, return_program_handle=True) -> 'CompiledSDFG':
""" Compiles a runnable binary from this SDFG.
:param output_file: If not None, copies the output library file to
Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,9 @@ def traverse_sdfg_with_defined_symbols(
:return: A generator that yields tuples of (state, node in state, currently-defined symbols)
"""
# Start with global symbols
# Start with global symbols and scalar constants
symbols = copy.copy(sdfg.symbols)
symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()})
symbols.update({k: desc.dtype for k, (desc, _) in sdfg.constants_prop.items() if isinstance(desc, dt.Scalar)})
for desc in sdfg.arrays.values():
symbols.update({str(s): s.dtype for s in desc.free_symbols})

Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
# Ensure that there is a mentioning of constants in either the array or symbol.
for const_name, (const_type, _) in sdfg.constants_prop.items():
if const_name in sdfg.arrays:
if const_type != sdfg.arrays[const_name].dtype:
if const_type.dtype != sdfg.arrays[const_name].dtype:
# This should actually be an error, but there is a lots of code that depends on it.
warnings.warn(f'Mismatch between constant and data descriptor of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".')
Expand Down
27 changes: 22 additions & 5 deletions dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,25 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle
self.out_mapping: Dict[str, Tuple[str, subsets.Range]] = {}
self.do_not_remove: Set[str] = set()

def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subset) -> subsets.Subset:
"""
Returns the requested range from a subscript node, which consists of the memlet subset composed with the
tasklet subset.
:param node: The subscript node.
:param memlet_subset: The memlet subset.
:return: The requested range.
"""
arrname, tasklet_slice = astutils.subscript_to_ast_slice(node)
arrname = arrname if arrname in self.arrays else None
if len(tasklet_slice) < len(memlet_subset):
# Unsqueeze all index dimensions from orig_subset into tasklet_subset
for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))):
if start == end:
tasklet_slice.insert(i, (None, None, None))
tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname))
return memlet_subset.compose(tasklet_subset)

def visit_Subscript(self, node: ast.Subscript) -> Any:
# Convert subscript to symbol name
node = self.generic_visit(node)
Expand All @@ -339,8 +358,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
new_name = dt.find_new_name(node_name, self.connector_names)
self.connector_names.add(new_name)

orig_subset = self.in_edges[node_name].subset
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1]))
subset = self._get_requested_range(node, self.in_edges[node_name].subset)
# Check if range can be collapsed
if _range_is_promotable(subset, self.defined):
self.in_mapping[new_name] = (node_name, subset)
Expand All @@ -351,8 +369,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
new_name = dt.find_new_name(node_name, self.connector_names)
self.connector_names.add(new_name)

orig_subset = self.out_edges[node_name].subset
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1]))
subset = self._get_requested_range(node, self.out_edges[node_name].subset)
# Check if range can be collapsed
if _range_is_promotable(subset, self.defined):
self.out_mapping[new_name] = (node_name, subset)
Expand Down Expand Up @@ -750,4 +767,4 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]:
return to_promote or None

def report(self, pass_retval: Set[str]) -> str:
return f'Promoted {len(pass_retval)} scalars to symbols.'
return f'Promoted {len(pass_retval)} scalars to symbols: {pass_retval}'
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]:
region.remove_branch(branch)
removed_branches += 1
# If the else branch remains, make sure it now has the new negate-all condition.
if new_else_cond is not None and region.branches[-1][0] is None:
if region.branches and new_else_cond is not None and region.branches[-1][0] is None:
region._branches[-1] = (new_else_cond, region._branches[-1][1])

if len(region.branches) == 0:
Expand Down
Loading

0 comments on commit f94324a

Please sign in to comment.