From d99e402be5309af43c11066af8981935083ab823 Mon Sep 17 00:00:00 2001 From: jay Date: Thu, 2 Jan 2025 19:15:21 +0800 Subject: [PATCH] [data] fix groupby hang when value contains np.nan (#49420) Signed-off-by: Puyuan Yao --- python/ray/data/_internal/arrow_block.py | 6 +++--- python/ray/data/_internal/pandas_block.py | 6 +++--- python/ray/data/_internal/util.py | 16 ++++++++++++++++ python/ray/data/tests/test_all_to_all.py | 18 ++++++++++++++++++ python/ray/data/tests/test_arrow_block.py | 16 ++++++++++++++++ 5 files changed, 56 insertions(+), 6 deletions(-) diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 4b9f055c83f24..220efa65e6a67 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -28,7 +28,7 @@ from ray.data._internal.numpy_support import convert_to_numpy from ray.data._internal.row import TableRow from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder -from ray.data._internal.util import NULL_SENTINEL, find_partitions +from ray.data._internal.util import NULL_SENTINEL, find_partitions, keys_equal from ray.data.block import ( Block, BlockAccessor, @@ -472,7 +472,7 @@ def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]: if next_row is None: next_row = next(iter) next_keys = next_row[keys] - while next_row[keys] == next_keys: + while keys_equal(next_row[keys], next_keys): end += 1 try: next_row = next(iter) @@ -592,7 +592,7 @@ def key_fn_with_null_sentinel(r): def gen(): nonlocal iter nonlocal next_row - while key_fn(next_row) == next_keys: + while keys_equal(key_fn(next_row), next_keys): yield next_row try: next_row = next(iter) diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index 73e0788fbb4b0..54dd12f29f08b 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -23,7 +23,7 @@ from ray.data._internal.numpy_support import convert_to_numpy, validate_numpy_batch from ray.data._internal.row import TableRow from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder -from ray.data._internal.util import find_partitions +from ray.data._internal.util import find_partitions, keys_equal from ray.data.block import ( Block, BlockAccessor, @@ -553,7 +553,7 @@ def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]: if next_row is None: next_row = next(iter) next_keys = next_row[keys] - while np.all(next_row[keys] == next_keys): + while keys_equal(next_row[keys], next_keys): end += 1 try: next_row = next(iter) @@ -671,7 +671,7 @@ def key_fn(r): def gen(): nonlocal iter nonlocal next_row - while key_fn(next_row) == next_keys: + while keys_equal(key_fn(next_row), next_keys): yield next_row try: next_row = next(iter) diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 0c4969b48e8cd..66fab790d2706 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -1102,3 +1102,19 @@ def convert_bytes_to_human_readable_str(num_bytes: int) -> str: else: num_bytes_str = f"{round(num_bytes / 1e3)}KB" return num_bytes_str + + +def is_nan(value): + try: + return isinstance(value, float) and np.isnan(value) + except TypeError: + return False + + +def keys_equal(keys1, keys2): + if len(keys1) != len(keys2): + return False + for k1, k2 in zip(keys1, keys2): + if not ((is_nan(k1) and is_nan(k2)) or k1 == k2): + return False + return True diff --git a/python/ray/data/tests/test_all_to_all.py b/python/ray/data/tests/test_all_to_all.py index a6b1733831451..09b668fa15c63 100644 --- a/python/ray/data/tests/test_all_to_all.py +++ b/python/ray/data/tests/test_all_to_all.py @@ -357,6 +357,24 @@ def test_groupby_agg_name_conflict(ray_start_regular_shared, num_parts): ] +@pytest.mark.parametrize("ds_format", ["pyarrow", "numpy", "pandas"]) +def test_groupby_nans(ray_start_regular_shared, ds_format): + ds = ray.data.from_items( + [ + 1.0, + 1.0, + 2.0, + np.nan, + np.nan, + ] + ) + ds = ds.map_batches(lambda x: x, batch_format=ds_format) + ds = ds.groupby("item").count() + ds = ds.filter(lambda v: np.isnan(v["item"])) + result = ds.take_all() + assert result[0]["count()"] == 2 + + @pytest.mark.parametrize("num_parts", [1, 30]) @pytest.mark.parametrize("ds_format", ["arrow", "pandas"]) def test_groupby_tabular_count( diff --git a/python/ray/data/tests/test_arrow_block.py b/python/ray/data/tests/test_arrow_block.py index f7b54406f257e..ec6342c7e2b4c 100644 --- a/python/ray/data/tests/test_arrow_block.py +++ b/python/ray/data/tests/test_arrow_block.py @@ -355,5 +355,21 @@ def test_build_block_with_null_column(ray_start_regular_shared): assert np.array_equal(rows[1]["array"], np.zeros((2, 2))) +def test_arrow_nan_element(): + ds = ray.data.from_items( + [ + 1.0, + 1.0, + 2.0, + np.nan, + np.nan, + ] + ) + ds = ds.groupby("item").count() + ds = ds.filter(lambda v: np.isnan(v["item"])) + result = ds.take_all() + assert result[0]["count()"] == 2 + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__]))