From 076208d3799b86193d2ba32ff63d7f0b5943e609 Mon Sep 17 00:00:00 2001 From: Yanxin Xiang Date: Sun, 9 Jun 2024 15:05:47 -0700 Subject: [PATCH 1/2] support tpch_1 consumer_producer_test --- datafusion/substrait/Cargo.toml | 1 + .../substrait/src/logical_plan/consumer.rs | 97 ++- datafusion/substrait/tests/cases/mod.rs | 1 + datafusion/substrait/tests/cases/tpch.rs | 63 ++ .../substrait/tests/testdata/query_1.json | 810 ++++++++++++++++++ .../tests/testdata/tpch/lineitem.csv | 2 + 6 files changed, 971 insertions(+), 3 deletions(-) create mode 100644 datafusion/substrait/tests/cases/tpch.rs create mode 100644 datafusion/substrait/tests/testdata/query_1.json create mode 100644 datafusion/substrait/tests/testdata/tpch/lineitem.csv diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 9322412c0ddb..59e4738074ae 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,6 +41,7 @@ object_store = { workspace = true } pbjson-types = "0.6" prost = "0.12" substrait = { version = "0.34.0", features = ["serde"] } +url = { workspace = true } [dev-dependencies] serde_json = "1.0" diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index d68711e8609c..caa9c9437083 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -22,6 +22,9 @@ use datafusion::arrow::datatypes::{ use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; +use substrait::proto::expression::literal::IntervalDayToSecond; +use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use url::Url; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; @@ -408,7 +411,6 @@ pub async fn from_substrait_rel( }; aggr_expr.push(agg_func?.as_ref().clone()); } - input.aggregate(group_expr, aggr_expr)?.build() } else { not_impl_err!("Aggregate without an input is not valid") @@ -569,7 +571,80 @@ pub async fn from_substrait_rel( Ok(LogicalPlan::Values(Values { schema, values })) } - _ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"), + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = + if name.starts_with("file://") && !name.starts_with("file:///") { + name.replacen("file://", "file:///", 1) + } else { + name.to_string() + }; + + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } + + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); + + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!( + "Only NamedTable and VirtualTable reads are supported" + ); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + let t = ctx.table(table_reference).await?; + let t = t.into_optimized_plan()?; + match &read.projection { + Some(MaskExpression { select, .. }) => match &select.as_ref() { + Some(projection) => { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + match &t { + LogicalPlan::TableScan(scan) => { + let fields = column_indices + .iter() + .map(|i| { + scan.projected_schema.qualified_field(*i) + }) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + let mut scan = scan.clone(); + scan.projection = Some(column_indices); + scan.projected_schema = + DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + HashMap::new(), + )?); + Ok(LogicalPlan::TableScan(scan)) + } + _ => plan_err!("unexpected plan for table"), + } + } + _ => Ok(t), + }, + _ => Ok(t), + } + } + _ => { + not_impl_err!("Only NamedTable and VirtualTable reads are supported") + } }, Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { Ok(set_op) => match set_op { @@ -810,7 +885,7 @@ pub async fn from_substrait_agg_func( f.function_reference ); }; - + let function_name = function_name.split(':').next().unwrap_or(function_name); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { Ok(Arc::new(Expr::AggregateFunction( @@ -818,6 +893,13 @@ pub async fn from_substrait_agg_func( ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { + match &fun { + // deal with situation that count(*) got no arguments + aggregate_function::AggregateFunction::Count if args.is_empty() => { + args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); + } + _ => {} + } Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), ))) @@ -1253,6 +1335,8 @@ fn from_substrait_type( r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( s, dfs_names, name_idx, )?)), + r#type::Kind::Varchar(_) => Ok(DataType::Utf8), + r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), }, _ => not_impl_err!("`None` Substrait kind is not supported"), @@ -1541,6 +1625,13 @@ fn from_substrait_literal( Some(LiteralType::Null(ntype)) => { from_substrait_null(ntype, dfs_names, name_idx)? } + Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days, + seconds, + microseconds, + })) => { + ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000)) + } Some(LiteralType::UserDefined(user_defined)) => { match user_defined.type_reference { INTERVAL_YEAR_MONTH_TYPE_REF => { diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index d049eb2c2121..365fbbb89a35 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -19,3 +19,4 @@ mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; +mod tpch; diff --git a/datafusion/substrait/tests/cases/tpch.rs b/datafusion/substrait/tests/cases/tpch.rs new file mode 100644 index 000000000000..d1c34ffe496b --- /dev/null +++ b/datafusion/substrait/tests/cases/tpch.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! tests contains in + +#[cfg(test)] +mod tests { + use datafusion::common::Result; + use datafusion::execution::options::CsvReadOptions; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use substrait::proto::Plan; + + #[tokio::test] + async fn tpch_test_1() -> Result<()> { + let ctx = create_context().await?; + let path = "tests/testdata/query_1.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + assert!( + format!("{:?}", plan).eq_ignore_ascii_case( + "Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \ + Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n \ + Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n \ + Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n \ + TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]" + ) + ); + Ok(()) + } + + async fn create_context() -> datafusion::common::Result { + let ctx = SessionContext::new(); + ctx.register_csv( + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + CsvReadOptions::default(), + ) + .await?; + Ok(ctx) + } +} diff --git a/datafusion/substrait/tests/testdata/query_1.json b/datafusion/substrait/tests/testdata/query_1.json new file mode 100644 index 000000000000..7dbce9959e5e --- /dev/null +++ b/datafusion/substrait/tests/testdata/query_1.json @@ -0,0 +1,810 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 3, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_datetime.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "lte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "subtract:date_day" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "multiply:opt_decimal_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 3, + "name": "subtract:opt_decimal_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 4, + "name": "add:opt_decimal_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 5, + "name": "sum:opt_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 6, + "name": "avg:opt_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 7, + "name": "count:opt" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "literal": { + "date": 10561, + "nullable": false, + "typeVariationReference": 0 + } + } + }, { + "value": { + "literal": { + "intervalDayToSecond": { + "days": 120, + "seconds": 0, + "microseconds": 0 + }, + "nullable": false, + "typeVariationReference": 0 + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }, { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 5, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 5, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 5, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 5, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 7, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["L_RETURNFLAG", "L_LINESTATUS", "SUM_QTY", "SUM_BASE_PRICE", "SUM_DISC_PRICE", "SUM_CHARGE", "AVG_QTY", "AVG_PRICE", "AVG_DISC", "COUNT_ORDER"] + } + }], + "expectedTypeUrls": [] + } + \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/lineitem.csv b/datafusion/substrait/tests/testdata/tpch/lineitem.csv new file mode 100644 index 000000000000..192ba86d7ab7 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/lineitem.csv @@ -0,0 +1,2 @@ +l_orderkey,l_partkey,l_suppkey,l_linenumber,l_quantity,l_extendedprice,l_discount,l_tax,l_returnflag,l_linestatus,l_shipdate,l_commitdate,l_receiptdate,l_shipinstruct,l_shipmode,l_comment +1,1,1,1,17,21168.23,0.04,0.02,'N','O','1996-03-13','1996-02-12','1996-03-22','DELIVER IN PERSON','TRUCK','egular courts above the' \ No newline at end of file From 27d3f2b643acf6107a08173d0a150b66062b482c Mon Sep 17 00:00:00 2001 From: Lordworms Date: Mon, 10 Jun 2024 11:06:19 -0700 Subject: [PATCH 2/2] refactor and optimize code --- .../substrait/src/logical_plan/consumer.rs | 143 +++++++----------- .../{tpch.rs => consumer_integration.rs} | 9 +- datafusion/substrait/tests/cases/mod.rs | 2 +- .../testdata/tpch_substrait_plans/README.md | 22 +++ .../{ => tpch_substrait_plans}/query_1.json | 0 5 files changed, 87 insertions(+), 89 deletions(-) rename datafusion/substrait/tests/cases/{tpch.rs => consumer_integration.rs} (90%) create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/README.md rename datafusion/substrait/tests/testdata/{ => tpch_substrait_plans}/query_1.json (100%) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index caa9c9437083..c0e99759ce41 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -48,7 +48,7 @@ use datafusion::{ use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; +use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -132,14 +132,7 @@ fn scalar_function_type_from_str( name: &str, ) -> Result { let s = ctx.state(); - let name = match name.rsplit_once(':') { - // Since 0.32.0, Substrait requires the function names to be in a compound format - // https://substrait.io/extensions/#function-signature-compound-names - // for example, `add:i8_i8`. - // On the consumer side, we don't really care about the signature though, just the name. - Some((name, _)) => name, - None => name, - }; + let name = substrait_fun_name(name); if let Some(func) = s.scalar_functions().get(name) { return Ok(ScalarFunctionType::Udf(func.to_owned())); @@ -156,6 +149,18 @@ fn scalar_function_type_from_str( not_impl_err!("Unsupported function name: {name:?}") } +pub fn substrait_fun_name(name: &str) -> &str { + let name = match name.rsplit_once(':') { + // Since 0.32.0, Substrait requires the function names to be in a compound format + // https://substrait.io/extensions/#function-signature-compound-names + // for example, `add:i8_i8`. + // On the consumer side, we don't really care about the signature though, just the name. + Some((name, _)) => name, + None => name, + }; + name +} + fn split_eq_and_noneq_join_predicate_with_nulls_equality( filter: &Expr, ) -> (Vec<(Column, Column)>, bool, Option) { @@ -242,6 +247,43 @@ pub async fn from_substrait_plan( } } +/// parse projection +pub fn extract_projection( + t: LogicalPlan, + projection: &::core::option::Option, +) -> Result { + match projection { + Some(MaskExpression { select, .. }) => match &select.as_ref() { + Some(projection) => { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + match t { + LogicalPlan::TableScan(mut scan) => { + let fields = column_indices + .iter() + .map(|i| scan.projected_schema.qualified_field(*i)) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + scan.projection = Some(column_indices); + scan.projected_schema = DFSchemaRef::new( + DFSchema::new_with_metadata(fields, HashMap::new())?, + ); + Ok(LogicalPlan::TableScan(scan)) + } + _ => plan_err!("unexpected plan for table"), + } + } + _ => Ok(t), + }, + _ => Ok(t), + } +} + /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] pub async fn from_substrait_rel( @@ -491,41 +533,7 @@ pub async fn from_substrait_rel( }; let t = ctx.table(table_reference).await?; let t = t.into_optimized_plan()?; - match &read.projection { - Some(MaskExpression { select, .. }) => match &select.as_ref() { - Some(projection) => { - let column_indices: Vec = projection - .struct_items - .iter() - .map(|item| item.field as usize) - .collect(); - match &t { - LogicalPlan::TableScan(scan) => { - let fields = column_indices - .iter() - .map(|i| { - scan.projected_schema.qualified_field(*i) - }) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - let mut scan = scan.clone(); - scan.projection = Some(column_indices); - scan.projected_schema = - DFSchemaRef::new(DFSchema::new_with_metadata( - fields, - HashMap::new(), - )?); - Ok(LogicalPlan::TableScan(scan)) - } - _ => plan_err!("unexpected plan for table"), - } - } - _ => Ok(t), - }, - _ => Ok(t), - } + extract_projection(t, &read.projection) } Some(ReadType::VirtualTable(vt)) => { let base_schema = read.base_schema.as_ref().ok_or_else(|| { @@ -597,54 +605,16 @@ pub async fn from_substrait_rel( }); if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!( - "Only NamedTable and VirtualTable reads are supported" - ); + return not_impl_err!("Only single file reads are supported"); } let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; let t = ctx.table(table_reference).await?; let t = t.into_optimized_plan()?; - match &read.projection { - Some(MaskExpression { select, .. }) => match &select.as_ref() { - Some(projection) => { - let column_indices: Vec = projection - .struct_items - .iter() - .map(|item| item.field as usize) - .collect(); - match &t { - LogicalPlan::TableScan(scan) => { - let fields = column_indices - .iter() - .map(|i| { - scan.projected_schema.qualified_field(*i) - }) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - let mut scan = scan.clone(); - scan.projection = Some(column_indices); - scan.projected_schema = - DFSchemaRef::new(DFSchema::new_with_metadata( - fields, - HashMap::new(), - )?); - Ok(LogicalPlan::TableScan(scan)) - } - _ => plan_err!("unexpected plan for table"), - } - } - _ => Ok(t), - }, - _ => Ok(t), - } - } - _ => { - not_impl_err!("Only NamedTable and VirtualTable reads are supported") + extract_projection(t, &read.projection) } + _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), }, Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { Ok(set_op) => match set_op { @@ -885,7 +855,8 @@ pub async fn from_substrait_agg_func( f.function_reference ); }; - let function_name = function_name.split(':').next().unwrap_or(function_name); + // function_name.split(':').next().unwrap_or(function_name); + let function_name = substrait_fun_name((**function_name).as_str()); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { Ok(Arc::new(Expr::AggregateFunction( diff --git a/datafusion/substrait/tests/cases/tpch.rs b/datafusion/substrait/tests/cases/consumer_integration.rs similarity index 90% rename from datafusion/substrait/tests/cases/tpch.rs rename to datafusion/substrait/tests/cases/consumer_integration.rs index d1c34ffe496b..c2ae5691134a 100644 --- a/datafusion/substrait/tests/cases/tpch.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -15,7 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! tests contains in +//! TPCH `substrait_consumer` tests +//! +//! This module tests that substrait plans as json encoded protobuf can be +//! correctly read as DataFusion plans. +//! +//! The input data comes from #[cfg(test)] mod tests { @@ -30,7 +35,7 @@ mod tests { #[tokio::test] async fn tpch_test_1() -> Result<()> { let ctx = create_context().await?; - let path = "tests/testdata/query_1.json"; + let path = "tests/testdata/tpch_substrait_plans/query_1.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), )) diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index 365fbbb89a35..a31f93087d83 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod consumer_integration; mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; -mod tpch; diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/README.md b/datafusion/substrait/tests/testdata/tpch_substrait_plans/README.md new file mode 100644 index 000000000000..ffcd38dfb88d --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/README.md @@ -0,0 +1,22 @@ + + +# Apache DataFusion Substrait consumer integration test + +these test json files come from [consumer-testing](https://github.com/substrait-io/consumer-testing/tree/main/substrait_consumer/tests/integration/queries/tpch_substrait_plans) diff --git a/datafusion/substrait/tests/testdata/query_1.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_1.json similarity index 100% rename from datafusion/substrait/tests/testdata/query_1.json rename to datafusion/substrait/tests/testdata/tpch_substrait_plans/query_1.json