Skip to content

Commit c70d1d2

Browse files
authored
fix: don't eagerly materialize fields that the user hasn't asked for (#3442)
We added logic a while back to eagerly materialize fields if they are narrow and there is a filter. However, we forgot to ensure that those fields are actually part of the final projection. The result is that we end up loading many columns the user doesn't want and then throwing them away. This fix changes the set of fields we load to only be those that are asked for.
1 parent 2e2bf1a commit c70d1d2

File tree

9 files changed

+169
-46
lines changed

9 files changed

+169
-46
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.2.2
3+
rev: v0.4.1
44
hooks:
55
- id: ruff
66
args: [--fix, --exit-non-zero-on-fix]

java/core/src/test/java/com/lancedb/lance/FilterTest.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import java.io.IOException;
2727
import java.nio.file.Path;
28+
import java.util.Arrays;
2829

2930
import static org.junit.jupiter.api.Assertions.assertEquals;
3031

@@ -102,7 +103,13 @@ void testFilters() throws Exception {
102103
}
103104

104105
private void testFilter(String filter, int expectedCount) throws Exception {
105-
try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().filter(filter).build())) {
106+
try (LanceScanner scanner =
107+
dataset.newScan(
108+
new ScanOptions.Builder()
109+
.columns(Arrays.asList())
110+
.withRowId(true)
111+
.filter(filter)
112+
.build())) {
106113
assertEquals(expectedCount, scanner.countRows());
107114
}
108115
}

