diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index 8b06ecfaad60f..b688c2630d686 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -230,6 +230,14 @@ def __init__( # duplicating the partition data, we disable PyArrow's partitioning. dataset_kwargs["partitioning"] = None + # `read_schema` is the schema object that will be used to perform + # read operations. + # It should be None, unless user has specified the schema or columns. + # We don't use the inferred schema for read, because the pyarrow only infers + # schema based on the first file. Thus, files with different schemas will end + # up producing blocks with wrong schema. + # See https://github.com/ray-project/ray/issues/47960 for more context. + read_schema = schema pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs) if schema is None: @@ -240,6 +248,7 @@ def __init__( schema = pa.schema( [schema.field(column) for column in columns], schema.metadata ) + read_schema = schema check_for_legacy_tensor_type(schema) @@ -247,17 +256,13 @@ def __init__( # Try to infer dataset schema by passing dummy table through UDF. dummy_table = schema.empty_table() try: - inferred_schema = _block_udf(dummy_table).schema - inferred_schema = inferred_schema.with_metadata(schema.metadata) + schema = _block_udf(dummy_table).schema.with_metadata(schema.metadata) except Exception: logger.debug( "Failed to infer schema of dataset by passing dummy table " "through UDF due to the following exception:", exc_info=True, ) - inferred_schema = schema - else: - inferred_schema = schema try: prefetch_remote_args = {} @@ -291,10 +296,10 @@ def __init__( self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments] self._pq_paths = [p.path for p in pq_ds.fragments] self._meta_provider = meta_provider - self._inferred_schema = inferred_schema self._block_udf = _block_udf self._to_batches_kwargs = to_batch_kwargs self._columns = columns + self._read_schema = read_schema self._schema = schema self._file_metadata_shuffler = None self._include_paths = include_paths @@ -306,7 +311,7 @@ def __init__( self._pq_fragments, to_batches_kwargs=to_batch_kwargs, columns=columns, - schema=schema, + schema=self._read_schema, local_scheduling=self._local_scheduling, ) self._encoding_ratio = estimate_files_encoding_ratio(sample_infos) @@ -358,7 +363,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: meta = self._meta_provider( paths, - self._inferred_schema, + self._schema, num_fragments=len(fragments), prefetched_metadata=metadata, ) @@ -375,7 +380,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: to_batches_kwargs, default_read_batch_size_rows, columns, - schema, + read_schema, include_paths, partitioning, ) = ( @@ -383,7 +388,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: self._to_batches_kwargs, self._default_read_batch_size_rows, self._columns, - self._schema, + self._read_schema, self._include_paths, self._partitioning, ) @@ -394,7 +399,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: to_batches_kwargs, default_read_batch_size_rows, columns, - schema, + read_schema, f, include_paths, partitioning, diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 15629e5ef0f79..23969d736f046 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -1291,6 +1291,32 @@ def _assert_equal(rows, expected): _assert_equal(ds.take_all(), expected_tuples) +def test_multiple_files_with_ragged_arrays(ray_start_regular_shared, tmp_path): + # Test reading multiple parquet files, each of which has different-shaped + # ndarrays in the same column. + # See https://github.com/ray-project/ray/issues/47960 for more context. + num_rows = 3 + ds = ray.data.range(num_rows) + + def map(row): + id = row["id"] + 1 + row["data"] = np.zeros((id * 100, id * 100), dtype=np.int8) + return row + + # Write 3 parquet files with different-shaped ndarray values in the + # "data" column. + ds.map(map).repartition(num_rows).write_parquet(tmp_path) + + # Read these 3 files, check that the result is correct. + ds2 = ray.data.read_parquet(tmp_path, override_num_blocks=1) + res = ds2.take_all() + res = sorted(res, key=lambda row: row["id"]) + assert len(res) == num_rows + for index, item in enumerate(res): + assert item["id"] == index + assert item["data"].shape == (100 * (index + 1), 100 * (index + 1)) + + if __name__ == "__main__": import sys