diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index a698913fff541..aad67e4ecab6f 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -15,19 +15,20 @@ // specific language governing permissions and limitations // under the License. -//! Regx expressions -use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +//! Regex expressions + +use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; -use datafusion_common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result, -}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -82,14 +83,27 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo impl RegexpLikeFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ + TypeSignature::Exact(vec![Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8View]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8]), TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8View]), + TypeSignature::Exact(vec![Utf8, LargeUtf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View]), TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8, Utf8]), TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8, LargeUtf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Utf8]), ], Volatility::Immutable, ), @@ -120,6 +134,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { _ => Boolean, }) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { let len = args .iter() @@ -135,7 +150,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { .map(|arg| arg.clone().into_array(inferred_length)) .collect::>>()?; - let result = regexp_like_func(&args); + let result = regexp_like(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); @@ -149,15 +164,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { Some(get_regexp_like_doc()) } } -fn regexp_like_func(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Utf8 => regexp_like::(args), - DataType::LargeUtf8 => regexp_like::(args), - other => { - internal_err!("Unsupported data type {other:?} for function regexp_like") - } - } -} + /// Tests a string using a regular expression returning true if at /// least one match, false otherwise. /// @@ -200,47 +207,114 @@ fn regexp_like_func(args: &[ArrayRef]) -> Result { /// # Ok(()) /// # } /// ``` -pub fn regexp_like(args: &[ArrayRef]) -> Result { +pub fn regexp_like(args: &[ArrayRef]) -> Result { match args.len() { - 2 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags: Option<&GenericStringArray> = None; - let array = regexp::regexp_is_match(values, regex, flags) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + 2 => handle_regexp_like(&args[0], &args[1], None), 3 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags = as_generic_string_array::(&args[2])?; + let flags = args[2].as_string::(); if flags.iter().any(|s| s == Some("g")) { return plan_err!("regexp_like() does not support the \"global\" option"); } - let array = regexp::regexp_is_match(values, regex, Some(flags)) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + handle_regexp_like(&args[0], &args[1], Some(flags)) + }, other => exec_err!( - "regexp_like was called with {other} arguments. It requires at least 2 and at most 3." + "`regexp_like` was called with {other} arguments. It requires at least 2 and at most 3." ), } } + +fn handle_regexp_like( + values: &ArrayRef, + patterns: &ArrayRef, + flags: Option<&GenericStringArray>, +) -> Result { + let array = match (values.data_type(), patterns.data_type()) { + (Utf8View, Utf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, Utf8View) => { + let value = values.as_string_view(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, LargeUtf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function `regexp_like`" + ) + } + }; + + Ok(Arc::new(array) as ArrayRef) +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::BooleanBuilder; use arrow::array::StringArray; + use arrow::array::{BooleanBuilder, StringViewArray}; use crate::regex::regexplike::regexp_like; #[test] - fn test_case_sensitive_regexp_like() { + fn test_case_sensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = @@ -254,13 +328,33 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = regexp_like::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_sensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); assert_eq!(re.as_ref(), &expected); } #[test] - fn test_case_insensitive_regexp_like() { + fn test_case_insensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); @@ -274,9 +368,29 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) - .unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + let patterns = + StringViewArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -288,7 +402,7 @@ mod tests { let flags = StringArray::from(vec!["g"]); let re_err = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) .expect_err("unsupported flag should have failed"); assert_eq!(