Skip to content

Commit

Permalink
fix: indexing into RegularArray with typetracer (#2227)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored Feb 9, 2023
1 parent e8600ca commit 25955f4
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 37 deletions.
32 changes: 25 additions & 7 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy

import awkward as ak
from awkward._errors import wrap_error
from awkward._nplikes.numpylike import ArrayLike, IndexType, NumpyLike, NumpyMetadata
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward.typing import Final, Literal
Expand All @@ -31,7 +31,7 @@ def asarray(
return self._module.asarray(obj, dtype=dtype)
else:
if getattr(obj, "dtype", dtype) != dtype:
raise ak._errors.wrap_error(
raise wrap_error(
ValueError(
"asarray was called with copy=False for an array of a different dtype"
)
Expand Down Expand Up @@ -121,7 +121,7 @@ def reshape(
self, x: ArrayLike, shape: tuple[int, ...], *, copy: bool | None = None
) -> ArrayLike:
if copy is False:
raise ak._errors.wrap_error(
raise wrap_error(
NotImplementedError(
"reshape was called with copy=False, which is currently not supported"
)
Expand All @@ -134,15 +134,13 @@ def reshape(

def shape_item_as_index(self, x1: ShapeItem) -> int:
if x1 is unknown_length:
raise ak._errors.wrap_error(
raise wrap_error(
TypeError("array module nplikes do not support unknown lengths")
)
elif isinstance(x1, int):
return x1
else:
raise ak._errors.wrap_error(
TypeError(f"expected None or int type, received {x1}")
)
raise wrap_error(TypeError(f"expected None or int type, received {x1}"))

def index_as_shape_item(self, x1: IndexType) -> int:
return int(x1)
Expand All @@ -163,6 +161,26 @@ def derive_slice_for_length(
slice_length = math.ceil((stop - start) / step)
return start, stop, step, slice_length

def regularize_index_for_length(
self, index: IndexType, length: ShapeItem
) -> IndexType:
"""
Args:
index: index value
length: length of array
Returns regularized index that is guaranteed to be in-bounds.
""" # We have known length and index
if index < 0:
index = index + length

if 0 <= index < length:
return index
else:
raise wrap_error(
IndexError(f"index value out of bounds (0, {length}): {index}")
)

def nonzero(self, x: ArrayLike) -> tuple[ArrayLike, ...]:
return self._module.nonzero(x)

Expand Down
6 changes: 6 additions & 0 deletions src/awkward/_nplikes/numpylike.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ def derive_slice_for_length(
) -> tuple[IndexType, IndexType, IndexType, ShapeItem]:
...

@abstractmethod
def regularize_index_for_length(
self, index: IndexType, length: ShapeItem
) -> IndexType:
...

@abstractmethod
def reshape(
self, x: ArrayLike, shape: tuple[int, ...], *, copy: bool | None = None
Expand Down
30 changes: 30 additions & 0 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,36 @@ def index_as_shape_item(self, x1: IndexType) -> ShapeItem:
else:
return int(x1)

def regularize_index_for_length(
self, index: IndexType, length: ShapeItem
) -> IndexType:
"""
Args:
index: index value
length: length of array
Returns regularized index that is guaranteed to be in-bounds.
"""
# Unknown indices are already regularized
if is_unknown_scalar(index):
return index

# Without a known length the result must be unknown, as we cannot regularize the index
length_scalar = self.shape_item_as_index(length)
if length is unknown_length:
return length_scalar

# We have known length and index
if index < 0:
index = index + length

if 0 <= index < length:
return index
else:
raise wrap_error(
IndexError(f"index value out of bounds (0, {length}): {index}")
)

def derive_slice_for_length(
self, slice_: slice, length: ShapeItem
) -> tuple[IndexType, IndexType, IndexType, ShapeItem]:
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from awkward.typing import TYPE_CHECKING, Sequence, TypeAlias

if TYPE_CHECKING:
from awkward._nplikes.numpylike import ArrayLike # noqa: F401
from awkward._nplikes.numpylike import ArrayLike
from awkward.contents.content import Content

np = NumpyMetadata.instance()
Expand Down
32 changes: 3 additions & 29 deletions src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import IndexType, NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._nplikes.typetracer import is_unknown_scalar
from awkward._util import unset
from awkward.contents.content import Content
from awkward.forms.form import _type_parameters_equal
Expand Down Expand Up @@ -291,14 +290,9 @@ def _getitem_nothing(self):

def _getitem_at(self, where: IndexType):
index_nplike = self._backend.index_nplike
if index_nplike.known_data and where < 0:
where += self._length

if not (self._length is unknown_length or 0 <= where < self._length):
raise ak._errors.index_error(self, where)
start, stop = where * index_nplike.shape_item_as_index(self._size), (
where + 1
) * index_nplike.shape_item_as_index(self._size)
where = index_nplike.regularize_index_for_length(where, self._length)
size_scalar = index_nplike.shape_item_as_index(self._size)
start, stop = where * size_scalar, (where + 1) * size_scalar
return self._content._getitem_range(start, stop)

def _getitem_range(self, start: SupportsIndex, stop: IndexType) -> Content:
Expand Down Expand Up @@ -501,26 +495,6 @@ def _getitem_next(
head, length=self._size
)

if (
is_unknown_scalar(start)
or is_unknown_scalar(stop)
or is_unknown_scalar(step)
):
nextsize = unknown_length
else:
if step > 0 and stop > start:
diff = stop - start
nextsize = diff // step
if diff % step != 0:
nextsize += 1
elif step < 0 and stop < start:
diff = start - stop
nextsize = diff // (step * -1)
if diff % step != 0:
nextsize += 1
else:
nextsize = 0

nextcarry = ak.index.Index64.empty(self._length * nextsize, index_nplike)
assert nextcarry.nplike is index_nplike
self._handle_error(
Expand Down

0 comments on commit 25955f4

Please sign in to comment.