Skip to content

Commit

Permalink
Create well-documented helper for strictly creating ragged ndarrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Nov 22, 2022
1 parent 95af3b3 commit 7410a0c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
9 changes: 5 additions & 4 deletions python/ray/air/util/tensor_extensions/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import numpy as np
import pyarrow as pa

from ray.air.util.tensor_extensions.utils import _is_ndarray_variable_shaped_tensor
from ray.air.util.tensor_extensions.utils import (
_is_ndarray_variable_shaped_tensor,
_create_strict_ragged_ndarray,
)
from ray._private.utils import _get_pyarrow_version
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -721,9 +724,7 @@ def _to_numpy(self, index: Optional[int] = None, zero_copy_only: bool = False):
arrs = [self._to_numpy(i, zero_copy_only) for i in range(len(self))]
# Return ragged NumPy ndarray in the ndarray of ndarray pointers
# representation.
arr = np.empty(len(self), dtype=object)
arr[:] = arrs
return arr
return _create_strict_ragged_ndarray(arrs)
data = self.storage.field("data")
shapes = self.storage.field("shape")
value_type = data.type.value_type
Expand Down
16 changes: 7 additions & 9 deletions python/ray/air/util/tensor_extensions/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
from pandas.core.indexers import check_array_indexer, validate_indices
from pandas.io.formats.format import ExtensionArrayFormatter

from ray.air.util.tensor_extensions.utils import _is_ndarray_variable_shaped_tensor
from ray.air.util.tensor_extensions.utils import (
_is_ndarray_variable_shaped_tensor,
_create_strict_ragged_ndarray,
)
from ray.util.annotations import PublicAPI

try:
Expand Down Expand Up @@ -1422,9 +1425,7 @@ def _is_boolean(self):


def _create_possibly_ragged_ndarray(
values: Union[
np.ndarray, ABCSeries, Sequence[Union[np.ndarray, TensorArrayElement]]
]
values: Union[np.ndarray, ABCSeries, Sequence[np.ndarray]]
) -> np.ndarray:
"""
Create a possibly ragged ndarray.
Expand All @@ -1438,11 +1439,8 @@ def _create_possibly_ragged_ndarray(
return np.array(values, copy=False)
except ValueError as e:
if "could not broadcast input array from shape" in str(e):
# Create an empty object-dtyped 1D array.
arr = np.empty(len(values), dtype=object)
# Try to fill the 1D array of pointers with the (ragged) tensors.
arr[:] = list(values)
return arr
# Fall back to strictly creating a ragged ndarray.
return _create_strict_ragged_ndarray(values)
else:
# Re-raise original error if the failure wasn't a broadcast error.
raise e from None
Expand Down
22 changes: 22 additions & 0 deletions python/ray/air/util/tensor_extensions/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import numpy as np


Expand All @@ -20,3 +22,23 @@ def _is_ndarray_variable_shaped_tensor(arr: np.ndarray) -> bool:
if a.shape != shape:
return True
return True


def _create_strict_ragged_ndarray(values: Any) -> np.ndarray:
"""Create a ragged ndarray; the representation will be ragged (1D array of
subndarray pointers) even if it's possible to represent it as a non-ragged ndarray.
"""
# Use the create-empty-and-fill method. This avoids the following pitfalls of the
# np.array constructor - np.array(values, dtype=object):
# 1. It will fail to construct an ndarray if the first element dimension is
# uniform, e.g. for imagery whose first element dimension is the channel.
# 2. It will construct the wrong representation for a single-row column (i.e. unit
# outer dimension). Namely, it will consolidate it into a single multi-dimensional
# ndarray rather than a 1D array of subndarray pointers, resulting in the single
# row not being well-typed (having object dtype).

# Create an empty object-dtyped 1D array.
arr = np.empty(len(values), dtype=object)
# Try to fill the 1D array of pointers with the (ragged) tensors.
arr[:] = list(values)
return arr

0 comments on commit 7410a0c

Please sign in to comment.