@@ -18,8 +18,24 @@ use std::sync::Arc;
18
18
19
19
use arrow:: compute:: cast;
20
20
use arrow_array:: { cast:: AsArray , ArrayRef } ;
21
- use arrow_schema:: DataType ;
22
- use datafusion_common:: ScalarValue ;
21
+ use arrow_schema:: { DataType , Schema } ;
22
+ use datafusion:: {
23
+ datasource:: empty:: EmptyTable , execution:: context:: SessionContext , logical_expr:: Expr ,
24
+ } ;
25
+ use datafusion_common:: {
26
+ tree_node:: { Transformed , TreeNode } ,
27
+ Column , DataFusionError , ScalarValue , TableReference ,
28
+ } ;
29
+ use prost:: Message ;
30
+ use snafu:: { location, Location } ;
31
+
32
+ use lance_core:: { Error , Result } ;
33
+ use substrait:: proto:: {
34
+ expression_reference:: ExprType ,
35
+ plan_rel:: RelType ,
36
+ read_rel:: { NamedTable , ReadType } ,
37
+ rel, ExtendedExpression , Plan , PlanRel , ProjectRel , ReadRel , Rel , RelRoot ,
38
+ } ;
23
39
24
40
// This is slightly tedious but when we convert expressions from SQL strings to logical
25
41
// datafusion expressions there is no type coercion that happens. In other words "x = 7"
@@ -284,3 +300,175 @@ pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarVa
284
300
_ => None ,
285
301
}
286
302
}
303
+
304
+ /// Convert a Substrait ExtendedExpressions message into a DF Expr
305
+ ///
306
+ /// The ExtendedExpressions message must contain a single scalar expression
307
+ pub async fn parse_substrait ( expr : & [ u8 ] , input_schema : Arc < Schema > ) -> Result < Expr > {
308
+ let envelope = ExtendedExpression :: decode ( expr) ?;
309
+ if envelope. referred_expr . is_empty ( ) {
310
+ return Err ( Error :: InvalidInput {
311
+ source : "the provided substrait expression is empty (contains no expressions)" . into ( ) ,
312
+ location : location ! ( ) ,
313
+ } ) ;
314
+ }
315
+ if envelope. referred_expr . len ( ) > 1 {
316
+ return Err ( Error :: InvalidInput {
317
+ source : format ! (
318
+ "the provided substrait expression had {} expressions when only 1 was expected" ,
319
+ envelope. referred_expr. len( )
320
+ )
321
+ . into ( ) ,
322
+ location : location ! ( ) ,
323
+ } ) ;
324
+ }
325
+ let expr = match & envelope. referred_expr [ 0 ] . expr_type {
326
+ None => Err ( Error :: InvalidInput {
327
+ source : "the provided substrait had an expression but was missing an expr_type" . into ( ) ,
328
+ location : location ! ( ) ,
329
+ } ) ,
330
+ Some ( ExprType :: Expression ( expr) ) => Ok ( expr. clone ( ) ) ,
331
+ _ => Err ( Error :: InvalidInput {
332
+ source : "the provided substrait was not a scalar expression" . into ( ) ,
333
+ location : location ! ( ) ,
334
+ } ) ,
335
+ } ?;
336
+
337
+ // Datafusion's substrait consumer only supports Plan (not ExtendedExpression) and so
338
+ // we need to create a dummy plan with a single project node
339
+ let plan = Plan {
340
+ version : None ,
341
+ extensions : envelope. extensions . clone ( ) ,
342
+ advanced_extensions : envelope. advanced_extensions . clone ( ) ,
343
+ expected_type_urls : envelope. expected_type_urls . clone ( ) ,
344
+ extension_uris : envelope. extension_uris . clone ( ) ,
345
+ relations : vec ! [ PlanRel {
346
+ rel_type: Some ( RelType :: Root ( RelRoot {
347
+ input: Some ( Rel {
348
+ rel_type: Some ( rel:: RelType :: Project ( Box :: new( ProjectRel {
349
+ common: None ,
350
+ input: Some ( Box :: new( Rel {
351
+ rel_type: Some ( rel:: RelType :: Read ( Box :: new( ReadRel {
352
+ common: None ,
353
+ base_schema: envelope. base_schema. clone( ) ,
354
+ filter: None ,
355
+ best_effort_filter: None ,
356
+ projection: None ,
357
+ advanced_extension: None ,
358
+ read_type: Some ( ReadType :: NamedTable ( NamedTable {
359
+ names: vec![ "dummy" . to_string( ) ] ,
360
+ advanced_extension: None ,
361
+ } ) ) ,
362
+ } ) ) ) ,
363
+ } ) ) ,
364
+ expressions: vec![ expr] ,
365
+ advanced_extension: None ,
366
+ } ) ) ) ,
367
+ } ) ,
368
+ // Not technically accurate but pretty sure DF ignores this
369
+ names: vec![ ] ,
370
+ } ) ) ,
371
+ } ] ,
372
+ } ;
373
+
374
+ let session_context = SessionContext :: new ( ) ;
375
+ let dummy_table = Arc :: new ( EmptyTable :: new ( input_schema) ) ;
376
+ session_context. register_table (
377
+ TableReference :: Bare {
378
+ table : "dummy" . into ( ) ,
379
+ } ,
380
+ dummy_table,
381
+ ) ?;
382
+ let df_plan =
383
+ datafusion_substrait:: logical_plan:: consumer:: from_substrait_plan ( & session_context, & plan)
384
+ . await ?;
385
+
386
+ let expr = df_plan. expressions ( ) . pop ( ) . unwrap ( ) ;
387
+
388
+ // When DF parses the above plan it turns column references into qualified references
389
+ // into `dummy` (e.g. we get `WHERE dummy.x < 0` instead of `WHERE x < 0`) We want
390
+ // these to be unqualified references instead and so we need a quick trasnformation pass
391
+
392
+ let expr = expr. transform ( & |node| match node {
393
+ Expr :: Column ( column) => {
394
+ if let Some ( relation) = column. relation {
395
+ match relation {
396
+ TableReference :: Bare { table } => {
397
+ if table == "dummy" {
398
+ Ok ( Transformed :: Yes ( Expr :: Column ( Column {
399
+ relation : None ,
400
+ name : column. name ,
401
+ } ) ) )
402
+ } else {
403
+ // This should not be possible
404
+ Err ( DataFusionError :: Substrait ( format ! (
405
+ "Unexpected reference to table {} found when parsing filter" ,
406
+ table
407
+ ) ) )
408
+ }
409
+ }
410
+ // This should not be possible
411
+ _ => Err ( DataFusionError :: Substrait ( "Unexpected partially or fully qualified table reference encountered when parsing filter" . into ( ) ) )
412
+ }
413
+ } else {
414
+ Ok ( Transformed :: No ( Expr :: Column ( column) ) )
415
+ }
416
+ }
417
+ _ => Ok ( Transformed :: No ( node) ) ,
418
+ } ) ?;
419
+ Ok ( expr)
420
+ }
421
+
422
+ #[ cfg( test) ]
423
+ mod tests {
424
+ use super :: * ;
425
+
426
+ use arrow_schema:: Field ;
427
+ use datafusion:: logical_expr:: { BinaryExpr , Operator } ;
428
+ use datafusion_common:: Column ;
429
+ use prost:: Message ;
430
+ use substrait_expr:: {
431
+ builder:: { schema:: SchemaBuildersExt , BuilderParams , ExpressionsBuilder } ,
432
+ functions:: functions_comparison:: FunctionsComparisonExt ,
433
+ helpers:: { literals:: literal, schema:: SchemaInfo } ,
434
+ } ;
435
+
436
+ #[ tokio:: test]
437
+ async fn test_substrait_conversion ( ) {
438
+ let schema = SchemaInfo :: new_full ( )
439
+ . field ( "x" , substrait_expr:: helpers:: types:: i32 ( true ) )
440
+ . build ( ) ;
441
+ let expr_builder = ExpressionsBuilder :: new ( schema, BuilderParams :: default ( ) ) ;
442
+ expr_builder
443
+ . add_expression (
444
+ "filter_mask" ,
445
+ expr_builder
446
+ . functions ( )
447
+ . lt (
448
+ expr_builder. fields ( ) . resolve_by_name ( "x" ) . unwrap ( ) ,
449
+ literal ( 0_i32 ) ,
450
+ )
451
+ . build ( )
452
+ . unwrap ( ) ,
453
+ )
454
+ . unwrap ( ) ;
455
+ let expr = expr_builder. build ( ) ;
456
+ let expr_bytes = expr. encode_to_vec ( ) ;
457
+
458
+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "x" , DataType :: Int32 , true ) ] ) ) ;
459
+
460
+ let df_expr = parse_substrait ( expr_bytes. as_slice ( ) , schema)
461
+ . await
462
+ . unwrap ( ) ;
463
+
464
+ let expected = Expr :: BinaryExpr ( BinaryExpr {
465
+ left : Box :: new ( Expr :: Column ( Column {
466
+ relation : None ,
467
+ name : "x" . to_string ( ) ,
468
+ } ) ) ,
469
+ op : Operator :: Lt ,
470
+ right : Box :: new ( Expr :: Literal ( ScalarValue :: Int32 ( Some ( 0 ) ) ) ) ,
471
+ } ) ;
472
+ assert_eq ! ( df_expr, expected) ;
473
+ }
474
+ }
0 commit comments