Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Support partition_cols in write_parquet #49411

Merged
merged 4 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 85 additions & 15 deletions python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import posixpath
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional

from ray.data._internal.arrow_ops.transform_pyarrow import concat
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.util import call_with_retry
from ray.data.block import Block, BlockAccessor
Expand All @@ -24,6 +25,7 @@ def __init__(
self,
path: str,
*,
partition_cols: Optional[List[str]] = None,
arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
arrow_parquet_args: Optional[Dict[str, Any]] = None,
num_rows_per_file: Optional[int] = None,
Expand All @@ -42,6 +44,7 @@ def __init__(
self.arrow_parquet_args_fn = arrow_parquet_args_fn
self.arrow_parquet_args = arrow_parquet_args
self.num_rows_per_file = num_rows_per_file
self.partition_cols = partition_cols

super().__init__(
path,
Expand All @@ -59,7 +62,6 @@ def write(
ctx: TaskContext,
) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

blocks = list(blocks)

Expand All @@ -69,34 +71,102 @@ def write(
filename = self.filename_provider.get_filename_for_block(
blocks[0], ctx.task_idx, 0
)
write_path = posixpath.join(self.path, filename)
write_kwargs = _resolve_kwargs(
self.arrow_parquet_args_fn, **self.arrow_parquet_args
)
user_schema = write_kwargs.pop("schema", None)

def write_blocks_to_path():
with self.open_output_stream(write_path) as file:
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
if user_schema is None:
output_schema = pa.unify_schemas([table.schema for table in tables])
else:
output_schema = user_schema
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
if user_schema is None:
output_schema = pa.unify_schemas([table.schema for table in tables])
else:
output_schema = user_schema

with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
for table in tables:
table = table.cast(output_schema)
writer.write_table(table)
if not self.partition_cols:
self._write_single_file(tables, filename, output_schema, write_kwargs)
else: # partition writes
self._write_partition_files(
tables, filename, output_schema, write_kwargs
)

logger.debug(f"Writing {filename} file to {self.path}.")

logger.debug(f"Writing {write_path} file.")
call_with_retry(
write_blocks_to_path,
description=f"write '{write_path}'",
description=f"write '{filename}' to '{self.path}'",
match=DataContext.get_current().retried_io_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)

def _write_single_file(
self,
tables: List["pyarrow.Table"],
filename: str,
output_schema: "pyarrow.Schema",
write_kwargs: Dict[str, Any],
) -> None:
import pyarrow.parquet as pq

write_path = posixpath.join(self.path, filename)
with self.open_output_stream(write_path) as file:
with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
for table in tables:
table = table.cast(output_schema)
writer.write_table(table)

def _write_partition_files(
self,
tables: List["pyarrow.Table"],
filename: str,
output_schema: "pyarrow.Schema",
write_kwargs: Dict[str, Any],
) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

table = concat(tables)
# Create unique combinations of the partition columns
table_fields = [
field for field in output_schema if field.name not in self.partition_cols
]
non_partition_cols = [f.name for f in table_fields]
output_schema = pa.schema(
[field for field in output_schema if field.name not in self.partition_cols]
)
Comment on lines +135 to +137
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_schema = pa.schema(
[field for field in output_schema if field.name not in self.partition_cols]
)
output_schema = pa.schema(table_fields)

# Group the table by partition keys
# For each partition key combination fetch list of values
# for the non partition columns
# Ex: Here original table contain
# two columns (a, b). We are paritioning by column a. The schema
# of `groups` grouped Table is as follows
# b_list: [[[0,0],[1,1],[2,2]]]
# a: [[1,2,3]]
groups = table.group_by(self.partition_cols).aggregate(
[(col_name, "list") for col_name in non_partition_cols]
)
grouped_keys = [groups.column(k) for k in self.partition_cols]

for i in range(groups.num_rows):
# See https://github.com/apache/arrow/issues/14882 for recommended approach
values = [
groups.column(f"{col.name}_list")[i].values for col in table_fields
]
Comment on lines +153 to +155
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was confused because "col" refers to the string column name in non_partition_cols but refers to a field in this context

Suggested change
values = [
groups.column(f"{col.name}_list")[i].values for col in table_fields
]
values = [
groups.column(f"{field.name}_list")[i].values for field in table_fields
]

group_table = pa.Table.from_arrays(values, names=non_partition_cols)
partition_path = "/".join(
[
f"{col}={values[i]}"
for col, values in zip(self.partition_cols, grouped_keys)
]
)
write_path = posixpath.join(self.path, partition_path)
self._create_dir(write_path)
write_path = posixpath.join(write_path, filename)
with self.open_output_stream(write_path) as file:
with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
writer.write_table(group_table)

@property
def num_rows_per_write(self) -> Optional[int]:
return self.num_rows_per_file
10 changes: 10 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2976,6 +2976,7 @@ def write_parquet(
self,
path: str,
*,
partition_cols: Optional[List[str]] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
try_create_dir: bool = True,
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -3009,6 +3010,8 @@ def write_parquet(
Args:
path: The path to the destination root directory, where
parquet files are written to.
partition_cols: Column names by which to partition the dataset.
Files are writted in Hive partition style.
Comment on lines +3013 to +3014
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Use active voice (from our style guide: https://developers.google.com/style/voice)

Also, typo with "writted"

Suggested change
partition_cols: Column names by which to partition the dataset.
Files are writted in Hive partition style.
partition_cols: Column names by which to partition the dataset.
This methods writes files in Hive partition style.

filesystem: The pyarrow filesystem implementation to write to.
These filesystems are specified in the
`pyarrow docs <https://arrow.apache.org/docs\
Expand Down Expand Up @@ -3057,8 +3060,15 @@ def write_parquet(
if arrow_parquet_args_fn is None:
arrow_parquet_args_fn = lambda: {} # noqa: E731

if partition_cols and num_rows_per_file:
raise ValueError(
"Cannot pass num_rows_per_file when partition_cols "
"argument is specified"
)

datasink = ParquetDatasink(
path,
partition_cols=partition_cols,
arrow_parquet_args_fn=arrow_parquet_args_fn,
arrow_parquet_args=arrow_parquet_args,
num_rows_per_file=num_rows_per_file,
Expand Down
13 changes: 9 additions & 4 deletions python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
return self.filesystem.open_output_stream(path, **self.open_stream_args)

def on_write_start(self) -> None:
self.has_created_dir = self._create_dir(self.path)

def _create_dir(self, dest) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use the bool return value anywhere? If not, should this just be None?

Suggested change
def _create_dir(self, dest) -> bool:
def _create_dir(self, dest) -> None:

"""Create a directory to write files to.

If ``try_create_dir`` is ``False``, this method is a no-op.
Expand All @@ -96,19 +99,21 @@ def on_write_start(self) -> None:
# a policy only allows users to write blobs prefixed with s3://bucket/foo
# a call to create_dir for s3://bucket/foo/bar will fail even though it
# should not.
parsed_uri = urlparse(self.path)
parsed_uri = urlparse(dest)
is_s3_uri = parsed_uri.scheme == "s3"
skip_create_dir_for_s3 = (
is_s3_uri and not DataContext.get_current().s3_try_create_dir
)

if self.try_create_dir and not skip_create_dir_for_s3:
if self.filesystem.get_file_info(self.path).type is FileType.NotFound:
if self.filesystem.get_file_info(dest).type is FileType.NotFound:
# Arrow's S3FileSystem doesn't allow creating buckets by default, so we
# add a query arg enabling bucket creation if an S3 URI is provided.
tmp = _add_creatable_buckets_param_if_s3_uri(self.path)
tmp = _add_creatable_buckets_param_if_s3_uri(dest)
self.filesystem.create_dir(tmp, recursive=True)
self.has_created_dir = True
return True

return False

def write(
self,
Expand Down
41 changes: 41 additions & 0 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,47 @@ def test_write_parquet_supports_gzip(ray_start_regular_shared, tmp_path):
assert pq.read_table(tmp_path).to_pydict() == {"id": [0]}


def test_write_parquet_partition_cols(ray_start_regular_shared, tmp_path):
num_partitions = 10
rows_per_partition = 10
num_rows = num_partitions * rows_per_partition

df = pd.DataFrame(
{
"a": list(range(num_partitions)) * rows_per_partition,
"b": list(range(num_partitions)) * rows_per_partition,
"c": list(range(num_rows)),
"d": list(range(num_rows)),
}
)

ds = ray.data.from_pandas(df)
ds.write_parquet(tmp_path, partition_cols=["a", "b"])

# Test that files are written in partition style
for i in range(num_partitions):
partition = os.path.join(tmp_path, f"a={i}", f"b={i}")
ds_partition = ray.data.read_parquet(partition)
dsf_partition = ds_partition.to_pandas()
c_expected = [k * i for k in range(rows_per_partition)].sort()
d_expected = [k * i for k in range(rows_per_partition)].sort()
assert c_expected == dsf_partition["c"].tolist().sort()
assert d_expected == dsf_partition["d"].tolist().sort()

# Test that partition are read back properly into original dataset schema
ds1 = ray.data.read_parquet(tmp_path)
assert set(ds.schema().names) == set(ds1.schema().names)
assert ds.count() == ds1.count()

df = df.sort_values(by=["a", "b", "c", "d"])
df1 = ds1.to_pandas().sort_values(by=["a", "b", "c", "d"])
for (index1, row1), (index2, row2) in zip(df.iterrows(), df1.iterrows()):
row1_dict = row1.to_dict()
row2_dict = row2.to_dict()
assert row1_dict["c"] == row2_dict["c"]
assert row1_dict["d"] == row2_dict["d"]


def test_include_paths(ray_start_regular_shared, tmp_path):
path = os.path.join(tmp_path, "test.txt")
table = pa.Table.from_pydict({"animals": ["cat", "dog"]})
Expand Down
Loading