@@ -189,6 +189,7 @@ impl MaterializationStyle {
189
189
}
190
190
191
191
/// Filter for filtering rows
192
+ #[ derive( Debug ) ]
192
193
pub enum LanceFilter {
193
194
/// The filter is an SQL string
194
195
Sql ( String ) ,
@@ -1027,11 +1028,22 @@ impl Scanner {
1027
1028
Ok ( concat_batches ( & schema, & batches) ?)
1028
1029
}
1029
1030
1030
- /// Scan and return the number of matching rows
1031
- #[ instrument( skip_all) ]
1032
- pub fn count_rows ( & self ) -> BoxFuture < Result < u64 > > {
1031
+ fn create_count_plan ( & self ) -> BoxFuture < Result < Arc < dyn ExecutionPlan > > > {
1033
1032
// Future intentionally boxed here to avoid large futures on the stack
1034
1033
async move {
1034
+ if !self . projection_plan . physical_schema . fields . is_empty ( ) {
1035
+ return Err ( Error :: invalid_input (
1036
+ "count_rows should not be called on a plan selecting columns" . to_string ( ) ,
1037
+ location ! ( ) ,
1038
+ ) ) ;
1039
+ }
1040
+
1041
+ if self . limit . is_some ( ) || self . offset . is_some ( ) {
1042
+ log:: warn!(
1043
+ "count_rows called with limit or offset which could have surprising results"
1044
+ ) ;
1045
+ }
1046
+
1035
1047
let plan = self . create_plan ( ) . await ?;
1036
1048
// Datafusion interprets COUNT(*) as COUNT(1)
1037
1049
let one = Arc :: new ( Literal :: new ( ScalarValue :: UInt8 ( Some ( 1 ) ) ) ) ;
@@ -1046,14 +1058,27 @@ impl Scanner {
1046
1058
let count_expr = builder. build ( ) ?;
1047
1059
1048
1060
let plan_schema = plan. schema ( ) ;
1049
- let count_plan = Arc :: new ( AggregateExec :: try_new (
1061
+ Ok ( Arc :: new ( AggregateExec :: try_new (
1050
1062
AggregateMode :: Single ,
1051
1063
PhysicalGroupBy :: new_single ( Vec :: new ( ) ) ,
1052
1064
vec ! [ Arc :: new( count_expr) ] ,
1053
1065
vec ! [ None ] ,
1054
1066
plan,
1055
1067
plan_schema,
1056
- ) ?) ;
1068
+ ) ?) as Arc < dyn ExecutionPlan > )
1069
+ }
1070
+ . boxed ( )
1071
+ }
1072
+
1073
+ /// Scan and return the number of matching rows
1074
+ ///
1075
+ /// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method
1076
+ /// especially if there is no filter.
1077
+ #[ instrument( skip_all) ]
1078
+ pub fn count_rows ( & self ) -> BoxFuture < Result < u64 > > {
1079
+ // Future intentionally boxed here to avoid large futures on the stack
1080
+ async move {
1081
+ let count_plan = self . create_count_plan ( ) . await ?;
1057
1082
let mut stream = execute_plan ( count_plan, LanceExecutionOptions :: default ( ) ) ?;
1058
1083
1059
1084
// A count plan will always return a single batch with a single row.
@@ -1127,15 +1152,25 @@ impl Scanner {
1127
1152
}
1128
1153
}
1129
1154
1130
- fn calc_eager_columns ( & self , filter_plan : & FilterPlan ) -> Result < Arc < Schema > > {
1131
- let columns = filter_plan. refine_columns ( ) ;
1155
+ // If we are going to filter on `filter_plan`, then which columns are so small it is
1156
+ // cheaper to read the entire column and filter in memory.
1157
+ //
1158
+ // Note: only add columns that we actually need to read
1159
+ fn calc_eager_columns (
1160
+ & self ,
1161
+ filter_plan : & FilterPlan ,
1162
+ desired_schema : & Schema ,
1163
+ ) -> Result < Arc < Schema > > {
1164
+ let filter_columns = filter_plan. refine_columns ( ) ;
1132
1165
let early_schema = self
1133
1166
. dataset
1134
1167
. empty_projection ( )
1135
- // We need the filter columns
1136
- . union_columns ( columns, OnMissing :: Error ) ?
1137
- // And also any columns that are eager
1138
- . union_predicate ( |f| self . is_early_field ( f) )
1168
+ // Start with the desired schema
1169
+ . union_schema ( desired_schema)
1170
+ // Subtract columns that are expensive
1171
+ . subtract_predicate ( |f| !self . is_early_field ( f) )
1172
+ // Add back columns that we need for filtering
1173
+ . union_columns ( filter_columns, OnMissing :: Error ) ?
1139
1174
. into_schema_ref ( ) ;
1140
1175
1141
1176
if early_schema. fields . iter ( ) . any ( |f| !f. is_default_storage ( ) ) {
@@ -1340,7 +1375,10 @@ impl Scanner {
1340
1375
( Some ( index_query) , Some ( _) ) => {
1341
1376
// If there is a filter then just load the eager columns and
1342
1377
// "take" the other columns later.
1343
- let eager_schema = self . calc_eager_columns ( & filter_plan) ?;
1378
+ let eager_schema = self . calc_eager_columns (
1379
+ & filter_plan,
1380
+ self . projection_plan . physical_schema . as_ref ( ) ,
1381
+ ) ?;
1344
1382
self . scalar_indexed_scan ( & eager_schema, index_query) . await ?
1345
1383
}
1346
1384
( None , Some ( _) ) if use_stats && self . batch_size . is_none ( ) => {
@@ -1352,7 +1390,10 @@ impl Scanner {
1352
1390
let eager_schema = if filter_plan. has_refine ( ) {
1353
1391
// If there is a filter then only load the filter columns in the
1354
1392
// initial scan. We will `take` the remaining columns later
1355
- self . calc_eager_columns ( & filter_plan) ?
1393
+ self . calc_eager_columns (
1394
+ & filter_plan,
1395
+ self . projection_plan . physical_schema . as_ref ( ) ,
1396
+ ) ?
1356
1397
} else {
1357
1398
// If there is no filter we eagerly load everything
1358
1399
self . projection_plan . physical_schema . clone ( )
@@ -3913,14 +3954,11 @@ mod test {
3913
3954
. unwrap ( ) ;
3914
3955
3915
3956
let dataset = Dataset :: open ( test_uri) . await . unwrap ( ) ;
3916
- assert_eq ! ( 32 , dataset. scan ( ) . count_rows( ) . await . unwrap( ) ) ;
3957
+ assert_eq ! ( 32 , dataset. count_rows( None ) . await . unwrap( ) ) ;
3917
3958
assert_eq ! (
3918
3959
16 ,
3919
3960
dataset
3920
- . scan( )
3921
- . filter( "`Filter_me` > 15" )
3922
- . unwrap( )
3923
- . count_rows( )
3961
+ . count_rows( Some ( "`Filter_me` > 15" . to_string( ) ) )
3924
3962
. await
3925
3963
. unwrap( )
3926
3964
) ;
@@ -3948,7 +3986,7 @@ mod test {
3948
3986
. unwrap ( ) ;
3949
3987
3950
3988
let dataset = Dataset :: open ( test_uri) . await . unwrap ( ) ;
3951
- assert_eq ! ( 32 , dataset. scan ( ) . count_rows( ) . await . unwrap( ) ) ;
3989
+ assert_eq ! ( dataset. count_rows( None ) . await . unwrap( ) , 32 ) ;
3952
3990
3953
3991
let mut scanner = dataset. scan ( ) ;
3954
3992
@@ -3996,7 +4034,7 @@ mod test {
3996
4034
. unwrap ( ) ;
3997
4035
3998
4036
let dataset = Dataset :: open ( test_uri) . await . unwrap ( ) ;
3999
- assert_eq ! ( 32 , dataset. scan ( ) . count_rows( ) . await . unwrap( ) ) ;
4037
+ assert_eq ! ( dataset. count_rows( None ) . await . unwrap( ) , 32 ) ;
4000
4038
4001
4039
let mut scanner = dataset. scan ( ) ;
4002
4040
@@ -4519,20 +4557,13 @@ mod test {
4519
4557
}
4520
4558
}
4521
4559
4522
- /// Assert that the plan when formatted matches the expected string.
4523
- ///
4524
- /// Within expected, you can use `...` to match any number of characters.
4525
- async fn assert_plan_equals (
4526
- dataset : & Dataset ,
4527
- plan : impl Fn ( & mut Scanner ) -> Result < & mut Scanner > ,
4560
+ async fn assert_plan_node_equals (
4561
+ plan_node : Arc < dyn ExecutionPlan > ,
4528
4562
expected : & str ,
4529
4563
) -> Result < ( ) > {
4530
- let mut scan = dataset. scan ( ) ;
4531
- plan ( & mut scan) ?;
4532
- let exec_plan = scan. create_plan ( ) . await ?;
4533
4564
let plan_desc = format ! (
4534
4565
"{}" ,
4535
- datafusion:: physical_plan:: displayable( exec_plan . as_ref( ) ) . indent( true )
4566
+ datafusion:: physical_plan:: displayable( plan_node . as_ref( ) ) . indent( true )
4536
4567
) ;
4537
4568
4538
4569
let to_match = expected. split ( "..." ) . collect :: < Vec < _ > > ( ) ;
@@ -4559,6 +4590,71 @@ mod test {
4559
4590
Ok ( ( ) )
4560
4591
}
4561
4592
4593
+ /// Assert that the plan when formatted matches the expected string.
4594
+ ///
4595
+ /// Within expected, you can use `...` to match any number of characters.
4596
+ async fn assert_plan_equals (
4597
+ dataset : & Dataset ,
4598
+ plan : impl Fn ( & mut Scanner ) -> Result < & mut Scanner > ,
4599
+ expected : & str ,
4600
+ ) -> Result < ( ) > {
4601
+ let mut scan = dataset. scan ( ) ;
4602
+ plan ( & mut scan) ?;
4603
+ let exec_plan = scan. create_plan ( ) . await ?;
4604
+ assert_plan_node_equals ( exec_plan, expected) . await
4605
+ }
4606
+
4607
+ #[ tokio:: test]
4608
+ async fn test_count_plan ( ) {
4609
+ // A count rows operation should load the minimal amount of data
4610
+ let dim = 256 ;
4611
+ let fixture = TestVectorDataset :: new_with_dimension ( LanceFileVersion :: Stable , true , dim)
4612
+ . await
4613
+ . unwrap ( ) ;
4614
+
4615
+ // By default, all columns are returned, this is bad for a count_rows op
4616
+ let err = fixture
4617
+ . dataset
4618
+ . scan ( )
4619
+ . create_count_plan ( )
4620
+ . await
4621
+ . unwrap_err ( ) ;
4622
+ assert ! ( matches!( err, Error :: InvalidInput { .. } ) ) ;
4623
+
4624
+ let mut scan = fixture. dataset . scan ( ) ;
4625
+ scan. project ( & Vec :: < String > :: default ( ) ) . unwrap ( ) ;
4626
+
4627
+ // with_row_id needs to be specified
4628
+ let err = scan. create_count_plan ( ) . await . unwrap_err ( ) ;
4629
+ assert ! ( matches!( err, Error :: InvalidInput { .. } ) ) ;
4630
+
4631
+ scan. with_row_id ( ) ;
4632
+
4633
+ let plan = scan. create_count_plan ( ) . await . unwrap ( ) ;
4634
+
4635
+ assert_plan_node_equals (
4636
+ plan,
4637
+ "AggregateExec: mode=Single, gby=[], aggr=[count_rows]
4638
+ LanceScan: uri=..., projection=[], row_id=true, row_addr=false, ordered=true" ,
4639
+ )
4640
+ . await
4641
+ . unwrap ( ) ;
4642
+
4643
+ scan. filter ( "s == ''" ) . unwrap ( ) ;
4644
+
4645
+ let plan = scan. create_count_plan ( ) . await . unwrap ( ) ;
4646
+
4647
+ assert_plan_node_equals (
4648
+ plan,
4649
+ "AggregateExec: mode=Single, gby=[], aggr=[count_rows]
4650
+ ProjectionExec: expr=[_rowid@1 as _rowid]
4651
+ FilterExec: s@0 =
4652
+ LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=true" ,
4653
+ )
4654
+ . await
4655
+ . unwrap ( ) ;
4656
+ }
4657
+
4562
4658
#[ rstest]
4563
4659
#[ tokio:: test]
4564
4660
async fn test_late_materialization (
0 commit comments