Skip to content

Commit

Permalink
[data] fix groupby hang when value contains np.nan (ray-project#49420)
Browse files Browse the repository at this point in the history
Signed-off-by: Puyuan Yao <williamyao034@gmail.com>
  • Loading branch information
Jay-ju authored and anyadontfly committed Feb 13, 2025
1 parent c316798 commit d99e402
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 6 deletions.
6 changes: 3 additions & 3 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ray.data._internal.numpy_support import convert_to_numpy
from ray.data._internal.row import TableRow
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
from ray.data._internal.util import NULL_SENTINEL, find_partitions
from ray.data._internal.util import NULL_SENTINEL, find_partitions, keys_equal
from ray.data.block import (
Block,
BlockAccessor,
Expand Down Expand Up @@ -472,7 +472,7 @@ def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
if next_row is None:
next_row = next(iter)
next_keys = next_row[keys]
while next_row[keys] == next_keys:
while keys_equal(next_row[keys], next_keys):
end += 1
try:
next_row = next(iter)
Expand Down Expand Up @@ -592,7 +592,7 @@ def key_fn_with_null_sentinel(r):
def gen():
nonlocal iter
nonlocal next_row
while key_fn(next_row) == next_keys:
while keys_equal(key_fn(next_row), next_keys):
yield next_row
try:
next_row = next(iter)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ray.data._internal.numpy_support import convert_to_numpy, validate_numpy_batch
from ray.data._internal.row import TableRow
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
from ray.data._internal.util import find_partitions
from ray.data._internal.util import find_partitions, keys_equal
from ray.data.block import (
Block,
BlockAccessor,
Expand Down Expand Up @@ -553,7 +553,7 @@ def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
if next_row is None:
next_row = next(iter)
next_keys = next_row[keys]
while np.all(next_row[keys] == next_keys):
while keys_equal(next_row[keys], next_keys):
end += 1
try:
next_row = next(iter)
Expand Down Expand Up @@ -671,7 +671,7 @@ def key_fn(r):
def gen():
nonlocal iter
nonlocal next_row
while key_fn(next_row) == next_keys:
while keys_equal(key_fn(next_row), next_keys):
yield next_row
try:
next_row = next(iter)
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,19 @@ def convert_bytes_to_human_readable_str(num_bytes: int) -> str:
else:
num_bytes_str = f"{round(num_bytes / 1e3)}KB"
return num_bytes_str


def is_nan(value):
try:
return isinstance(value, float) and np.isnan(value)
except TypeError:
return False


def keys_equal(keys1, keys2):
if len(keys1) != len(keys2):
return False
for k1, k2 in zip(keys1, keys2):
if not ((is_nan(k1) and is_nan(k2)) or k1 == k2):
return False
return True
18 changes: 18 additions & 0 deletions python/ray/data/tests/test_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,24 @@ def test_groupby_agg_name_conflict(ray_start_regular_shared, num_parts):
]


@pytest.mark.parametrize("ds_format", ["pyarrow", "numpy", "pandas"])
def test_groupby_nans(ray_start_regular_shared, ds_format):
ds = ray.data.from_items(
[
1.0,
1.0,
2.0,
np.nan,
np.nan,
]
)
ds = ds.map_batches(lambda x: x, batch_format=ds_format)
ds = ds.groupby("item").count()
ds = ds.filter(lambda v: np.isnan(v["item"]))
result = ds.take_all()
assert result[0]["count()"] == 2


@pytest.mark.parametrize("num_parts", [1, 30])
@pytest.mark.parametrize("ds_format", ["arrow", "pandas"])
def test_groupby_tabular_count(
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/test_arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,5 +355,21 @@ def test_build_block_with_null_column(ray_start_regular_shared):
assert np.array_equal(rows[1]["array"], np.zeros((2, 2)))


def test_arrow_nan_element():
ds = ray.data.from_items(
[
1.0,
1.0,
2.0,
np.nan,
np.nan,
]
)
ds = ds.groupby("item").count()
ds = ds.filter(lambda v: np.isnan(v["item"]))
result = ds.take_all()
assert result[0]["count()"] == 2


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

0 comments on commit d99e402

Please sign in to comment.