Skip to content

Commit

Permalink
[data] [docs] Generalize fix for converting lists to np.ndarray in UD…
Browse files Browse the repository at this point in the history
…Fs (#34930)
  • Loading branch information
ericl authored May 2, 2023
1 parent 5641891 commit fef04ac
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 19 deletions.
16 changes: 14 additions & 2 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":

@staticmethod
def numpy_to_block(
batch: Union[np.ndarray, Dict[str, np.ndarray]],
batch: Union[np.ndarray, Dict[str, np.ndarray], Dict[str, list]],
passthrough_arrow_not_implemented_errors: bool = False,
) -> "pyarrow.Table":
import pyarrow as pa
Expand All @@ -163,7 +163,7 @@ def numpy_to_block(
if isinstance(batch, np.ndarray):
batch = {TENSOR_COLUMN_NAME: batch}
elif not isinstance(batch, collections.abc.Mapping) or any(
not isinstance(col, np.ndarray) for col in batch.values()
not isinstance(col, (list, np.ndarray)) for col in batch.values()
):
raise ValueError(
"Batch must be an ndarray or dictionary of ndarrays when converting "
Expand All @@ -172,6 +172,18 @@ def numpy_to_block(
)
new_batch = {}
for col_name, col in batch.items():
if isinstance(col, list):
# Try to convert list values into an numpy array via
# np.array(), so users don't need to manually cast.
# NOTE: we don't cast generic iterables, since types like
# `str` are also Iterable.
try:
col = np.array(col)
except Exception:
raise ValueError(
"Failed to convert column values to numpy array: "
f"({_truncated_repr(col)})."
)
# Use Arrow's native *List types for 1-dimensional ndarrays.
if col.dtype.type is np.object_ or col.ndim > 1:
try:
Expand Down
12 changes: 0 additions & 12 deletions python/ray/data/_internal/planner/map_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,6 @@ def validate_batch(batch: Block) -> None:
f"{type(value)}. To fix this issue, convert "
f"the {type(value)} to a `np.ndarray`."
)
if isinstance(value, list):
# Try to convert list values into an numpy array via
# np.array(), so users don't need to manually cast.
# NOTE: we don't cast generic iterables, since types like
# `str` are also Iterable.
try:
batch[key] = np.array(value)
except Exception:
raise ValueError(
"Failed to convert column values to numpy array: "
f"({_truncated_repr(value)})."
)

def process_next_batch(batch: DataBatch) -> Iterator[Block]:
# Apply UDF.
Expand Down
29 changes: 24 additions & 5 deletions python/ray/data/tests/test_strict_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_strict_map_output(ray_start_regular_shared, enable_strict_mode):

with pytest.raises(StrictModeError):
ds.map_batches(lambda x: np.array([0]), max_retries=0).materialize()
ds.map_batches(lambda x: {"id": np.array([0])}).materialize()
ds.map_batches(lambda x: UserDict({"id": np.array([0])})).materialize()
ds.map_batches(lambda x: {"id": [0]}).materialize()
ds.map_batches(lambda x: UserDict({"id": [0]})).materialize()

with pytest.raises(StrictModeError):
ds.map(lambda x: np.ones(10), max_retries=0).materialize()
Expand All @@ -71,8 +71,8 @@ def test_strict_map_output(ray_start_regular_shared, enable_strict_mode):
ds.map_batches(lambda x: object(), max_retries=0).materialize()
with pytest.raises(ValueError):
ds.map_batches(lambda x: {"x": object()}, max_retries=0).materialize()
ds.map_batches(lambda x: {"x": np.array([object()])}).materialize()
ds.map_batches(lambda x: UserDict({"x": np.array([object()])})).materialize()
ds.map_batches(lambda x: {"x": [object()]}).materialize()
ds.map_batches(lambda x: UserDict({"x": [object()]})).materialize()

with pytest.raises(StrictModeError):
ds.map(lambda x: object(), max_retries=0).materialize()
Expand All @@ -86,7 +86,9 @@ def test_strict_convert_map_output(ray_start_regular_shared, enable_strict_mode)

with pytest.raises(ValueError):
# Strings not converted into array.
ray.data.range(1).map_batches(lambda x: {"id": "string"}).materialize()
ray.data.range(1).map_batches(
lambda x: {"id": "string"}, max_retries=0
).materialize()

class UserObj:
def __eq__(self, other):
Expand All @@ -100,6 +102,23 @@ def __eq__(self, other):
assert ds.take_batch()["id"].tolist() == [0, 1, 2, UserObj()]


def test_strict_convert_map_groups(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.read_csv("example://iris.csv")

def process_group(group):
variety = group["variety"][0]
count = len(group["variety"])

# Test implicit list->array conversion here.
return {
"variety": [variety],
"count": [count],
}

ds = ds.groupby("variety").map_groups(process_group)
ds.show()


def test_strict_default_batch_format(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1)

Expand Down

0 comments on commit fef04ac

Please sign in to comment.