Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Add Dataset.rename_columns #47906

Merged
merged 5 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,86 @@ def select_columns(batch):
**ray_remote_args,
)

@PublicAPI(api_group=BT_API_GROUP)
def rename_columns(
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
self,
names: Union[List[str], Dict[str, str]],
*,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
**ray_remote_args,
):
"""Rename columns in the dataset.

Examples:

>>> import ray
>>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
>>> ds.schema()
Column Type
------ ----
sepal.length double
sepal.width double
petal.length double
petal.width double
variety string

You can pass a dictionary mapping old column names to new column names.

>>> ds.rename_columns({"variety": "category"}).schema()
Column Type
------ ----
sepal.length double
sepal.width double
petal.length double
petal.width double
category string

Or you can pass a list of new column names.

>>> ds.rename_columns(
... ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"]
... ).schema()
Column Type
------ ----
sepal_length double
sepal_width double
petal_length double
petal_width double
variety string

Args:
mapper: A dictionary that maps old column names to new column names, or a
list of new column names.
concurrency: The maximum number of Ray workers to use concurrently.
ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
""" # noqa: E501
if concurrency is not None and not isinstance(concurrency, int):
raise ValueError(
"Expected `concurrency` to be an integer or `None`, but got "
f"{concurrency}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, concurrency validation will be done within map_batches. no need to add it here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm performing an additional validation here to avoid this confusing error message when you pass a tuple to concurrency:

ValueError: concurrency is set as a tuple of integers, but fn is not a callable class: <function Dataset.rename_columns..rename_columns at 0x166532550>. Use concurrency=n to control maximum number of workers to use.


def rename_columns(batch: "pyarrow.Table") -> "pyarrow.Table":
# Versions of PyArrow before 17 don't support renaming columns with a dict.
if isinstance(names, dict):
column_names_list = batch.column_names
for i, column_name in enumerate(column_names_list):
if column_name in names:
column_names_list[i] = names[column_name]
else:
column_names_list = names

return batch.rename_columns(column_names_list)

return self.map_batches(
rename_columns,
batch_format="pyarrow",
zero_copy_batch=True,
concurrency=concurrency,
**ray_remote_args,
)

@PublicAPI(api_group=BT_API_GROUP)
def flat_map(
self,
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,15 @@ def test_add_column(ray_start_regular_shared):
ds = ray.data.range(5).add_column("id", 0)


@pytest.mark.parametrize("names", (["foo", "bar"], {"spam": "foo", "ham": "bar"}))
def test_rename_columns(ray_start_regular_shared, names):
ds = ray.data.from_items([{"spam": 0, "ham": 0}])

renamed_ds = ds.rename_columns(names)

assert renamed_ds.schema().names == ["foo", "bar"]


def test_drop_columns(ray_start_regular_shared, tmp_path):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [2, 3, 4], "col3": [3, 4, 5]})
ds1 = ray.data.from_pandas(df)
Expand Down
Loading