Skip to content

Commit e34fc4d

Browse files
authored
feat: add support for accepting Substrait ExtendedExpression messages as filters (#1863)
1 parent a150a4b commit e34fc4d

File tree

8 files changed

+293
-5
lines changed

8 files changed

+293
-5
lines changed

python/python/lance/dataset.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,7 @@ def __init__(self, ds: LanceDataset):
15471547
self.ds = ds
15481548
self._limit = 0
15491549
self._filter = None
1550+
self._substrait_filter = None
15501551
self._prefilter = None
15511552
self._offset = None
15521553
self._columns = None
@@ -1607,8 +1608,38 @@ def columns(self, cols: Optional[list[str]] = None) -> ScannerBuilder:
16071608

16081609
def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder:
16091610
if isinstance(filter, pa.compute.Expression):
1610-
filter = str(filter)
1611-
self._filter = filter
1611+
try:
1612+
from pyarrow.substrait import serialize_expressions
1613+
1614+
fields_without_lists = []
1615+
counter = 0
1616+
# Pyarrow cannot handle fixed size lists when converting
1617+
# types to Substrait. So we can't use those in our filter,
1618+
# which is ok for now but we need to replace them with some
1619+
# kind of placeholder because Substrait is going to use
1620+
# ordinal field references and we want to make sure those are
1621+
# correct.
1622+
for field in self.ds.schema:
1623+
if pa.types.is_fixed_size_list(field.type):
1624+
pos = counter
1625+
counter += 1
1626+
fields_without_lists.append(
1627+
pa.field(f"__unlikely_name_placeholder_{pos}", pa.int8())
1628+
)
1629+
else:
1630+
fields_without_lists.append(field)
1631+
# Serialize the pyarrow compute expression toSubstrait and use
1632+
# that as a filter.
1633+
scalar_schema = pa.schema(fields_without_lists)
1634+
self._substrait_filter = serialize_expressions(
1635+
[filter], ["my_filter"], scalar_schema
1636+
)
1637+
except ImportError:
1638+
# serialize_expressions was introduced in pyarrow 14. Fallback to
1639+
# stringifying the expression if pyarrow is too old
1640+
self._filter = str(filter)
1641+
else:
1642+
self._filter = filter
16121643
return self
16131644

16141645
def prefilter(self, prefilter: bool) -> ScannerBuilder:
@@ -1709,6 +1740,7 @@ def to_scanner(self) -> LanceScanner:
17091740
self._fragments,
17101741
self._with_row_id,
17111742
self._use_stats,
1743+
self._substrait_filter,
17121744
)
17131745
return LanceScanner(scanner, self.ds)
17141746

python/python/tests/test_dataset.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,10 @@ def test_pickle(tmp_path: Path):
404404

405405

406406
def test_polar_scan(tmp_path: Path):
407-
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
407+
some_structs = [{"x": counter, "y": counter} for counter in range(100)]
408+
table = pa.Table.from_pydict(
409+
{"a": range(100), "b": range(100), "struct": some_structs}
410+
)
408411
base_dir = tmp_path / "test"
409412
lance.write_dataset(table, base_dir)
410413

@@ -413,6 +416,32 @@ def test_polar_scan(tmp_path: Path):
413416
df = dataset.to_table().to_pandas()
414417
tm.assert_frame_equal(polars_df.collect().to_pandas(), df)
415418

