From 1c7bbe4dc1b06525aac809929c80cff96b691071 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 4 Oct 2024 13:41:43 -0700 Subject: [PATCH 1/5] Initial commit Signed-off-by: Balaji Veeramani --- python/ray/data/dataset.py | 20 ++++++++++++++++++++ python/ray/data/tests/test_map.py | 8 ++++++++ 2 files changed, 28 insertions(+) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4a93034fa058..8d96a71b339a 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -862,6 +862,26 @@ def select_columns(batch): **ray_remote_args, ) + def rename_columns( + self, + mapper: Dict[str, str], + *, + compute: Union[str, ComputeStrategy] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + **ray_remote_args, + ): + def rename_columns(batch: "pyarrow.Table") -> "pyarrow.Table": + return batch.rename_columns(mapper) + + return self.map_batches( + rename_columns, + batch_format="pyarrow", + zero_copy_batch=True, + compute=compute, + concurrency=concurrency, + **ray_remote_args, + ) + @PublicAPI(api_group=BT_API_GROUP) def flat_map( self, diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 675a0a3dd417..07d5c956e254 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -343,6 +343,14 @@ def test_add_column(ray_start_regular_shared): ds = ray.data.range(5).add_column("id", 0) +def test_rename_columns(ray_start_regular_shared): + ds = ray.data.from_items([{"spam": 0, "ham": 0}]) + + renamed_ds = ds.rename_columns({"spam": "foo", "ham": "bar"}) + + assert renamed_ds.schema().names == ["foo", "bar"] + + def test_drop_columns(ray_start_regular_shared, tmp_path): df = pd.DataFrame({"col1": [1, 2, 3], "col2": [2, 3, 4], "col3": [3, 4, 5]}) ds1 = ray.data.from_pandas(df) From 4637c0c0d3cf74152e5e958b6bb4bc40daf2cada Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 4 Oct 2024 13:53:22 -0700 Subject: [PATCH 2/5] Add docstring Signed-off-by: Balaji Veeramani --- python/ray/data/dataset.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 8d96a71b339a..7d6de9f98614 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -866,10 +866,44 @@ def rename_columns( self, mapper: Dict[str, str], *, - compute: Union[str, ComputeStrategy] = None, concurrency: Optional[Union[int, Tuple[int, int]]] = None, **ray_remote_args, ): + """Rename columns in the dataset. + + Examples: + + >>> import ray + >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") + >>> ds.schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width doubles + variety string + >>> ds.rename_columns({"variety": "category"}).schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + category string + + Args: + mapper: A dictionary that maps old column names to new column names. + concurrency: The maximum number of Ray workers to use concurrently. + ray_remote_args: Additional resource requirements to request from + ray (e.g., num_gpus=1 to request GPUs for the map tasks). + """ # noqa: E501 + if concurrency is not None and not isinstance(concurrency, int): + raise ValueError( + "Expected `concurrency` to be an integer or `None`, but got " + f"{concurrency}." + ) + def rename_columns(batch: "pyarrow.Table") -> "pyarrow.Table": return batch.rename_columns(mapper) @@ -877,7 +911,6 @@ def rename_columns(batch: "pyarrow.Table") -> "pyarrow.Table": rename_columns, batch_format="pyarrow", zero_copy_batch=True, - compute=compute, concurrency=concurrency, **ray_remote_args, ) From 2b94a682c121ca84774c3661cff6ff325f01ef24 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 4 Oct 2024 13:55:08 -0700 Subject: [PATCH 3/5] Add decorator Signed-off-by: Balaji Veeramani --- python/ray/data/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 7d6de9f98614..1134f07c2f48 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -862,6 +862,7 @@ def select_columns(batch): **ray_remote_args, ) + @PublicAPI(api_group=BT_API_GROUP) def rename_columns( self, mapper: Dict[str, str], From 829a222e8e74ecb6d766e18f14bf6336670e3578 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 9 Oct 2024 09:09:23 -0700 Subject: [PATCH 4/5] Address review comments Signed-off-by: Balaji Veeramani --- python/ray/data/dataset.py | 34 +++++++++++++++++++++++++++---- python/ray/data/tests/test_map.py | 7 ++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 1134f07c2f48..94c96a6a9721 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -865,7 +865,7 @@ def select_columns(batch): @PublicAPI(api_group=BT_API_GROUP) def rename_columns( self, - mapper: Dict[str, str], + names: Union[List[str], Dict[str, str]], *, concurrency: Optional[Union[int, Tuple[int, int]]] = None, **ray_remote_args, @@ -882,8 +882,11 @@ def rename_columns( sepal.length double sepal.width double petal.length double - petal.width doubles + petal.width double variety string + + You can pass a dictionary mapping old column names to new column names. + >>> ds.rename_columns({"variety": "category"}).schema() Column Type ------ ---- @@ -893,8 +896,22 @@ def rename_columns( petal.width double category string + Or you can pass a list of new column names. + + >>> ds.rename_columns( + ... ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"] + ... ).schema() + Column Type + ------ ---- + sepal_length double + sepal_width double + petal_length double + petal_width double + variety string + Args: - mapper: A dictionary that maps old column names to new column names. + mapper: A dictionary that maps old column names to new column names, or a + list of new column names. concurrency: The maximum number of Ray workers to use concurrently. ray_remote_args: Additional resource requirements to request from ray (e.g., num_gpus=1 to request GPUs for the map tasks). @@ -906,7 +923,16 @@ def rename_columns( ) def rename_columns(batch: "pyarrow.Table") -> "pyarrow.Table": - return batch.rename_columns(mapper) + # Versions of PyArrow before 17 don't support renaming columns with a dict. + if isinstance(names, dict): + column_names_list = batch.column_names + for i, column_name in enumerate(column_names_list): + if column_name in names: + column_names_list[i] = names[column_name] + else: + column_names_list = names + + return batch.rename_columns(column_names_list) return self.map_batches( rename_columns, diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 07d5c956e254..3b3fdb384a43 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -343,15 +343,16 @@ def test_add_column(ray_start_regular_shared): ds = ray.data.range(5).add_column("id", 0) -def test_rename_columns(ray_start_regular_shared): +@pytest.mark.parametrize("names", (["foo", "bar"], {"spam": "foo", "ham": "bar"})) +def test_rename_columns(ray_start_regular_shared, names): ds = ray.data.from_items([{"spam": 0, "ham": 0}]) - renamed_ds = ds.rename_columns({"spam": "foo", "ham": "bar"}) + renamed_ds = ds.rename_columns(names) assert renamed_ds.schema().names == ["foo", "bar"] -def test_drop_columns(ray_start_regular_shared, tmp_path): +def test_drop_columns(ray_start_regular_shared, tmp_wpath): df = pd.DataFrame({"col1": [1, 2, 3], "col2": [2, 3, 4], "col3": [3, 4, 5]}) ds1 = ray.data.from_pandas(df) ds1.write_parquet(str(tmp_path)) From 27f09125aaa51bbe20c37af69eee30696e7d8183 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 9 Oct 2024 09:09:52 -0700 Subject: [PATCH 5/5] Fix typo Signed-off-by: Balaji Veeramani --- python/ray/data/tests/test_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 3b3fdb384a43..4f058a9152f5 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -352,7 +352,7 @@ def test_rename_columns(ray_start_regular_shared, names): assert renamed_ds.schema().names == ["foo", "bar"] -def test_drop_columns(ray_start_regular_shared, tmp_wpath): +def test_drop_columns(ray_start_regular_shared, tmp_path): df = pd.DataFrame({"col1": [1, 2, 3], "col2": [2, 3, 4], "col3": [3, 4, 5]}) ds1 = ray.data.from_pandas(df) ds1.write_parquet(str(tmp_path))