From ed37acea3da55df7ef33c33aaace4f1f34cf8239 Mon Sep 17 00:00:00 2001 From: brifitz <95299320+brifitz@users.noreply.github.com> Date: Thu, 2 Jan 2025 09:49:24 +0000 Subject: [PATCH] fix(rust): `slice_pushdown` optimization leading to incorrectly sliced row index on parquet file (#20508) --- .../polars-io/src/parquet/read/read_impl.rs | 29 +++++++++++++++---- py-polars/tests/unit/io/test_lazy_parquet.py | 28 ++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index eb4448eebeb1..9f5281280c51 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -671,7 +671,7 @@ fn rg_to_dfs_par_over_rg( store: &mmap::ColumnStore, row_group_start: usize, row_group_end: usize, - previous_row_count: &mut IdxSize, + rows_read: &mut IdxSize, slice: (usize, usize), file_metadata: &FileMetadata, schema: &ArrowSchemaRef, @@ -689,15 +689,34 @@ fn rg_to_dfs_par_over_rg( .sum(); let slice_end = slice.0 + slice.1; + // rows_scanned is the number of rows that have been scanned so far when checking for overlap with the slice. + // rows_read is the number of rows found to overlap with the slice, and thus the number of rows that will be + // read into a dataframe. + let mut rows_scanned: IdxSize; + + if row_group_start > 0 { + // In the case of async reads, we need to account for the fact that row_group_start may be greater than + // zero due to earlier processing. + // For details, see: https://github.com/pola-rs/polars/pull/20508#discussion_r1900165649 + rows_scanned = (0..row_group_start) + .map(|i| file_metadata.row_groups[i].num_rows() as IdxSize) + .sum(); + } else { + rows_scanned = 0; + } + for i in row_group_start..row_group_end { - let row_count_start = *previous_row_count; + let row_count_start = rows_scanned; let rg_md = &file_metadata.row_groups[i]; + let n_rows_this_file = rg_md.num_rows(); let rg_slice = - split_slice_at_file(&mut n_rows_processed, rg_md.num_rows(), slice.0, slice_end); - *previous_row_count = previous_row_count - .checked_add(rg_slice.1 as IdxSize) + split_slice_at_file(&mut n_rows_processed, n_rows_this_file, slice.0, slice_end); + rows_scanned = rows_scanned + .checked_add(n_rows_this_file as IdxSize) .ok_or(ROW_COUNT_OVERFLOW_ERR)?; + *rows_read += rg_slice.1 as IdxSize; + if rg_slice.1 == 0 { continue; } diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 05589332cc99..78ffb6b1379b 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -564,6 +564,34 @@ def trim_to_metadata(path: str | Path) -> None: ) +@pytest.mark.write_disk +def test_predicate_slice_pushdown_row_index_20485(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "slice_pushdown.parquet" + row_group_size = 100000 + num_row_groups = 3 + + df = pl.select(ref=pl.int_range(num_row_groups * row_group_size)) + df.write_parquet(file_path, row_group_size=row_group_size) + + # Use a slice that starts near the end of one row group and extends into the next + # to test handling of slices that span multiple row groups. + slice_start = 199995 + slice_len = 10 + ldf = pl.scan_parquet(file_path) + sliced_df = ldf.with_row_index().slice(slice_start, slice_len).collect() + sliced_df_no_pushdown = ( + ldf.with_row_index().slice(slice_start, slice_len).collect(slice_pushdown=False) + ) + + expected_index = list(range(slice_start, slice_start + slice_len)) + actual_index = list(sliced_df["index"]) + assert actual_index == expected_index + + assert_frame_equal(sliced_df, sliced_df_no_pushdown) + + @pytest.mark.write_disk @pytest.mark.parametrize("streaming", [True, False]) def test_parquet_row_groups_shift_bug_18739(tmp_path: Path, streaming: bool) -> None: