diff --git a/python/ray/air/util/tensor_extensions/arrow.py b/python/ray/air/util/tensor_extensions/arrow.py index 2ee8179a626ab..91b5daecbefe9 100644 --- a/python/ray/air/util/tensor_extensions/arrow.py +++ b/python/ray/air/util/tensor_extensions/arrow.py @@ -5,7 +5,10 @@ import numpy as np import pyarrow as pa -from ray.air.util.tensor_extensions.utils import _is_ndarray_variable_shaped_tensor +from ray.air.util.tensor_extensions.utils import ( + _is_ndarray_variable_shaped_tensor, + _create_strict_ragged_ndarray, +) from ray._private.utils import _get_pyarrow_version from ray.util.annotations import PublicAPI @@ -721,9 +724,7 @@ def _to_numpy(self, index: Optional[int] = None, zero_copy_only: bool = False): arrs = [self._to_numpy(i, zero_copy_only) for i in range(len(self))] # Return ragged NumPy ndarray in the ndarray of ndarray pointers # representation. - arr = np.empty(len(self), dtype=object) - arr[:] = arrs - return arr + return _create_strict_ragged_ndarray(arrs) data = self.storage.field("data") shapes = self.storage.field("shape") value_type = data.type.value_type diff --git a/python/ray/air/util/tensor_extensions/pandas.py b/python/ray/air/util/tensor_extensions/pandas.py index e23cce7b91987..4201be97201f9 100644 --- a/python/ray/air/util/tensor_extensions/pandas.py +++ b/python/ray/air/util/tensor_extensions/pandas.py @@ -44,7 +44,10 @@ from pandas.core.indexers import check_array_indexer, validate_indices from pandas.io.formats.format import ExtensionArrayFormatter -from ray.air.util.tensor_extensions.utils import _is_ndarray_variable_shaped_tensor +from ray.air.util.tensor_extensions.utils import ( + _is_ndarray_variable_shaped_tensor, + _create_strict_ragged_ndarray, +) from ray.util.annotations import PublicAPI try: @@ -1422,9 +1425,7 @@ def _is_boolean(self): def _create_possibly_ragged_ndarray( - values: Union[ - np.ndarray, ABCSeries, Sequence[Union[np.ndarray, TensorArrayElement]] - ] + values: Union[np.ndarray, ABCSeries, Sequence[np.ndarray]] ) -> np.ndarray: """ Create a possibly ragged ndarray. @@ -1438,11 +1439,8 @@ def _create_possibly_ragged_ndarray( return np.array(values, copy=False) except ValueError as e: if "could not broadcast input array from shape" in str(e): - # Create an empty object-dtyped 1D array. - arr = np.empty(len(values), dtype=object) - # Try to fill the 1D array of pointers with the (ragged) tensors. - arr[:] = list(values) - return arr + # Fall back to strictly creating a ragged ndarray. + return _create_strict_ragged_ndarray(values) else: # Re-raise original error if the failure wasn't a broadcast error. raise e from None diff --git a/python/ray/air/util/tensor_extensions/utils.py b/python/ray/air/util/tensor_extensions/utils.py index 3b7e60a579fbb..f28928b54c2ce 100644 --- a/python/ray/air/util/tensor_extensions/utils.py +++ b/python/ray/air/util/tensor_extensions/utils.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy as np @@ -20,3 +22,23 @@ def _is_ndarray_variable_shaped_tensor(arr: np.ndarray) -> bool: if a.shape != shape: return True return True + + +def _create_strict_ragged_ndarray(values: Any) -> np.ndarray: + """Create a ragged ndarray; the representation will be ragged (1D array of + subndarray pointers) even if it's possible to represent it as a non-ragged ndarray. + """ + # Use the create-empty-and-fill method. This avoids the following pitfalls of the + # np.array constructor - np.array(values, dtype=object): + # 1. It will fail to construct an ndarray if the first element dimension is + # uniform, e.g. for imagery whose first element dimension is the channel. + # 2. It will construct the wrong representation for a single-row column (i.e. unit + # outer dimension). Namely, it will consolidate it into a single multi-dimensional + # ndarray rather than a 1D array of subndarray pointers, resulting in the single + # row not being well-typed (having object dtype). + + # Create an empty object-dtyped 1D array. + arr = np.empty(len(values), dtype=object) + # Try to fill the 1D array of pointers with the (ragged) tensors. + arr[:] = list(values) + return arr