Skip to content

Commit

Permalink
Merge pull request #127 from ev-br/1D_array_indices
Browse files Browse the repository at this point in the history
fancy indexing with ints and integer arrays
  • Loading branch information
ev-br authored Feb 22, 2025
2 parents 5ccb0e7 + 6664e6d commit 4b8fbef
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 35 deletions.
30 changes: 20 additions & 10 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:

# Note: A large fraction of allowed indices are disallowed here (see the
# docstring below)
def _validate_index(self, key):
def _validate_index(self, key, op="getitem"):
"""
Validate an index according to the array API.
Expand Down Expand Up @@ -390,11 +390,16 @@ def _validate_index(self, key):
"zero-dimensional integer arrays and boolean arrays "
"are specified in the Array API."
)
if op == "setitem":
if isinstance(i, Array) and i.dtype in _integer_dtypes:
raise IndexError("Fancy indexing __setitem__ is not supported.")

nonexpanding_key = []
single_axes = []
n_ellipsis = 0
key_has_mask = False
key_has_index_array = False
key_has_slices = False
for i in _key:
if i is not None:
nonexpanding_key.append(i)
Expand All @@ -403,13 +408,17 @@ def _validate_index(self, key):
if isinstance(i, Array):
if i.dtype in _boolean_dtypes:
key_has_mask = True
elif i.dtype in _integer_dtypes:
key_has_index_array = True
single_axes.append(i)
else:
# i must not be an array here, to avoid elementwise equals
if i == Ellipsis:
n_ellipsis += 1
else:
single_axes.append(i)
if isinstance(i, slice):
key_has_slices = True

n_single_axes = len(single_axes)
if n_ellipsis > 1:
Expand All @@ -427,6 +436,12 @@ def _validate_index(self, key):
"specified in the Array API."
)

if (key_has_index_array and (n_ellipsis > 0 or key_has_slices or key_has_mask)):
raise IndexError(
"Integer index arrays are only allowed with integer indices; "
f"got {key}."
)

if n_ellipsis == 0:
indexed_shape = self.shape
else:
Expand Down Expand Up @@ -483,14 +498,9 @@ def _validate_index(self, key):
"Array API when the array is the sole index."
)
if not get_array_api_strict_flags()['boolean_indexing']:
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")

elif i.dtype in _integer_dtypes and i.ndim != 0:
raise IndexError(
f"Single-axes index {i} is a non-zero-dimensional "
"integer array, but advanced integer indexing is not "
"specified in the Array API."
)
raise RuntimeError(
"The boolean_indexing flag has been disabled for array-api-strict"
)
elif isinstance(i, tuple):
raise IndexError(
f"Single-axes index {i} is a tuple, but nested tuple "
Expand Down Expand Up @@ -902,7 +912,7 @@ def __setitem__(
"""
# Note: Only indices required by the spec are allowed. See the
# docstring of _validate_index
self._validate_index(key)
self._validate_index(key, op="setitem")
if isinstance(key, Array):
# Indexing self._array with array_api_strict arrays can be erroneous
key = key._array
Expand Down
100 changes: 75 additions & 25 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from .. import ones, asarray, result_type, all, equal
from .. import ones, arange, reshape, asarray, result_type, all, equal
from .._array_object import Array, CPU_DEVICE, Device
from .._dtypes import (
_all_dtypes,
Expand Down Expand Up @@ -45,35 +45,46 @@ def test_validate_index():
a = ones((3, 4))

# Out of bounds slices are not allowed
assert_raises(IndexError, lambda: a[:4])
assert_raises(IndexError, lambda: a[:-4])
assert_raises(IndexError, lambda: a[:3:-1])
assert_raises(IndexError, lambda: a[:-5:-1])
assert_raises(IndexError, lambda: a[4:])
assert_raises(IndexError, lambda: a[-4:])
assert_raises(IndexError, lambda: a[4::-1])
assert_raises(IndexError, lambda: a[-4::-1])

assert_raises(IndexError, lambda: a[...,:5])
assert_raises(IndexError, lambda: a[...,:-5])
assert_raises(IndexError, lambda: a[...,:5:-1])
assert_raises(IndexError, lambda: a[...,:-6:-1])
assert_raises(IndexError, lambda: a[...,5:])
assert_raises(IndexError, lambda: a[...,-5:])
assert_raises(IndexError, lambda: a[...,5::-1])
assert_raises(IndexError, lambda: a[...,-5::-1])
assert_raises(IndexError, lambda: a[:4, 0])
assert_raises(IndexError, lambda: a[:-4, 0])
assert_raises(IndexError, lambda: a[:3:-1]) # XXX raises for a wrong reason
assert_raises(IndexError, lambda: a[:-5:-1, 0])
assert_raises(IndexError, lambda: a[4:, 0])
assert_raises(IndexError, lambda: a[-4:, 0])
assert_raises(IndexError, lambda: a[4::-1, 0])
assert_raises(IndexError, lambda: a[-4::-1, 0])

assert_raises(IndexError, lambda: a[..., :5])
assert_raises(IndexError, lambda: a[..., :-5])
assert_raises(IndexError, lambda: a[..., :5:-1])
assert_raises(IndexError, lambda: a[..., :-6:-1])
assert_raises(IndexError, lambda: a[..., 5:])
assert_raises(IndexError, lambda: a[..., -5:])
assert_raises(IndexError, lambda: a[..., 5::-1])
assert_raises(IndexError, lambda: a[..., -5::-1])

# Boolean indices cannot be part of a larger tuple index
assert_raises(IndexError, lambda: a[a[:,0]==1,0])
assert_raises(IndexError, lambda: a[a[:,0]==1,...])
assert_raises(IndexError, lambda: a[..., a[0]==1])
assert_raises(IndexError, lambda: a[a[:, 0] == 1, 0])
assert_raises(IndexError, lambda: a[a[:, 0] == 1, ...])
assert_raises(IndexError, lambda: a[..., a[0] == 1])
assert_raises(IndexError, lambda: a[[True, True, True]])
assert_raises(IndexError, lambda: a[(True, True, True),])

# Integer array indices are not allowed (except for 0-D)
idx = asarray([[0, 1]])
assert_raises(IndexError, lambda: a[idx])
assert_raises(IndexError, lambda: a[idx,])
# Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed
idx = asarray([0, 1])
assert_raises(IndexError, lambda: a[..., idx])
assert_raises(IndexError, lambda: a[:, idx])
assert_raises(IndexError, lambda: a[asarray([True, True]), idx])

# 1D integer array indices must have the same length
idx1 = asarray([0, 1])
idx2 = asarray([0, 1, 1])
assert_raises(IndexError, lambda: a[idx1, idx2])

# Non-integer array indices are not allowed
assert_raises(IndexError, lambda: a[ones(2), 0])

# Array-likes (lists, tuples) are not allowed as indices
assert_raises(IndexError, lambda: a[[0, 1]])
assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
assert_raises(IndexError, lambda: a[[0, 1]])
Expand All @@ -87,6 +98,45 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[0,])
assert_raises(IndexError, lambda: a[0])
assert_raises(IndexError, lambda: a[:])
assert_raises(IndexError, lambda: a[idx])


def test_indexing_arrays():
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed

# 1D array
a = arange(5)
idx = asarray([1, 0, 1, 2, -1])
a_idx = a[idx]

a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx] = 42

# mixed array and integer indexing
a = reshape(arange(3*4), (3, 4))
idx = asarray([1, 0, 1, 2, -1])
a_idx = a[idx, 1]

a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)

# index with two arrays
a_idx = a[idx, idx]
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx, idx] = 42

# smoke test indexing with ndim > 1 arrays
idx = idx[..., None]
a[idx, idx]


def test_promoted_scalar_inherits_device():
device1 = Device("device1")
Expand Down

0 comments on commit 4b8fbef

Please sign in to comment.