java/core/src/test/java/com/lancedb/lance/ScannerTest.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,12 @@ void testDatasetScannerCountRows() throws Exception {
162162
// write id with value from 0 to 39
163163
try (Dataset dataset = testDataset.write(1, 40)) {
164164
try (LanceScanner scanner =
165-
dataset.newScan(new ScanOptions.Builder().filter("id < 20").build())) {
165+
dataset.newScan(
166+
new ScanOptions.Builder()
167+
.columns(Arrays.asList())
168+
.withRowId(true)
169+
.filter("id < 20")
170+
.build())) {
166171
assertEquals(20, scanner.countRows());
167172
}
168173
}
@@ -387,7 +392,6 @@ void testDatasetScannerBatchReadahead() throws Exception {
387392
// This test is more about ensuring that the batchReadahead parameter is accepted
388393
// and doesn't cause errors. The actual effect of batchReadahead might not be
389394
// directly observable in this test.
390-
assertEquals(totalRows, scanner.countRows());
391395
try (ArrowReader reader = scanner.scanBatches()) {
392396
int rowCount = 0;
393397
while (reader.loadNextBatch()) {

python/python/lance/dataset.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,9 @@ def count_rows(
914914
"""
915915
if isinstance(filter, pa.compute.Expression):
916916
# TODO: consolidate all to use scanner
917-
return self.scanner(filter=filter).count_rows()
917+
return self.scanner(
918+
columns=[], with_row_id=True, filter=filter
919+
).count_rows()
918920

919921
return self._ds.count_rows(filter)
920922

python/python/lance/fragment.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,10 @@ def fragment_id(self):
354354
def count_rows(
355355
self, filter: Optional[Union[pa.compute.Expression, str]] = None
356356
) -> int:
357-
if filter is not None:
358-
return self.scanner(filter=filter).count_rows()
357+
if isinstance(filter, pa.compute.Expression):
358+
return self.scanner(
359+
with_row_id=True, columns=[], filter=filter
360+
).count_rows()
359361
return self._fragment.count_rows(filter)
360362

361363
@property
@@ -540,10 +542,12 @@ def merge(
540542

541543
def merge_columns(
542544
self,
543-
value_func: Dict[str, str]
544-
| BatchUDF
545-
| ReaderLike
546-
| Callable[[pa.RecordBatch], pa.RecordBatch],
545+
value_func: (
546+
Dict[str, str]
547+
| BatchUDF
548+
| ReaderLike
549+
| Callable[[pa.RecordBatch], pa.RecordBatch]
550+
),
547551
columns: Optional[list[str]] = None,
548552
batch_size: Optional[int] = None,
549553
reader_schema: Optional[pa.Schema] = None,

python/python/tests/test_dataset.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,16 @@ def test_count_rows(tmp_path: Path):
778778
assert dataset.count_rows(filter="a < 50") == 50
779779

780780

781+
def test_select_none(tmp_path: Path):
782+
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
783+
base_dir = tmp_path / "test"
784+
ds = lance.write_dataset(table, base_dir)
785+
786+
assert "projection=[a]" in ds.scanner(
787+
columns=[], filter="a < 50", with_row_id=True
788+
).explain_plan(True)
789+
790+
781791
def test_get_fragments(tmp_path: Path):
782792
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
783793
base_dir = tmp_path / "test"
@@ -2200,7 +2210,7 @@ def test_scan_count_rows(tmp_path: Path):
22002210
df = pd.DataFrame({"a": range(42), "b": range(42)})
22012211
dataset = lance.write_dataset(df, base_dir)
22022212

2203-
assert dataset.scanner().count_rows() == 42
2213+
assert dataset.scanner(columns=[], with_row_id=True).count_rows() == 42
22042214
assert dataset.count_rows(filter="a < 10") == 10
22052215
assert dataset.count_rows(filter=pa_ds.field("a") < 20) == 20
22062216

rust/lance/src/dataset/fragment.rs

+3
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,9 @@ impl FileFragment {
903903
match filter {
904904
Some(expr) => self
905905
.scan()
906+
.project(&Vec::<String>::default())
907+
.unwrap()
908+
.with_row_id()
906909
.filter(&expr)?
907910
.count_rows()
908911
.await

rust/lance/src/dataset/scanner.rs

+126-30
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ impl MaterializationStyle {
189189
}
190190

191191
/// Filter for filtering rows
192+
#[derive(Debug)]
192193
pub enum LanceFilter {
193194
/// The filter is an SQL string
194195
Sql(String),
@@ -1027,11 +1028,22 @@ impl Scanner {
10271028
Ok(concat_batches(&schema, &batches)?)
10281029
}
10291030

1030-
/// Scan and return the number of matching rows
1031-
#[instrument(skip_all)]
1032-
pub fn count_rows(&self) -> BoxFuture<Result<u64>> {
1031+
fn create_count_plan(&self) -> BoxFuture<Result<Arc<dyn ExecutionPlan>>> {
10331032
// Future intentionally boxed here to avoid large futures on the stack
10341033
async move {
1034+
if !self.projection_plan.physical_schema.fields.is_empty() {
1035+
return Err(Error::invalid_input(
1036+
"count_rows should not be called on a plan selecting columns".to_string(),
1037+
location!(),
1038+
));
1039+
}
1040+
1041+
if self.limit.is_some() || self.offset.is_some() {
1042+
log::warn!(
1043+
"count_rows called with limit or offset which could have surprising results"
1044+
);
1045+
}
1046+
10351047
let plan = self.create_plan().await?;
10361048
// Datafusion interprets COUNT(*) as COUNT(1)
10371049
let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1))));
@@ -1046,14 +1058,27 @@ impl Scanner {
10461058
let count_expr = builder.build()?;
10471059

10481060
let plan_schema = plan.schema();
1049-
let count_plan = Arc::new(AggregateExec::try_new(
1061+
Ok(Arc::new(AggregateExec::try_new(
10501062
AggregateMode::Single,
10511063
PhysicalGroupBy::new_single(Vec::new()),
10521064
vec![Arc::new(count_expr)],
10531065
vec![None],
10541066
plan,
10551067
plan_schema,
1056-
)?);
1068+
)?) as Arc<dyn ExecutionPlan>)
1069+
}
1070+
.boxed()
1071+
}
1072+
1073+
/// Scan and return the number of matching rows
1074+
///
1075+
/// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method
1076+
/// especially if there is no filter.
1077+
#[instrument(skip_all)]
1078+
pub fn count_rows(&self) -> BoxFuture<Result<u64>> {
1079+
// Future intentionally boxed here to avoid large futures on the stack
1080+
async move {
1081+
let count_plan = self.create_count_plan().await?;
10571082
let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?;
10581083

10591084
// A count plan will always return a single batch with a single row.
@@ -1127,15 +1152,25 @@ impl Scanner {
11271152
}
11281153
}
11291154

1130-
fn calc_eager_columns(&self, filter_plan: &FilterPlan) -> Result<Arc<Schema>> {
1131-
let columns = filter_plan.refine_columns();
1155+
// If we are going to filter on `filter_plan`, then which columns are so small it is
1156+
// cheaper to read the entire column and filter in memory.
1157+
//
1158+
// Note: only add columns that we actually need to read
1159+
fn calc_eager_columns(
1160+
&self,
1161+
filter_plan: &FilterPlan,
1162+
desired_schema: &Schema,
1163+
) -> Result<Arc<Schema>> {
1164+
let filter_columns = filter_plan.refine_columns();
11321165
let early_schema = self
11331166
.dataset
11341167
.empty_projection()
1135-
// We need the filter columns
1136-
.union_columns(columns, OnMissing::Error)?
1137-
// And also any columns that are eager
1138-
.union_predicate(|f| self.is_early_field(f))
1168+
// Start with the desired schema
1169+
.union_schema(desired_schema)
1170+
// Subtract columns that are expensive
1171+
.subtract_predicate(|f| !self.is_early_field(f))
1172+
// Add back columns that we need for filtering
1173+
.union_columns(filter_columns, OnMissing::Error)?
11391174
.into_schema_ref();
11401175

11411176
if early_schema.fields.iter().any(|f| !f.is_default_storage()) {
@@ -1340,7 +1375,10 @@ impl Scanner {
13401375
(Some(index_query), Some(_)) => {
13411376
// If there is a filter then just load the eager columns and
13421377
// "take" the other columns later.
1343-
let eager_schema = self.calc_eager_columns(&filter_plan)?;
1378+
let eager_schema = self.calc_eager_columns(
1379+
&filter_plan,
1380+
self.projection_plan.physical_schema.as_ref(),
1381+
)?;
13441382
self.scalar_indexed_scan(&eager_schema, index_query).await?
13451383
}
13461384
(None, Some(_)) if use_stats && self.batch_size.is_none() => {
@@ -1352,7 +1390,10 @@ impl Scanner {
13521390
let eager_schema = if filter_plan.has_refine() {
13531391
// If there is a filter then only load the filter columns in the
13541392
// initial scan. We will `take` the remaining columns later
1355-
self.calc_eager_columns(&filter_plan)?
1393+
self.calc_eager_columns(
1394+
&filter_plan,
1395+
self.projection_plan.physical_schema.as_ref(),
1396+
)?
13561397
} else {
13571398
// If there is no filter we eagerly load everything
13581399
self.projection_plan.physical_schema.clone()
@@ -3913,14 +3954,11 @@ mod test {
39133954
.unwrap();
39143955

39153956
let dataset = Dataset::open(test_uri).await.unwrap();
3916-
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
3957+
assert_eq!(32, dataset.count_rows(None).await.unwrap());
39173958
assert_eq!(
39183959
16,
39193960
dataset
3920-
.scan()
3921-
.filter("`Filter_me` > 15")
3922-
.unwrap()
3923-
.count_rows()
3961+
.count_rows(Some("`Filter_me` > 15".to_string()))
39243962
.await
39253963
.unwrap()
39263964
);
@@ -3948,7 +3986,7 @@ mod test {
39483986
.unwrap();
39493987

39503988
let dataset = Dataset::open(test_uri).await.unwrap();
3951-
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
3989+
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);
39523990

39533991
let mut scanner = dataset.scan();
39543992

@@ -3996,7 +4034,7 @@ mod test {
39964034
.unwrap();
39974035

39984036
let dataset = Dataset::open(test_uri).await.unwrap();
3999-
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
4037+
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);
40004038

40014039
let mut scanner = dataset.scan();
40024040

@@ -4519,20 +4557,13 @@ mod test {
45194557
}
45204558
}
45214559

4522-
/// Assert that the plan when formatted matches the expected string.
4523-
///
4524-
/// Within expected, you can use `...` to match any number of characters.
4525-
async fn assert_plan_equals(
4526-
dataset: &Dataset,
4527-
plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>,
4560+
async fn assert_plan_node_equals(
4561+
plan_node: Arc<dyn ExecutionPlan>,
45284562
expected: &str,
45294563
) -> Result<()> {
4530-
let mut scan = dataset.scan();
4531-
plan(&mut scan)?;
4532-
let exec_plan = scan.create_plan().await?;
45334564
let plan_desc = format!(
45344565
"{}",
4535-
datafusion::physical_plan::displayable(exec_plan.as_ref()).indent(true)
4566+
datafusion::physical_plan::displayable(plan_node.as_ref()).indent(true)
45364567
);
45374568

45384569
let to_match = expected.split("...").collect::<Vec<_>>();
@@ -4559,6 +4590,71 @@ mod test {
45594590
Ok(())
45604591
}
45614592

4593+
/// Assert that the plan when formatted matches the expected string.
4594+
///
4595+
/// Within expected, you can use `...` to match any number of characters.
4596+
async fn assert_plan_equals(
4597+
dataset: &Dataset,
4598+
plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>,
4599+
expected: &str,
4600+
) -> Result<()> {
4601+
let mut scan = dataset.scan();
4602+
plan(&mut scan)?;
4603+
let exec_plan = scan.create_plan().await?;
4604+
assert_plan_node_equals(exec_plan, expected).await
4605+
}
4606+
4607+
#[tokio::test]
4608+
async fn test_count_plan() {
4609+
// A count rows operation should load the minimal amount of data
4610+
let dim = 256;
4611+
let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim)
4612+
.await
4613+
.unwrap();
4614+
4615+
// By default, all columns are returned, this is bad for a count_rows op
4616+
let err = fixture
4617+
.dataset
4618+
.scan()
4619+
.create_count_plan()
4620+
.await
4621+
.unwrap_err();
4622+
assert!(matches!(err, Error::InvalidInput { .. }));
4623+
4624+
let mut scan = fixture.dataset.scan();
4625+
scan.project(&Vec::<String>::default()).unwrap();
4626+
4627+
// with_row_id needs to be specified
4628+
let err = scan.create_count_plan().await.unwrap_err();
4629+
assert!(matches!(err, Error::InvalidInput { .. }));
4630+
4631+
scan.with_row_id();
4632+
4633+
let plan = scan.create_count_plan().await.unwrap();
4634+
4635+
assert_plan_node_equals(
4636+
plan,
4637+
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
4638+
LanceScan: uri=..., projection=[], row_id=true, row_addr=false, ordered=true",
4639+
)
4640+
.await
4641+
.unwrap();
4642+
4643+
scan.filter("s == ''").unwrap();
4644+
4645+
let plan = scan.create_count_plan().await.unwrap();
4646+
4647+
assert_plan_node_equals(
4648+
plan,
4649+
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
4650+
ProjectionExec: expr=[_rowid@1 as _rowid]
4651+
FilterExec: s@0 =
4652+
LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=true",
4653+
)
4654+
.await
4655+
.unwrap();
4656+
}
4657+
45624658
#[rstest]
45634659
#[tokio::test]
45644660
async fn test_late_materialization(

rust/lance/src/dataset/write/merge_insert.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -1816,10 +1816,7 @@ mod tests {
18161816

18171817
// Check that the data is as expected
18181818
let updated = ds
1819-
.scan()
1820-
.filter("value = 9999999")
1821-
.unwrap()
1822-
.count_rows()
1819+
.count_rows(Some("value = 9999999".to_string()))
18231820
.await
18241821
.unwrap();
18251822
assert_eq!(updated, 2048);

0 commit comments

Comments
 (0)