419+
# Note, this doesn't verify that the filter is actually pushed down.
420+
# It only checks that, if the filter is pushed down, we interpret it
421+
# correctly.
422+
def check_pushdown_filt(pl_filt, sql_filt):
423+
polars_df = pl.scan_pyarrow_dataset(dataset).filter(pl_filt)
424+
df = dataset.to_table(filter=sql_filt).to_pandas()
425+
tm.assert_frame_equal(polars_df.collect().to_pandas(), df)
426+
427+
# These three should push down (but we don't verify)
428+
check_pushdown_filt(pl.col("a") > 50, "a > 50")
429+
check_pushdown_filt(~(pl.col("a") > 50), "a <= 50")
430+
check_pushdown_filt(pl.col("a").is_in([50, 51, 52]), "a IN (50, 51, 52)")
431+
# At the current moment it seems polars cannot pushdown this
432+
# kind of filter
433+
check_pushdown_filt((pl.col("a") + 3) < 100, "(a + 3) < 100")
434+
435+
# I can't seem to get struct["x"] to work in Lance but maybe there is
436+
# a way. For now, let's compare it directly to the pyarrow compute version
437+
438+
# Doesn't yet work today :( due to upstream issue (datafusion's substrait parser
439+
# doesn't yet handle nested refs)
440+
# if pa.cpp_version_info.major >= 14:
441+
# polars_df = pl.scan_pyarrow_dataset(dataset).filter(pl.col("struct.x") < 10)
442+
# df = dataset.to_table(filter=pc.field("struct", "x") < 10).to_pandas()
443+
# tm.assert_frame_equal(polars_df.collect().to_pandas(), df)
444+
416445

417446
def test_count_fragments(tmp_path: Path):
418447
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})

python/python/tests/test_filter.py

+5
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def test_simple_predicates(dataset):
7474
pc.field("float") >= 30.0,
7575
pc.field("str") != "aa",
7676
pc.field("str") == "aa",
77+
(pc.field("int") >= 50) & (pc.field("int") < 200),
78+
pc.invert(pc.field("int") >= 50),
79+
pc.is_null(pc.field("int")),
80+
pc.field("int") + 3 >= 50,
81+
pc.is_valid(pc.field("int")),
7782
]
7883
# test simple
7984
for expr in predicates:

python/src/dataset.rs

