diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 726eb07ff653..027134a3c5b9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -82,6 +82,9 @@ jobs: - name: Check function packages (encoding_expressions) run: cargo check --no-default-features --features=encoding_expressions -p datafusion + - name: Check function packages (math_expressions) + run: cargo check --no-default-features --features=math_expressions -p datafusion + - name: Check function packages (array_expressions) run: cargo check --no-default-features --features=array_expressions -p datafusion diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 59d568cd7f77..ba7b0614d11d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -47,6 +47,7 @@ default = ["array_expressions", "crypto_expressions", "encoding_expressions", "r encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] +math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet"] pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = ["datafusion-physical-expr/regex_expressions", "datafusion-optimizer/regex_expressions"] diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 274a6fa9c2dc..60761259f992 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -23,7 +23,6 @@ use std::fmt; use std::str::FromStr; use std::sync::{Arc, OnceLock}; -use crate::nullif::SUPPORTED_NULLIF_TYPES; use crate::signature::TIMEZONE_WILDCARD; use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; @@ -83,8 +82,6 @@ pub enum BuiltinScalarFunction { Gcd, /// lcm, Least common multiple Lcm, - /// isnan - Isnan, /// iszero Iszero, /// ln, Natural logarithm @@ -233,8 +230,6 @@ pub enum BuiltinScalarFunction { Ltrim, /// md5 MD5, - /// nullif - NullIf, /// octet_length OctetLength, /// random @@ -384,7 +379,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Floor => Volatility::Immutable, BuiltinScalarFunction::Gcd => Volatility::Immutable, - BuiltinScalarFunction::Isnan => Volatility::Immutable, BuiltinScalarFunction::Iszero => Volatility::Immutable, BuiltinScalarFunction::Lcm => Volatility::Immutable, BuiltinScalarFunction::Ln => Volatility::Immutable, @@ -456,7 +450,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Lower => Volatility::Immutable, BuiltinScalarFunction::Ltrim => Volatility::Immutable, BuiltinScalarFunction::MD5 => Volatility::Immutable, - BuiltinScalarFunction::NullIf => Volatility::Immutable, BuiltinScalarFunction::OctetLength => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, BuiltinScalarFunction::RegexpLike => Volatility::Immutable, @@ -726,11 +719,6 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "ltrim") } BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), - BuiltinScalarFunction::NullIf => { - // NULLIF has two args and they might get coerced, get a preview of this - let coerced_types = data_types(input_expr_types, &self.signature()); - coerced_types.map(|typs| typs[0].clone()) - } BuiltinScalarFunction::OctetLength => { utf8_to_int_type(&input_expr_types[0], "octet_length") } @@ -871,7 +859,7 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, - BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => Ok(Boolean), + BuiltinScalarFunction::Iszero => Ok(Boolean), BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), @@ -1261,9 +1249,6 @@ impl BuiltinScalarFunction { self.volatility(), ), - BuiltinScalarFunction::NullIf => { - Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), self.volatility()) - } BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Uuid => Signature::exact(vec![], self.volatility()), @@ -1368,12 +1353,10 @@ impl BuiltinScalarFunction { vec![Int32, Int64, UInt32, UInt64, Utf8], self.volatility(), ), - BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => { - Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], - self.volatility(), - ) - } + BuiltinScalarFunction::Iszero => Signature::one_of( + vec![Exact(vec![Float32]), Exact(vec![Float64])], + self.volatility(), + ), } } @@ -1439,7 +1422,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => &["factorial"], BuiltinScalarFunction::Floor => &["floor"], BuiltinScalarFunction::Gcd => &["gcd"], - BuiltinScalarFunction::Isnan => &["isnan"], BuiltinScalarFunction::Iszero => &["iszero"], BuiltinScalarFunction::Lcm => &["lcm"], BuiltinScalarFunction::Ln => &["ln"], @@ -1462,7 +1444,6 @@ impl BuiltinScalarFunction { // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], - BuiltinScalarFunction::NullIf => &["nullif"], // string functions BuiltinScalarFunction::Ascii => &["ascii"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9c20763c89dd..28a03d141bd6 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -565,7 +565,6 @@ scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple"); scalar_expr!(Log2, log2, num, "base 2 logarithm"); scalar_expr!(Log10, log10, num, "base 10 logarithm"); scalar_expr!(Ln, ln, num, "natural logarithm"); -scalar_expr!(NullIf, nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); scalar_expr!( @@ -926,12 +925,6 @@ scalar_expr!(Now, now, ,"returns current timestamp in nanoseconds, using the sam scalar_expr!(CurrentTime, current_time, , "returns current UTC time as a [`DataType::Time64`] value"); scalar_expr!(MakeDate, make_date, year month day, "make a date from year, month and day component parts"); scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); -scalar_expr!( - Isnan, - isnan, - num, - "returns true if a given number is +NaN or -NaN otherwise returns false" -); scalar_expr!( Iszero, iszero, @@ -1363,7 +1356,6 @@ mod test { test_unary_scalar_expr!(Ln, ln); test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Nanvl, nanvl, x, y); - test_scalar_expr!(Isnan, isnan, input); test_scalar_expr!(Iszero, iszero, input); test_scalar_expr!(Ascii, ascii, input); diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index c29535456327..8c73ae5ae709 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -30,7 +30,6 @@ mod built_in_function; mod built_in_window_function; mod columnar_value; mod literal; -mod nullif; mod operator; mod partition_evaluator; mod signature; @@ -74,7 +73,6 @@ pub use function::{ pub use groups_accumulator::{EmitTo, GroupsAccumulator}; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::*; -pub use nullif::SUPPORTED_NULLIF_TYPES; pub use operator::Operator; pub use partition_evaluator::PartitionEvaluator; pub use signature::{ diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 6d4a716e2e8e..7109261cc78f 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -29,10 +29,14 @@ authors = { workspace = true } rust-version = { workspace = true } [features] +# enable core functions +core_expressions = [] # Enable encoding by default so the doctests work. In general don't automatically enable all packages. -default = ["encoding_expressions"] -# enable the encode/decode functions +default = ["core_expressions", "encoding_expressions", "math_expressions"] +# enable encode/decode functions encoding_expressions = ["base64", "hex"] +# enable math functions +math_expressions = [] [lib] diff --git a/datafusion/expr/src/nullif.rs b/datafusion/functions/src/core/mod.rs similarity index 58% rename from datafusion/expr/src/nullif.rs rename to datafusion/functions/src/core/mod.rs index f17bd793b8fc..9aab4bd450d1 100644 --- a/datafusion/expr/src/nullif.rs +++ b/datafusion/functions/src/core/mod.rs @@ -15,23 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; +//! "core" DataFusion functions + +mod nullif; + +// create UDFs +make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); + +// Export the functions out of this package, both as expr_fn as well as a list of functions +export_functions!( + (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression.") +); -/// Currently supported types by the nullif function. -/// The order of these types correspond to the order on which coercion applies -/// This should thus be from least informative to most informative -pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/functions/src/core/nullif.rs similarity index 78% rename from datafusion/physical-expr/src/expressions/nullif.rs rename to datafusion/functions/src/core/nullif.rs index dcd883f92965..1007f349f7a4 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -15,17 +15,89 @@ // specific language governing permissions and limitations // under the License. +//! Encoding expressions + +use arrow::{ + datatypes::DataType, +}; +use datafusion_common::{internal_err, Result, DataFusionError}; +use datafusion_expr::{ColumnarValue}; + +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; use arrow::array::Array; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; -use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::ColumnarValue; +use datafusion_common::{ ScalarValue}; + +#[derive(Debug)] +pub(super) struct NullIfFunc { + signature: Signature, +} + +/// Currently supported types by the nullif function. +/// The order of these types correspond to the order on which coercion applies +/// This should thus be from least informative to most informative +static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Utf8, + DataType::LargeUtf8, +]; + + +impl NullIfFunc { + pub fn new() -> Self { + Self { + signature: + Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), + Volatility::Immutable, + ) + } + } +} + +impl ScalarUDFImpl for NullIfFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "nullif" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // NULLIF has two args and they might get coerced, get a preview of this + let coerced_types = datafusion_expr::type_coercion::functions::data_types(arg_types, &self.signature); + coerced_types.map(|typs| typs[0].clone()) + .map_err(|e| e.context("Failed to coerce arguments for NULLIF") + ) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + nullif_func(args) + } +} + + /// Implements NULLIF(expr1, expr2) /// Args: 0 - left expr is any array /// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. /// -pub fn nullif_func(args: &[ColumnarValue]) -> Result { +fn nullif_func(args: &[ColumnarValue]) -> Result { if args.len() != 2 { return internal_err!( "{:?} args were supplied but NULLIF takes exactly two args", diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 91a5c510f0f9..981174c141d6 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -84,21 +84,34 @@ use log::debug; #[macro_use] pub mod macros; +make_package!(core, "core_expressions", "Core datafusion expressions"); + make_package!( encoding, "encoding_expressions", "Hex and binary `encode` and `decode` functions." ); +make_package!(math, "math_expressions", "Mathematical functions."); + /// Fluent-style API for creating `Expr`s pub mod expr_fn { + #[cfg(feature = "core_expressions")] + pub use super::core::expr_fn::*; #[cfg(feature = "encoding_expressions")] pub use super::encoding::expr_fn::*; + #[cfg(feature = "math_expressions")] + pub use super::math::expr_fn::*; } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - encoding::functions().into_iter().try_for_each(|udf| { + let mut all_functions = core::functions() + .into_iter() + .chain(encoding::functions()) + .chain(math::functions()); + + all_functions.try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; if let Some(existing_udf) = existing_udf { debug!("Overwrite existing UDF: {}", existing_udf.name()); diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 1931ee279421..5debcbda30cc 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -121,3 +121,42 @@ macro_rules! make_package { } }; } + +/// Invokes a function on each element of an array and returns the result as a new array +/// +/// $ARG: ArrayRef +/// $NAME: name of the function (for error messages) +/// $ARGS_TYPE: the type of array to cast the argument to +/// $RETURN_TYPE: the type of array to return +/// $FUNC: the function to apply to each element of $ARG +/// +macro_rules! make_function_scalar_inputs_return_type { + ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ + let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); + + arg.iter() + .map(|a| match a { + Some(a) => Some($FUNC(a)), + _ => None, + }) + .collect::<$RETURN_TYPE>() + }}; +} + +/// Downcast an argument to a specific array type, returning an internal error +/// if the cast fails +/// +/// $ARG: ArrayRef +/// $NAME: name of the argument (for error messages) +/// $ARRAY_TYPE: the type of array to cast the argument to +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + std::any::type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs new file mode 100644 index 000000000000..67d2d957ea1f --- /dev/null +++ b/datafusion/functions/src/math/mod.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! "core" DataFusion functions + +mod nans; + +// create UDFs +make_udf_function!(nans::IsNanFunc, ISNAN, isnan); + +// Export the functions out of this package, both as expr_fn as well as a list of functions +export_functions!( + (isnan, num, "returns true if a given number is +NaN or -NaN otherwise returns false") +); + diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs new file mode 100644 index 000000000000..228039e5f6c7 --- /dev/null +++ b/datafusion/functions/src/math/nans.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Encoding expressions + +use arrow::{ + datatypes::DataType, +}; +use datafusion_common::{internal_err, Result, DataFusionError}; +use datafusion_expr::ColumnarValue; + +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; +use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; + +#[derive(Debug)] +pub(super) struct IsNanFunc { + signature: Signature, +} + +impl IsNanFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: + Signature::one_of( + vec![Exact(vec![Float32]), Exact(vec![Float64])], + Volatility::Immutable, + ) + } + } +} + +impl ScalarUDFImpl for IsNanFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "isnan" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => { + Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float64Array, + BooleanArray, + { f64::is_nan } + )) + }, + DataType::Float32 => { + Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float32Array, + BooleanArray, + { f32::is_nan } + )) + }, + other => return internal_err!("Unsupported data type {other:?} for function isnan"), + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 007a03985f45..09e908586c5b 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -32,7 +32,6 @@ mod literal; mod negative; mod no_op; mod not; -mod nullif; mod try_cast; /// Module with some convenient methods used in expression building @@ -92,7 +91,6 @@ pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; -pub use nullif::nullif_func; pub use try_cast::{try_cast, TryCastExpr}; /// returns the name of the state diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 81f433611af8..5ea073845314 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -33,9 +33,8 @@ use crate::execution_props::ExecutionProps; use crate::sort_properties::SortProperties; use crate::{ - array_expressions, conditional_expressions, datetime_expressions, - expressions::nullif_func, math_expressions, string_expressions, struct_expressions, - PhysicalExpr, ScalarFunctionExpr, + array_expressions, conditional_expressions, datetime_expressions, math_expressions, + string_expressions, struct_expressions, PhysicalExpr, ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, @@ -282,9 +281,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Gcd => { Arc::new(|args| make_scalar_function_inner(math_expressions::gcd)(args)) } - BuiltinScalarFunction::Isnan => { - Arc::new(|args| make_scalar_function_inner(math_expressions::isnan)(args)) - } BuiltinScalarFunction::Iszero => { Arc::new(|args| make_scalar_function_inner(math_expressions::iszero)(args)) } @@ -592,7 +588,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Digest => { Arc::new(invoke_if_crypto_expressions_feature_flag!(digest, "digest")) } - BuiltinScalarFunction::NullIf => Arc::new(nullif_func), BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d50336bd0f4c..10eb81e50127 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -580,7 +580,7 @@ enum ScalarFunction { Lower = 33; Ltrim = 34; MD5 = 35; - NullIf = 36; + // 36 was NullIf OctetLength = 37; Random = 38; RegexpReplace = 39; @@ -655,7 +655,7 @@ enum ScalarFunction { ArrayReplaceAll = 110; Nanvl = 111; Flatten = 112; - Isnan = 113; + // 113 was IsNan Iszero = 114; ArrayEmpty = 115; ArrayPopBack = 116; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 592c2609b678..665ea3580e21 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22339,7 +22339,6 @@ impl serde::Serialize for ScalarFunction { Self::Lower => "Lower", Self::Ltrim => "Ltrim", Self::Md5 => "MD5", - Self::NullIf => "NullIf", Self::OctetLength => "OctetLength", Self::Random => "Random", Self::RegexpReplace => "RegexpReplace", @@ -22413,7 +22412,6 @@ impl serde::Serialize for ScalarFunction { Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", Self::Flatten => "Flatten", - Self::Isnan => "Isnan", Self::Iszero => "Iszero", Self::ArrayEmpty => "ArrayEmpty", Self::ArrayPopBack => "ArrayPopBack", @@ -22483,7 +22481,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Lower", "Ltrim", "MD5", - "NullIf", "OctetLength", "Random", "RegexpReplace", @@ -22557,7 +22554,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceAll", "Nanvl", "Flatten", - "Isnan", "Iszero", "ArrayEmpty", "ArrayPopBack", @@ -22656,7 +22652,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Lower" => Ok(ScalarFunction::Lower), "Ltrim" => Ok(ScalarFunction::Ltrim), "MD5" => Ok(ScalarFunction::Md5), - "NullIf" => Ok(ScalarFunction::NullIf), "OctetLength" => Ok(ScalarFunction::OctetLength), "Random" => Ok(ScalarFunction::Random), "RegexpReplace" => Ok(ScalarFunction::RegexpReplace), @@ -22730,7 +22725,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), "Flatten" => Ok(ScalarFunction::Flatten), - "Isnan" => Ok(ScalarFunction::Isnan), "Iszero" => Ok(ScalarFunction::Iszero), "ArrayEmpty" => Ok(ScalarFunction::ArrayEmpty), "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a65df74bbcf3..b455d2a14ade 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2667,7 +2667,7 @@ pub enum ScalarFunction { Lower = 33, Ltrim = 34, Md5 = 35, - NullIf = 36, + /// 36 was NullIf OctetLength = 37, Random = 38, RegexpReplace = 39, @@ -2742,7 +2742,7 @@ pub enum ScalarFunction { ArrayReplaceAll = 110, Nanvl = 111, Flatten = 112, - Isnan = 113, + /// 113 was IsNan Iszero = 114, ArrayEmpty = 115, ArrayPopBack = 116, @@ -2809,7 +2809,6 @@ impl ScalarFunction { ScalarFunction::Lower => "Lower", ScalarFunction::Ltrim => "Ltrim", ScalarFunction::Md5 => "MD5", - ScalarFunction::NullIf => "NullIf", ScalarFunction::OctetLength => "OctetLength", ScalarFunction::Random => "Random", ScalarFunction::RegexpReplace => "RegexpReplace", @@ -2883,7 +2882,6 @@ impl ScalarFunction { ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Flatten => "Flatten", - ScalarFunction::Isnan => "Isnan", ScalarFunction::Iszero => "Iszero", ScalarFunction::ArrayEmpty => "ArrayEmpty", ScalarFunction::ArrayPopBack => "ArrayPopBack", @@ -2947,7 +2945,6 @@ impl ScalarFunction { "Lower" => Some(Self::Lower), "Ltrim" => Some(Self::Ltrim), "MD5" => Some(Self::Md5), - "NullIf" => Some(Self::NullIf), "OctetLength" => Some(Self::OctetLength), "Random" => Some(Self::Random), "RegexpReplace" => Some(Self::RegexpReplace), @@ -3021,7 +3018,6 @@ impl ScalarFunction { "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), "Flatten" => Some(Self::Flatten), - "Isnan" => Some(Self::Isnan), "Iszero" => Some(Self::Iszero), "ArrayEmpty" => Some(Self::ArrayEmpty), "ArrayPopBack" => Some(Self::ArrayPopBack), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 07590e0d93ae..ecb92db46b2c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -58,15 +58,15 @@ use datafusion_expr::{ current_time, date_bin, date_part, date_trunc, degrees, digest, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, initcap, - instr, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, + instr, iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, - radians, random, regexp_like, regexp_match, regexp_replace, repeat, replace, reverse, - right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, - split_part, sqrt, starts_with, string_to_array, strpos, struct_fun, substr, - substr_index, substring, tan, tanh, to_hex, translate, trim, trunc, upper, uuid, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, + lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians, + random, regexp_like, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, + substring, tan, tanh, to_hex, translate, trim, trunc, upper, uuid, AggregateFunction, + Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, + GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -512,7 +512,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Range => Self::Range, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, - ScalarFunction::NullIf => Self::NullIf, ScalarFunction::DatePart => Self::DatePart, ScalarFunction::DateTrunc => Self::DateTrunc, ScalarFunction::DateBin => Self::DateBin, @@ -568,7 +567,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, - ScalarFunction::Isnan => Self::Isnan, ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, ScalarFunction::OverLay => Self::OverLay, @@ -1560,10 +1558,6 @@ pub fn parse_expr( ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), - ScalarFunction::NullIf => Ok(nullif( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::Digest => Ok(digest( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1789,7 +1783,6 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), - ScalarFunction::Isnan => Ok(isnan(parse_expr(&args[0], registry)?)), ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), ScalarFunction::ArrowTypeof => { Ok(arrow_typeof(parse_expr(&args[0], registry)?)) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7a8fbde07b6f..9060d0243272 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1491,7 +1491,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Range => Self::Range, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, - BuiltinScalarFunction::NullIf => Self::NullIf, BuiltinScalarFunction::DatePart => Self::DatePart, BuiltinScalarFunction::DateTrunc => Self::DateTrunc, BuiltinScalarFunction::DateBin => Self::DateBin, @@ -1546,7 +1545,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, - BuiltinScalarFunction::Isnan => Self::Isnan, BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, BuiltinScalarFunction::OverLay => Self::OverLay, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 3c72bf334e7f..2124a5224a76 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; #[cfg(test)] use std::collections::HashMap; use std::{sync::Arc, vec}; @@ -29,7 +30,8 @@ use datafusion_common::{ use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, - AggregateUDF, ScalarUDF, TableSource, WindowUDF, + AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource, + Volatility, WindowUDF, }; use datafusion_sql::{ parser::DFParser, @@ -2671,13 +2673,62 @@ fn logical_plan_with_dialect_and_options( dialect: &dyn Dialect, options: ParserOptions, ) -> Result { - let context = MockContextProvider::default(); + let context = MockContextProvider::default().with_udf(make_udf( + "nullif", + vec![DataType::Int32, DataType::Int32], + DataType::Int32, + )); + let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; planner.statement_to_plan(ast.pop_front().unwrap()) } +fn make_udf(name: &'static str, args: Vec, return_type: DataType) -> ScalarUDF { + ScalarUDF::new_from_impl(DummyUDF::new(name, args, return_type)) +} + +/// Mocked UDF +#[derive(Debug)] +struct DummyUDF { + name: &'static str, + signature: Signature, + return_type: DataType, +} + +impl DummyUDF { + fn new(name: &'static str, args: Vec, return_type: DataType) -> Self { + Self { + name, + signature: Signature::exact(args, Volatility::Immutable), + return_type, + } + } +} + +impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } +} + /// Create logical plan, write with formatter, compare to expected output fn quick_test(sql: &str, expected: &str) { let plan = logical_plan(sql).unwrap(); @@ -2724,6 +2775,7 @@ fn prepare_stmt_replace_params_quick_test( #[derive(Default)] struct MockContextProvider { options: ConfigOptions, + udfs: HashMap>, udafs: HashMap>, } @@ -2731,6 +2783,11 @@ impl MockContextProvider { fn options_mut(&mut self) -> &mut ConfigOptions { &mut self.options } + + fn with_udf(mut self, udf: ScalarUDF) -> Self { + self.udfs.insert(udf.name().to_string(), Arc::new(udf)); + self + } } impl ContextProvider for MockContextProvider { @@ -2823,8 +2880,8 @@ impl ContextProvider for MockContextProvider { } } - fn get_function_meta(&self, _name: &str) -> Option> { - None + fn get_function_meta(&self, name: &str) -> Option> { + self.udfs.get(name).map(Arc::clone) } fn get_aggregate_meta(&self, name: &str) -> Option> { diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index e3b2610e51be..3a23f3615d08 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -84,7 +84,7 @@ statement error Error during planning: No function matches the given name and ar SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tnullif\(Boolean/UInt8/UInt16/UInt32/UInt64/Int8/Int16/Int32/Int64/Float32/Float64/Utf8/LargeUtf8, Boolean/UInt8/UInt16/UInt32/UInt64/Int8/Int16/Int32/Int64/Float32/Float64/Utf8/LargeUtf8\) +statement error DataFusion error: Failed to coerce arguments for NULLIF SELECT nullif(1); # error message for wrong function signature (Exact: exact number of args of an exact type) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 129eb6508b4d..eb2f05dc742f 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1840,9 +1840,10 @@ statement error Error during planning: No function matches the given name and ar SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tnullif\(Boolean/UInt8/UInt16/UInt32/UInt64/Int8/Int16/Int32/Int64/Float32/Float64/Utf8/LargeUtf8, Boolean/UInt8/UInt16/UInt32/UInt64/Int8/Int16/Int32/Int64/Float32/Float64/Utf8/LargeUtf8\) +statement error DataFusion error: Failed to coerce arguments for NULLIF SELECT nullif(1); + # error message for wrong function signature (Exact: exact number of args of an exact type) statement error Error during planning: No function matches the given name and argument types 'pi\(Float64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpi\(\) SELECT pi(3.14);