From e0d259a70eface3f5f22d00cff87efc76636b24d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 25 Jun 2024 07:57:40 +0800 Subject: [PATCH] Rewrite array operator to function in parser (#11101) * rewrite func Signed-off-by: jayzhan211 * remove rule in analyzer Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/tests/sql/select.rs | 2 +- datafusion/expr/src/type_coercion/binary.rs | 24 ++--- datafusion/functions-array/src/rewrite.rs | 109 +------------------- datafusion/sql/src/expr/mod.rs | 72 +++++++++++-- 4 files changed, 77 insertions(+), 130 deletions(-) diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index f2710e6592409..d9ef462df26c8 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -246,7 +246,7 @@ async fn test_parameter_invalid_types() -> Result<()> { .await; assert_eq!( results.unwrap_err().strip_backtrace(), - "Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })" + "type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32" ); Ok(()) } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index ea9d0c2fe72ec..d83fbfe49bc25 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -889,21 +889,18 @@ fn dictionary_coercion( /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - string_coercion(lhs_type, rhs_type) - .or_else(|| list_coercion(lhs_type, rhs_type)) - .or(match (lhs_type, rhs_type) { - (Utf8, from_type) | (from_type, Utf8) => { - string_concat_internal_coercion(from_type, &Utf8) - } - (LargeUtf8, from_type) | (from_type, LargeUtf8) => { - string_concat_internal_coercion(from_type, &LargeUtf8) - } - _ => None, - }) + string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + (Utf8, from_type) | (from_type, Utf8) => { + string_concat_internal_coercion(from_type, &Utf8) + } + (LargeUtf8, from_type) | (from_type, LargeUtf8) => { + string_concat_internal_coercion(from_type, &LargeUtf8) + } + _ => None, + }) } fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - // TODO: cast between array elements (#6558) if lhs_type.equals_datatype(rhs_type) { Some(lhs_type.to_owned()) } else { @@ -952,10 +949,7 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - // TODO: cast between array elements (#6558) (List(_), List(_)) => Some(lhs_type.clone()), - (List(_), _) => Some(lhs_type.clone()), - (_, List(_)) => Some(rhs_type.clone()), _ => None, } } diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index d18f1f8a3cbb6..28bc2d5e43730 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -18,12 +18,10 @@ //! Rewrites for using Array Functions use crate::array_has::array_has_all; -use crate::concat::{array_append, array_concat, array_prepend}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; -use datafusion_common::utils::list_ndims; +use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_common::{Column, DFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{BinaryExpr, Expr, Operator}; @@ -39,7 +37,7 @@ impl FunctionRewrite for ArrayFunctionRewriter { fn rewrite( &self, expr: Expr, - schema: &DFSchema, + _schema: &DFSchema, _config: &ConfigOptions, ) -> Result> { let transformed = match expr { @@ -61,91 +59,6 @@ impl FunctionRewrite for ArrayFunctionRewriter { Transformed::yes(array_has_all(*right, *left)) } - // Column cases: - // 1) array_prepend/append/concat || column - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) - && as_col(&right).is_some() => - { - let c = as_col(&right).unwrap(); - let d = schema.field_from_column(c)?.data_type(); - let ndim = list_ndims(d); - match ndim { - 0 => Transformed::yes(array_append(*left, *right)), - _ => Transformed::yes(array_concat(vec![*left, *right])), - } - } - // 2) select column1 || column2 - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && as_col(&left).is_some() - && as_col(&right).is_some() => - { - let c1 = as_col(&left).unwrap(); - let c2 = as_col(&right).unwrap(); - let d1 = schema.field_from_column(c1)?.data_type(); - let d2 = schema.field_from_column(c2)?.data_type(); - let ndim1 = list_ndims(d1); - let ndim2 = list_ndims(d2); - match (ndim1, ndim2) { - (0, _) => Transformed::yes(array_prepend(*left, *right)), - (_, 0) => Transformed::yes(array_append(*left, *right)), - _ => Transformed::yes(array_concat(vec![*left, *right])), - } - } - - // Chain concat operator (a || b) || array, - // (array_concat, array_append, array_prepend) || array -> array concat - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) - && is_func(&right, "make_array") => - { - Transformed::yes(array_concat(vec![*left, *right])) - } - - // Chain concat operator (a || b) || scalar, - // (array_concat, array_append, array_prepend) || scalar -> array append - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) => - { - Transformed::yes(array_append(*left, *right)) - } - - // array || array -> array concat - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_concat(vec![*left, *right])) - } - - // array || scalar -> array append - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat && is_func(&left, "make_array") => - { - Transformed::yes(array_append(*left, *right)) - } - - // scalar || array -> array prepend - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat && is_func(&right, "make_array") => - { - Transformed::yes(array_prepend(*left, *right)) - } - _ => Transformed::no(expr), }; Ok(transformed) @@ -161,21 +74,3 @@ fn is_func(expr: &Expr, func_name: &str) -> bool { func.name() == func_name } - -/// Returns true if expr is a function call with one of the specified names -fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool { - let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else { - return false; - }; - - func_names.contains(&func.name()) -} - -/// returns Some(col) if this is Expr::Column -fn as_col(expr: &Expr) -> Option<&Column> { - if let Expr::Column(c) = expr { - Some(c) - } else { - None - } -} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 8b64ccfb52cb6..a8af37ee6a37d 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,6 +17,7 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; +use datafusion_common::utils::list_ndims; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use datafusion_common::{ @@ -86,13 +87,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); - - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - op, - Box::new(right), - )); - + let expr = self.build_logical_expr(op, left, right, schema)?; eval_stack.push(expr); } } @@ -103,6 +98,69 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(expr) } + fn build_logical_expr( + &self, + op: Operator, + left: Expr, + right: Expr, + schema: &DFSchema, + ) -> Result { + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + if op == Operator::StringConcat { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. + // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. + if left_list_ndims + right_list_ndims == 0 { + // TODO: concat function ignore null, but string concat takes null into consideration + // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` + } else if left_list_ndims == right_list_ndims { + if let Some(udf) = self.context_provider.get_function_meta("array_concat") + { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![left, right], + ))); + } else { + return internal_err!("array_concat not found"); + } + } else if left_list_ndims > right_list_ndims { + if let Some(udf) = self.context_provider.get_function_meta("array_append") + { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![left, right], + ))); + } else { + return internal_err!("array_append not found"); + } + } else if left_list_ndims < right_list_ndims { + if let Some(udf) = + self.context_provider.get_function_meta("array_prepend") + { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![left, right], + ))); + } else { + return internal_err!("array_append not found"); + } + } + } + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + op, + Box::new(right), + ))) + } + /// Generate a relational expression from a SQL expression pub fn sql_to_expr( &self,