Skip to content

Commit

Permalink
[data] fix reading multiple parquet files with ragged ndarrays (ray-p…
Browse files Browse the repository at this point in the history
…roject#47961)

## Why are these changes needed?

PyArrow infers parquet schema only based on the first file. This will
cause errors when reading multiple files with ragged ndarrays.

This PR fixes this issue by not using the inferred schema for reading.

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number
Fixes ray-project#47960

---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
raulchen authored and ujjawal-khare committed Oct 15, 2024
1 parent d0b6bfc commit ee06939
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
27 changes: 16 additions & 11 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ def __init__(
# duplicating the partition data, we disable PyArrow's partitioning.
dataset_kwargs["partitioning"] = None

# `read_schema` is the schema object that will be used to perform
# read operations.
# It should be None, unless user has specified the schema or columns.
# We don't use the inferred schema for read, because the pyarrow only infers
# schema based on the first file. Thus, files with different schemas will end
# up producing blocks with wrong schema.
# See https://github.com/ray-project/ray/issues/47960 for more context.
read_schema = schema
pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs)

if schema is None:
Expand All @@ -240,24 +248,21 @@ def __init__(
schema = pa.schema(
[schema.field(column) for column in columns], schema.metadata
)
read_schema = schema

check_for_legacy_tensor_type(schema)

if _block_udf is not None:
# Try to infer dataset schema by passing dummy table through UDF.
dummy_table = schema.empty_table()
try:
inferred_schema = _block_udf(dummy_table).schema
inferred_schema = inferred_schema.with_metadata(schema.metadata)
schema = _block_udf(dummy_table).schema.with_metadata(schema.metadata)
except Exception:
logger.debug(
"Failed to infer schema of dataset by passing dummy table "
"through UDF due to the following exception:",
exc_info=True,
)
inferred_schema = schema
else:
inferred_schema = schema

try:
prefetch_remote_args = {}
Expand Down Expand Up @@ -291,10 +296,10 @@ def __init__(
self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments]
self._pq_paths = [p.path for p in pq_ds.fragments]
self._meta_provider = meta_provider
self._inferred_schema = inferred_schema
self._block_udf = _block_udf
self._to_batches_kwargs = to_batch_kwargs
self._columns = columns
self._read_schema = read_schema
self._schema = schema
self._file_metadata_shuffler = None
self._include_paths = include_paths
Expand All @@ -306,7 +311,7 @@ def __init__(
self._pq_fragments,
to_batches_kwargs=to_batch_kwargs,
columns=columns,
schema=schema,
schema=self._read_schema,
local_scheduling=self._local_scheduling,
)
self._encoding_ratio = estimate_files_encoding_ratio(sample_infos)
Expand Down Expand Up @@ -358,7 +363,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:

meta = self._meta_provider(
paths,
self._inferred_schema,
self._schema,
num_fragments=len(fragments),
prefetched_metadata=metadata,
)
Expand All @@ -375,15 +380,15 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
to_batches_kwargs,
default_read_batch_size_rows,
columns,
schema,
read_schema,
include_paths,
partitioning,
) = (
self._block_udf,
self._to_batches_kwargs,
self._default_read_batch_size_rows,
self._columns,
self._schema,
self._read_schema,
self._include_paths,
self._partitioning,
)
Expand All @@ -394,7 +399,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
to_batches_kwargs,
default_read_batch_size_rows,
columns,
schema,
read_schema,
f,
include_paths,
partitioning,
Expand Down
26 changes: 26 additions & 0 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,32 @@ def _assert_equal(rows, expected):
_assert_equal(ds.take_all(), expected_tuples)


def test_multiple_files_with_ragged_arrays(ray_start_regular_shared, tmp_path):
# Test reading multiple parquet files, each of which has different-shaped
# ndarrays in the same column.
# See https://github.com/ray-project/ray/issues/47960 for more context.
num_rows = 3
ds = ray.data.range(num_rows)

def map(row):
id = row["id"] + 1
row["data"] = np.zeros((id * 100, id * 100), dtype=np.int8)
return row

# Write 3 parquet files with different-shaped ndarray values in the
# "data" column.
ds.map(map).repartition(num_rows).write_parquet(tmp_path)

# Read these 3 files, check that the result is correct.
ds2 = ray.data.read_parquet(tmp_path, override_num_blocks=1)
res = ds2.take_all()
res = sorted(res, key=lambda row: row["id"])
assert len(res) == num_rows
for index, item in enumerate(res):
assert item["id"] == index
assert item["data"].shape == (100 * (index + 1), 100 * (index + 1))


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit ee06939

Please sign in to comment.