Skip to content

Commit

Permalink
[Data] Support partition_cols in write_parquet (#49411)
Browse files Browse the repository at this point in the history
  • Loading branch information
gvspraveen authored and srinathk10 committed Jan 3, 2025
1 parent 877e016 commit edbb1c2
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 19 deletions.
100 changes: 85 additions & 15 deletions python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import posixpath
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional

from ray.data._internal.arrow_ops.transform_pyarrow import concat
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.util import call_with_retry
from ray.data.block import Block, BlockAccessor
Expand All @@ -24,6 +25,7 @@ def __init__(
self,
path: str,
*,
partition_cols: Optional[List[str]] = None,
arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
arrow_parquet_args: Optional[Dict[str, Any]] = None,
num_rows_per_file: Optional[int] = None,
Expand All @@ -42,6 +44,7 @@ def __init__(
self.arrow_parquet_args_fn = arrow_parquet_args_fn
self.arrow_parquet_args = arrow_parquet_args
self.num_rows_per_file = num_rows_per_file
self.partition_cols = partition_cols

super().__init__(
path,
Expand All @@ -59,7 +62,6 @@ def write(
ctx: TaskContext,
) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

blocks = list(blocks)

Expand All @@ -69,34 +71,102 @@ def write(
filename = self.filename_provider.get_filename_for_block(
blocks[0], ctx.task_idx, 0
)
write_path = posixpath.join(self.path, filename)
write_kwargs = _resolve_kwargs(
self.arrow_parquet_args_fn, **self.arrow_parquet_args
)
user_schema = write_kwargs.pop("schema", None)

def write_blocks_to_path():
with self.open_output_stream(write_path) as file:
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
if user_schema is None:
output_schema = pa.unify_schemas([table.schema for table in tables])
else:
output_schema = user_schema
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
if user_schema is None:
output_schema = pa.unify_schemas([table.schema for table in tables])
else:
output_schema = user_schema

with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
for table in tables:
table = table.cast(output_schema)
writer.write_table(table)
if not self.partition_cols:
self._write_single_file(tables, filename, output_schema, write_kwargs)
else: # partition writes
self._write_partition_files(
tables, filename, output_schema, write_kwargs
)

logger.debug(f"Writing {filename} file to {self.path}.")

logger.debug(f"Writing {write_path} file.")
call_with_retry(
write_blocks_to_path,
description=f"write '{write_path}'",
description=f"write '{filename}' to '{self.path}'",
match=DataContext.get_current().retried_io_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)

def _write_single_file(
self,
tables: List["pyarrow.Table"],
filename: str,
output_schema: "pyarrow.Schema",
write_kwargs: Dict[str, Any],
) -> None:
import pyarrow.parquet as pq

write_path = posixpath.join(self.path, filename)
with self.open_output_stream(write_path) as file:
with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
for table in tables:
table = table.cast(output_schema)
writer.write_table(table)

def _write_partition_files(
self,
tables: List["pyarrow.Table"],
filename: str,
output_schema: "pyarrow.Schema",
write_kwargs: Dict[str, Any],
) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

table = concat(tables)
# Create unique combinations of the partition columns
table_fields = [
field for field in output_schema if field.name not in self.partition_cols
]
non_partition_cols = [f.name for f in table_fields]
output_schema = pa.schema(
[field for field in output_schema if field.name not in self.partition_cols]
)
# Group the table by partition keys
# For each partition key combination fetch list of values
# for the non partition columns
# Ex: Here original table contain
# two columns (a, b). We are paritioning by column a. The schema
# of `groups` grouped Table is as follows
# b_list: [[[0,0],[1,1],[2,2]]]
# a: [[1,2,3]]
groups = table.group_by(self.partition_cols).aggregate(
[(col_name, "list") for col_name in non_partition_cols]
)
grouped_keys = [groups.column(k) for k in self.partition_cols]

for i in range(groups.num_rows):
# See https://github.com/apache/arrow/issues/14882 for recommended approach
values = [
groups.column(f"{col.name}_list")[i].values for col in table_fields
]
group_table = pa.Table.from_arrays(values, names=non_partition_cols)
partition_path = "/".join(
[
f"{col}={values[i]}"
for col, values in zip(self.partition_cols, grouped_keys)
]
)
write_path = posixpath.join(self.path, partition_path)
self._create_dir(write_path)
write_path = posixpath.join(write_path, filename)
with self.open_output_stream(write_path) as file:
with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
writer.write_table(group_table)

@property
def num_rows_per_write(self) -> Optional[int]:
return self.num_rows_per_file
10 changes: 10 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2976,6 +2976,7 @@ def write_parquet(
self,
path: str,
*,
partition_cols: Optional[List[str]] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
try_create_dir: bool = True,
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -3009,6 +3010,8 @@ def write_parquet(
Args:
path: The path to the destination root directory, where
parquet files are written to.
partition_cols: Column names by which to partition the dataset.
Files are writted in Hive partition style.
filesystem: The pyarrow filesystem implementation to write to.
These filesystems are specified in the
`pyarrow docs <https://arrow.apache.org/docs\
Expand Down Expand Up @@ -3057,8 +3060,15 @@ def write_parquet(
if arrow_parquet_args_fn is None:
arrow_parquet_args_fn = lambda: {} # noqa: E731

if partition_cols and num_rows_per_file:
raise ValueError(
"Cannot pass num_rows_per_file when partition_cols "
"argument is specified"
)

datasink = ParquetDatasink(
path,
partition_cols=partition_cols,
arrow_parquet_args_fn=arrow_parquet_args_fn,
arrow_parquet_args=arrow_parquet_args,
num_rows_per_file=num_rows_per_file,
Expand Down
13 changes: 9 additions & 4 deletions python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
return self.filesystem.open_output_stream(path, **self.open_stream_args)

def on_write_start(self) -> None:
self.has_created_dir = self._create_dir(self.path)

def _create_dir(self, dest) -> bool:
"""Create a directory to write files to.
If ``try_create_dir`` is ``False``, this method is a no-op.
Expand All @@ -96,19 +99,21 @@ def on_write_start(self) -> None:
# a policy only allows users to write blobs prefixed with s3://bucket/foo
# a call to create_dir for s3://bucket/foo/bar will fail even though it
# should not.
parsed_uri = urlparse(self.path)
parsed_uri = urlparse(dest)
is_s3_uri = parsed_uri.scheme == "s3"
skip_create_dir_for_s3 = (
is_s3_uri and not DataContext.get_current().s3_try_create_dir
)

if self.try_create_dir and not skip_create_dir_for_s3:
if self.filesystem.get_file_info(self.path).type is FileType.NotFound:
if self.filesystem.get_file_info(dest).type is FileType.NotFound:
# Arrow's S3FileSystem doesn't allow creating buckets by default, so we
# add a query arg enabling bucket creation if an S3 URI is provided.
tmp = _add_creatable_buckets_param_if_s3_uri(self.path)
tmp = _add_creatable_buckets_param_if_s3_uri(dest)
self.filesystem.create_dir(tmp, recursive=True)
self.has_created_dir = True
return True

return False

def write(
self,
Expand Down
41 changes: 41 additions & 0 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,47 @@ def test_write_parquet_supports_gzip(ray_start_regular_shared, tmp_path):
assert pq.read_table(tmp_path).to_pydict() == {"id": [0]}


def test_write_parquet_partition_cols(ray_start_regular_shared, tmp_path):
num_partitions = 10
rows_per_partition = 10
num_rows = num_partitions * rows_per_partition

df = pd.DataFrame(
{
"a": list(range(num_partitions)) * rows_per_partition,
"b": list(range(num_partitions)) * rows_per_partition,
"c": list(range(num_rows)),
"d": list(range(num_rows)),
}
)

ds = ray.data.from_pandas(df)
ds.write_parquet(tmp_path, partition_cols=["a", "b"])

# Test that files are written in partition style
for i in range(num_partitions):
partition = os.path.join(tmp_path, f"a={i}", f"b={i}")
ds_partition = ray.data.read_parquet(partition)
dsf_partition = ds_partition.to_pandas()
c_expected = [k * i for k in range(rows_per_partition)].sort()
d_expected = [k * i for k in range(rows_per_partition)].sort()
assert c_expected == dsf_partition["c"].tolist().sort()
assert d_expected == dsf_partition["d"].tolist().sort()

# Test that partition are read back properly into original dataset schema
ds1 = ray.data.read_parquet(tmp_path)
assert set(ds.schema().names) == set(ds1.schema().names)
assert ds.count() == ds1.count()

df = df.sort_values(by=["a", "b", "c", "d"])
df1 = ds1.to_pandas().sort_values(by=["a", "b", "c", "d"])
for (index1, row1), (index2, row2) in zip(df.iterrows(), df1.iterrows()):
row1_dict = row1.to_dict()
row2_dict = row2.to_dict()
assert row1_dict["c"] == row2_dict["c"]
assert row1_dict["d"] == row2_dict["d"]


def test_include_paths(ray_start_regular_shared, tmp_path):
path = os.path.join(tmp_path, "test.txt")
table = pa.Table.from_pydict({"animals": ["cat", "dog"]})
Expand Down

0 comments on commit edbb1c2

Please sign in to comment.