+11
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ impl Dataset {
289289
fragments: Option<Vec<FileFragment>>,
290290
with_row_id: Option<bool>,
291291
use_stats: Option<bool>,
292+
substrait_filter: Option<Vec<u8>>,
292293
) -> PyResult<Scanner> {
293294
let mut scanner: LanceScanner = self_.ds.scan();
294295
if let Some(c) = columns {
@@ -297,10 +298,20 @@ impl Dataset {
297298
.map_err(|err| PyValueError::new_err(err.to_string()))?;
298299
}
299300
if let Some(f) = filter {
301+
if substrait_filter.is_some() {
302+
return Err(PyValueError::new_err(
303+
"cannot specify both a string filter and a substrait filter",
304+
));
305+
}
300306
scanner
301307
.filter(f.as_str())
302308
.map_err(|err| PyValueError::new_err(err.to_string()))?;
303309
}
310+
if let Some(f) = substrait_filter {
311+
RT.runtime
312+
.block_on(scanner.filter_substrait(f.as_slice()))
313+
.map_err(|err| PyIOError::new_err(err.to_string()))?;
314+
}
304315
if let Some(prefilter) = prefilter {
305316
scanner.prefilter(prefilter);
306317
}

rust/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ datafusion-common = "35.0"
8484
datafusion-sql = "35.0"
8585
datafusion-expr = "35.0"
8686
datafusion-physical-expr = "35.0"
87+
datafusion-substrait = "35.0"
8788
either = "1.0"
8889
futures = "0.3"
8990
http = "0.2.9"
@@ -109,6 +110,8 @@ serde = { version = "^1" }
109110
serde_json = { version = "1" }
110111
shellexpand = "3.0"
111112
snafu = "0.7.4"
113+
substrait = "0.22.1"
114+
substrait-expr = "0.2.0"
112115
tempfile = "3"
113116
tokio = { version = "1.23", features = [
114117
"rt-multi-thread",

rust/lance-datafusion/Cargo.toml

+7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ async-trait.workspace = true
1818
datafusion.workspace = true
1919
datafusion-common.workspace = true
2020
datafusion-physical-expr.workspace = true
21+
datafusion-substrait.workspace = true
2122
futures.workspace = true
2223
lance-arrow.workspace = true
2324
lance-core = { workspace = true, features = ["datafusion"] }
25+
prost.workspace = true
26+
snafu.workspace = true
27+
substrait.workspace = true
2428
tokio.workspace = true
29+
30+
[dev-dependencies]
31+
substrait-expr.workspace = true

rust/lance-datafusion/src/expr.rs

+190-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,24 @@ use std::sync::Arc;
1818

1919
use arrow::compute::cast;
2020
use arrow_array::{cast::AsArray, ArrayRef};
21-
use arrow_schema::DataType;
22-
use datafusion_common::ScalarValue;
21+
use arrow_schema::{DataType, Schema};
22+
use datafusion::{
23+
datasource::empty::EmptyTable, execution::context::SessionContext, logical_expr::Expr,
24+
};
25+
use datafusion_common::{
26+
tree_node::{Transformed, TreeNode},
27+
Column, DataFusionError, ScalarValue, TableReference,
28+
};
29+
use prost::Message;
30+
use snafu::{location, Location};
31+
32+
use lance_core::{Error, Result};
33+
use substrait::proto::{
34+
expression_reference::ExprType,
35+
plan_rel::RelType,
36+
read_rel::{NamedTable, ReadType},
37+
rel, ExtendedExpression, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot,
38+
};
2339

2440
// This is slightly tedious but when we convert expressions from SQL strings to logical
2541
// datafusion expressions there is no type coercion that happens. In other words "x = 7"
@@ -284,3 +300,175 @@ pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarVa
284300
_ => None,
285301
}
286302
}
303+
304+
/// Convert a Substrait ExtendedExpressions message into a DF Expr
305+
///
306+
/// The ExtendedExpressions message must contain a single scalar expression
307+
pub async fn parse_substrait(expr: &[u8], input_schema: Arc<Schema>) -> Result<Expr> {
308+
let envelope = ExtendedExpression::decode(expr)?;
309+
if envelope.referred_expr.is_empty() {
310+
return Err(Error::InvalidInput {
311+
source: "the provided substrait expression is empty (contains no expressions)".into(),
312+
location: location!(),
313+
});
314+
}
315+
if envelope.referred_expr.len() > 1 {
316+
return Err(Error::InvalidInput {
317+
source: format!(
318+
"the provided substrait expression had {} expressions when only 1 was expected",
319+
envelope.referred_expr.len()
320+
)
321+
.into(),
322+
location: location!(),
323+
});
324+
}
325+
let expr = match &envelope.referred_expr[0].expr_type {
326+
None => Err(Error::InvalidInput {
327+
source: "the provided substrait had an expression but was missing an expr_type".into(),
328+
location: location!(),
329+
}),
330+
Some(ExprType::Expression(expr)) => Ok(expr.clone()),
331+
_ => Err(Error::InvalidInput {
332+
source: "the provided substrait was not a scalar expression".into(),
333+
location: location!(),
334+
}),
335+
}?;
336+
337+
// Datafusion's substrait consumer only supports Plan (not ExtendedExpression) and so
338+
// we need to create a dummy plan with a single project node
339+
let plan = Plan {
340+
version: None,
341+
extensions: envelope.extensions.clone(),
342+
advanced_extensions: envelope.advanced_extensions.clone(),
343+
expected_type_urls: envelope.expected_type_urls.clone(),
344+
extension_uris: envelope.extension_uris.clone(),
345+
relations: vec![PlanRel {
346+
rel_type: Some(RelType::Root(RelRoot {
347+
input: Some(Rel {
348+
rel_type: Some(rel::RelType::Project(Box::new(ProjectRel {
349+
common: None,
350+
input: Some(Box::new(Rel {
351+
rel_type: Some(rel::RelType::Read(Box::new(ReadRel {
352+
common: None,
353+
base_schema: envelope.base_schema.clone(),
354+
filter: None,
355+
best_effort_filter: None,
356+
projection: None,
357+
advanced_extension: None,
358+
read_type: Some(ReadType::NamedTable(NamedTable {
359+
names: vec!["dummy".to_string()],
360+
advanced_extension: None,
361+
})),
362+
}))),
363+
})),
364+
expressions: vec![expr],
365+
advanced_extension: None,
366+
}))),
367+
}),
368+
// Not technically accurate but pretty sure DF ignores this
369+
names: vec![],
370+
})),
371+
}],
372+
};
373+
374+
let session_context = SessionContext::new();
375+
let dummy_table = Arc::new(EmptyTable::new(input_schema));
376+
session_context.register_table(
377+
TableReference::Bare {
378+
table: "dummy".into(),
379+
},
380+
dummy_table,
381+
)?;
382+
let df_plan =
383+
datafusion_substrait::logical_plan::consumer::from_substrait_plan(&session_context, &plan)
384+
.await?;
385+
386+
let expr = df_plan.expressions().pop().unwrap();
387+
388+
// When DF parses the above plan it turns column references into qualified references
389+
// into `dummy` (e.g. we get `WHERE dummy.x < 0` instead of `WHERE x < 0`) We want
390+
// these to be unqualified references instead and so we need a quick trasnformation pass
391+
392+
let expr = expr.transform(&|node| match node {
393+
Expr::Column(column) => {
394+
if let Some(relation) = column.relation {
395+
match relation {
396+
TableReference::Bare { table } => {
397+
if table == "dummy" {
398+
Ok(Transformed::Yes(Expr::Column(Column {
399+
relation: None,
400+
name: column.name,
401+
})))
402+
} else {
403+
// This should not be possible
404+
Err(DataFusionError::Substrait(format!(
405+
"Unexpected reference to table {} found when parsing filter",
406+
table
407+
)))
408+
}
409+
}
410+
// This should not be possible
411+
_ => Err(DataFusionError::Substrait("Unexpected partially or fully qualified table reference encountered when parsing filter".into()))
412+
}
413+
} else {
414+
Ok(Transformed::No(Expr::Column(column)))
415+
}
416+
}
417+
_ => Ok(Transformed::No(node)),
418+
})?;
419+
Ok(expr)
420+
}
421+
422+
#[cfg(test)]
423+
mod tests {
424+
use super::*;
425+
426+
use arrow_schema::Field;
427+
use datafusion::logical_expr::{BinaryExpr, Operator};
428+
use datafusion_common::Column;
429+
use prost::Message;
430+
use substrait_expr::{
431+
builder::{schema::SchemaBuildersExt, BuilderParams, ExpressionsBuilder},
432+
functions::functions_comparison::FunctionsComparisonExt,
433+
helpers::{literals::literal, schema::SchemaInfo},
434+
};
435+
436+
#[tokio::test]
437+
async fn test_substrait_conversion() {
438+
let schema = SchemaInfo::new_full()
439+
.field("x", substrait_expr::helpers::types::i32(true))
440+
.build();
441+
let expr_builder = ExpressionsBuilder::new(schema, BuilderParams::default());
442+
expr_builder
443+
.add_expression(
444+
"filter_mask",
445+
expr_builder
446+
.functions()
447+
.lt(
448+
expr_builder.fields().resolve_by_name("x").unwrap(),
449+
literal(0_i32),
450+
)
451+
.build()
452+
.unwrap(),
453+
)
454+
.unwrap();
455+
let expr = expr_builder.build();
456+
let expr_bytes = expr.encode_to_vec();
457+
458+
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)]));
459+
460+
let df_expr = parse_substrait(expr_bytes.as_slice(), schema)
461+
.await
462+
.unwrap();
463+
464+
let expected = Expr::BinaryExpr(BinaryExpr {
465+
left: Box::new(Expr::Column(Column {
466+
relation: None,
467+
name: "x".to_string(),
468+
})),
469+
op: Operator::Lt,
470+
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))),
471+
});
472+
assert_eq!(df_expr, expected);
473+
}
474+
}

0 commit comments

Comments
 (0)