Skip to content

Commit

Permalink
[Data] Add Dataset.rename_columns (#47906)
Browse files Browse the repository at this point in the history
Fixes #32261

---------

Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
  • Loading branch information
bveeramani authored Oct 9, 2024
1 parent b69b929 commit c9bb68b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
80 changes: 80 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,86 @@ def select_columns(batch):
**ray_remote_args,
)

@PublicAPI(api_group=BT_API_GROUP)
def rename_columns(
self,
names: Union[List[str], Dict[str, str]],
*,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
**ray_remote_args,
):
"""Rename columns in the dataset.
Examples:
>>> import ray
>>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
>>> ds.schema()
Column Type
------ ----
sepal.length double
sepal.width double
petal.length double
petal.width double
variety string
You can pass a dictionary mapping old column names to new column names.
>>> ds.rename_columns({"variety": "category"}).schema()
Column Type
------ ----
sepal.length double
sepal.width double
petal.length double
petal.width double
category string
Or you can pass a list of new column names.
>>> ds.rename_columns(
... ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"]
... ).schema()
Column Type
------ ----
sepal_length double
sepal_width double
petal_length double
petal_width double
variety string
Args:
mapper: A dictionary that maps old column names to new column names, or a
list of new column names.
concurrency: The maximum number of Ray workers to use concurrently.
ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
""" # noqa: E501
if concurrency is not None and not isinstance(concurrency, int):
raise ValueError(
"Expected `concurrency` to be an integer or `None`, but got "
f"{concurrency}."
)

def rename_columns(batch: "pyarrow.Table") -> "pyarrow.Table":
# Versions of PyArrow before 17 don't support renaming columns with a dict.
if isinstance(names, dict):
column_names_list = batch.column_names
for i, column_name in enumerate(column_names_list):
if column_name in names:
column_names_list[i] = names[column_name]
else:
column_names_list = names

return batch.rename_columns(column_names_list)

return self.map_batches(
rename_columns,
batch_format="pyarrow",
zero_copy_batch=True,
concurrency=concurrency,
**ray_remote_args,
)

@PublicAPI(api_group=BT_API_GROUP)
def flat_map(
self,
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,15 @@ def test_add_column(ray_start_regular_shared):
ds = ray.data.range(5).add_column("id", 0)


@pytest.mark.parametrize("names", (["foo", "bar"], {"spam": "foo", "ham": "bar"}))
def test_rename_columns(ray_start_regular_shared, names):
ds = ray.data.from_items([{"spam": 0, "ham": 0}])

renamed_ds = ds.rename_columns(names)

assert renamed_ds.schema().names == ["foo", "bar"]


def test_drop_columns(ray_start_regular_shared, tmp_path):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [2, 3, 4], "col3": [3, 4, 5]})
ds1 = ray.data.from_pandas(df)
Expand Down

0 comments on commit c9bb68b

Please sign in to comment.