diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c92660c7bbf4..f5a6860299ab 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -39,13 +39,20 @@ use datafusion_expr::{ and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, Volatility, }; -use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use datafusion_physical_expr::{ + create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval, +}; use crate::simplify_expressions::SimplifyInfo; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; + /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, + /// Guarantees about the values of columns. This is provided by the user + /// in [ExprSimplifier::with_guarantees()]. + guarantees: Vec<(Expr, NullableInterval)>, } pub const THRESHOLD_INLINE_INLIST: usize = 3; @@ -57,7 +64,10 @@ impl ExprSimplifier { /// /// [`SimplifyContext`]: crate::simplify_expressions::context::SimplifyContext pub fn new(info: S) -> Self { - Self { info } + Self { + info, + guarantees: vec![], + } } /// Simplifies this [`Expr`]`s as much as possible, evaluating @@ -121,6 +131,7 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut or_in_list_simplifier = OrInListSimplifier::new(); + let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); // TODO iterate until no changes are made during rewrite // (evaluating constants can enable new simplifications and @@ -129,6 +140,7 @@ impl ExprSimplifier { expr.rewrite(&mut const_evaluator)? .rewrite(&mut simplifier)? .rewrite(&mut or_in_list_simplifier)? + .rewrite(&mut guarantee_rewriter)? // run both passes twice to try an minimize simplifications that we missed .rewrite(&mut const_evaluator)? .rewrite(&mut simplifier) @@ -149,6 +161,65 @@ impl ExprSimplifier { expr.rewrite(&mut expr_rewrite) } + + /// Input guarantees about the values of columns. + /// + /// The guarantees can simplify expressions. For example, if a column `x` is + /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the + /// literal `true`. + /// + /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`, + /// where the [Expr] is a column reference and the [NullableInterval] + /// is an interval representing the known possible values of that column. + /// + /// ```rust + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; + /// use datafusion_physical_expr::execution_props::ExecutionProps; + /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; + /// use datafusion_optimizer::simplify_expressions::{ + /// ExprSimplifier, SimplifyContext}; + /// + /// let schema = Schema::new(vec![ + /// Field::new("x", DataType::Int64, false), + /// Field::new("y", DataType::UInt32, false), + /// Field::new("z", DataType::Int64, false), + /// ]) + /// .to_dfschema_ref().unwrap(); + /// + /// // Create the simplifier + /// let props = ExecutionProps::new(); + /// let context = SimplifyContext::new(&props) + /// .with_schema(schema); + /// + /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5) + /// let expr_x = col("x").gt_eq(lit(3_i64)); + /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32)); + /// let expr_z = col("z").gt(lit(5_i64)); + /// let expr = expr_x.and(expr_y).and(expr_z.clone()); + /// + /// let guarantees = vec![ + /// // x ∈ [3, 5] + /// ( + /// col("x"), + /// NullableInterval::NotNull { + /// values: Interval::make(Some(3_i64), Some(5_i64), (false, false)), + /// } + /// ), + /// // y = 3 + /// (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))), + /// ]; + /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees); + /// let output = simplifier.simplify(expr).unwrap(); + /// // Expression becomes: true AND true AND (z > 5), which simplifies to + /// // z > 5. + /// assert_eq!(output, expr_z); + /// ``` + pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self { + self.guarantees = guarantees; + self + } } #[allow(rustdoc::private_intra_doc_links)] @@ -1239,7 +1310,9 @@ mod tests { use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; use datafusion_expr::*; use datafusion_physical_expr::{ - execution_props::ExecutionProps, functions::make_scalar_function, + execution_props::ExecutionProps, + functions::make_scalar_function, + intervals::{Interval, NullableInterval}, }; // ------------------------------ @@ -2703,6 +2776,19 @@ mod tests { try_simplify(expr).unwrap() } + fn simplify_with_guarantee( + expr: Expr, + guarantees: Vec<(Expr, NullableInterval)>, + ) -> Expr { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(schema), + ) + .with_guarantees(guarantees); + simplifier.simplify(expr).unwrap() + } + fn expr_test_schema() -> DFSchemaRef { Arc::new( DFSchema::new_with_metadata( @@ -3166,4 +3252,89 @@ mod tests { let expr = not_ilike(null, "%"); assert_eq!(simplify(expr), lit_bool_null()); } + + #[test] + fn test_simplify_with_guarantee() { + // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b"))) + let expr_x = col("c3").gt(lit(3_i64)); + let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32)); + let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true); + let expr = expr_x.clone().and(expr_y.clone().or(expr_z)); + + // All guaranteed null + let guarantees = vec![ + (col("c3"), NullableInterval::from(ScalarValue::Int64(None))), + (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))), + (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))), + ]; + + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(output, lit_bool_null()); + + // All guaranteed false + let guarantees = vec![ + ( + col("c3"), + NullableInterval::NotNull { + values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + }, + ), + ( + col("c4"), + NullableInterval::from(ScalarValue::UInt32(Some(9))), + ), + ( + col("c1"), + NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))), + ), + ]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(output, lit(false)); + + // Guaranteed false or null -> no change. + let guarantees = vec![ + ( + col("c3"), + NullableInterval::MaybeNull { + values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + }, + ), + ( + col("c4"), + NullableInterval::MaybeNull { + values: Interval::make(Some(9_u32), Some(9_u32), (false, false)), + }, + ), + ( + col("c1"), + NullableInterval::NotNull { + values: Interval::make(Some("d"), Some("f"), (false, false)), + }, + ), + ]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(&output, &expr_x); + + // Sufficient true guarantees + let guarantees = vec![ + ( + col("c3"), + NullableInterval::from(ScalarValue::Int64(Some(9))), + ), + ( + col("c4"), + NullableInterval::from(ScalarValue::UInt32(Some(3))), + ), + ]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(output, lit(true)); + + // Only partially simplify + let guarantees = vec![( + col("c4"), + NullableInterval::from(ScalarValue::UInt32(Some(3))), + )]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(&output, &expr_x); + } } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs new file mode 100644 index 000000000000..5504d7d76e35 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -0,0 +1,520 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] +//! +//! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees +use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; +use std::collections::HashMap; + +use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; + +/// Rewrite expressions to incorporate guarantees. +/// +/// Guarantees are a mapping from an expression (which currently is always a +/// column reference) to a [NullableInterval]. The interval represents the known +/// possible values of the column. Using these known values, expressions are +/// rewritten so they can be simplified using `ConstEvaluator` and `Simplifier`. +/// +/// For example, if we know that a column is not null and has values in the +/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. +/// +/// See a full example in [`ExprSimplifier::with_guarantees()`]. +/// +/// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees +pub(crate) struct GuaranteeRewriter<'a> { + guarantees: HashMap<&'a Expr, &'a NullableInterval>, +} + +impl<'a> GuaranteeRewriter<'a> { + pub fn new( + guarantees: impl IntoIterator, + ) -> Self { + Self { + guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + } + } +} + +impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + if self.guarantees.is_empty() { + return Ok(expr); + } + + match &expr { + Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Ok(lit(true)), + Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), + _ => Ok(expr), + }, + Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Ok(lit(false)), + Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), + _ => Ok(expr), + }, + Expr::Between(Between { + expr: inner, + negated, + low, + high, + }) => { + if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + self.guarantees.get(inner.as_ref()), + low.as_ref(), + high.as_ref(), + ) { + let expr_interval = NullableInterval::NotNull { + values: Interval::new( + IntervalBound::new(low.clone(), false), + IntervalBound::new(high.clone(), false), + ), + }; + + let contains = expr_interval.contains(*interval)?; + + if contains.is_certainly_true() { + Ok(lit(!negated)) + } else if contains.is_certainly_false() { + Ok(lit(*negated)) + } else { + Ok(expr) + } + } else { + Ok(expr) + } + } + + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + // We only support comparisons for now + if !op.is_comparison_operator() { + return Ok(expr); + }; + + // Check if this is a comparison between a column and literal + let (col, op, value) = match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), + (Expr::Literal(value), Expr::Column(_)) => { + // If we can swap the op, we can simplify the expression + if let Some(op) = op.swap() { + (right, op, value) + } else { + return Ok(expr); + } + } + _ => return Ok(expr), + }; + + if let Some(col_interval) = self.guarantees.get(col.as_ref()) { + let result = + col_interval.apply_operator(&op, &value.clone().into())?; + if result.is_certainly_true() { + Ok(lit(true)) + } else if result.is_certainly_false() { + Ok(lit(false)) + } else { + Ok(expr) + } + } else { + Ok(expr) + } + } + + // Columns (if interval is collapsed to a single value) + Expr::Column(_) => { + if let Some(col_interval) = self.guarantees.get(&expr) { + if let Some(value) = col_interval.single_value() { + Ok(lit(value)) + } else { + Ok(expr) + } + } else { + Ok(expr) + } + } + + Expr::InList(InList { + expr: inner, + list, + negated, + }) => { + if let Some(interval) = self.guarantees.get(inner.as_ref()) { + // Can remove items from the list that don't match the guarantee + let new_list: Vec = list + .iter() + .filter_map(|expr| { + if let Expr::Literal(item) = expr { + match interval + .contains(&NullableInterval::from(item.clone())) + { + // If we know for certain the value isn't in the column's interval, + // we can skip checking it. + Ok(interval) if interval.is_certainly_false() => None, + Ok(_) => Some(Ok(expr.clone())), + Err(e) => Some(Err(e)), + } + } else { + Some(Ok(expr.clone())) + } + }) + .collect::>()?; + + Ok(Expr::InList(InList { + expr: inner.clone(), + list: new_list, + negated: *negated, + })) + } else { + Ok(expr) + } + } + + _ => Ok(expr), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::datatypes::DataType; + use datafusion_common::{tree_node::TreeNode, ScalarValue}; + use datafusion_expr::{col, lit, Operator}; + + #[test] + fn test_null_handling() { + // IsNull / IsNotNull can be rewritten to true / false + let guarantees = vec![ + // Note: AlwaysNull case handled by test_column_single_value test, + // since it's a special case of a column with a single value. + ( + col("x"), + NullableInterval::NotNull { + values: Default::default(), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // x IS NULL => guaranteed false + let expr = col("x").is_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(false)); + + // x IS NOT NULL => guaranteed true + let expr = col("x").is_not_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(true)); + } + + fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) + where + ScalarValue: From, + T: Clone, + { + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(rewriter).unwrap(); + let expected = lit(ScalarValue::from(expected_value.clone())); + assert_eq!( + output, expected, + "{} simplified to {}, but expected {}", + expr, output, expected + ); + } + } + + fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { + for expr in cases { + let output = expr.clone().rewrite(rewriter).unwrap(); + assert_eq!( + &output, expr, + "{} was simplified to {}, but expected it to be unchanged", + expr, output + ); + } + } + + #[test] + fn test_inequalities_non_null_bounded() { + let guarantees = vec![ + // x ∈ (1, 3] (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + }, + ), + ]; + + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + (col("x").lt_eq(lit(1)), false), + (col("x").lt_eq(lit(3)), true), + (col("x").gt(lit(3)), false), + (col("x").gt(lit(1)), true), + (col("x").eq(lit(0)), false), + (col("x").not_eq(lit(0)), true), + (col("x").between(lit(2), lit(5)), true), + (col("x").between(lit(2), lit(3)), true), + (col("x").between(lit(5), lit(10)), false), + (col("x").not_between(lit(2), lit(5)), false), + (col("x").not_between(lit(2), lit(3)), false), + (col("x").not_between(lit(5), lit(10)), true), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(5)), + }), + true, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").gt(lit(2)), + col("x").lt_eq(lit(2)), + col("x").eq(lit(2)), + col("x").not_eq(lit(2)), + col("x").between(lit(3), lit(5)), + col("x").not_between(lit(3), lit(10)), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_inequalities_non_null_unbounded() { + let guarantees = vec![ + // y ∈ [2021-01-01, ∞) (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::new( + IntervalBound::new(ScalarValue::Date32(Some(18628)), false), + IntervalBound::make_unbounded(DataType::Date32).unwrap(), + ), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false), + (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true), + (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true), + (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true), + ( + col("x").between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + false, + ), + ( + col("x").not_between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Date32(Some(17000)))), + }), + true, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit(ScalarValue::Date32(Some(19000)))), + col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + col("x").not_between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_inequalities_maybe_null() { + let guarantees = vec![ + // x ∈ ("abc", "def"]? (maybe null) + ( + col("x"), + NullableInterval::MaybeNull { + values: Interval::make(Some("abc"), Some("def"), (true, false)), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit("z")), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsNotDistinctFrom, + right: Box::new(lit("z")), + }), + false, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit("z")), + col("x").lt_eq(lit("z")), + col("x").gt(lit("a")), + col("x").gt_eq(lit("a")), + col("x").eq(lit("abc")), + col("x").not_eq(lit("a")), + col("x").between(lit("a"), lit("z")), + col("x").not_between(lit("a"), lit("z")), + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_column_single_value() { + let scalars = [ + ScalarValue::Null, + ScalarValue::Int32(Some(1)), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(None), + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("def".to_string())), + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ScalarValue::Decimal128(Some(1000), 19, 2), + ]; + + for scalar in scalars { + let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + let output = col("x").rewrite(&mut rewriter).unwrap(); + assert_eq!(output, Expr::Literal(scalar.clone())); + } + } + + #[test] + fn test_in_list() { + let guarantees = vec![ + // x ∈ [1, 10) (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(10_i32), (false, true)), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // These cases should be simplified so the list doesn't contain any + // values the guarantee says are outside the range. + // (column_name, starting_list, negated, expected_list) + let cases = &[ + // x IN (9, 11) => x IN (9) + ("x", vec![9, 11], false, vec![9]), + // x IN (10, 2) => x IN (2) + ("x", vec![10, 2], false, vec![2]), + // x NOT IN (9, 11) => x NOT IN (9) + ("x", vec![9, 11], true, vec![9]), + // x NOT IN (0, 22) => x NOT IN () + ("x", vec![0, 22], true, vec![]), + ]; + + for (column_name, starting_list, negated, expected_list) in cases { + let expr = col(*column_name).in_list( + starting_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(), + *negated, + ); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let expected_list = expected_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(); + assert_eq!( + output, + Expr::InList(InList { + expr: Box::new(col(*column_name)), + list: expected_list, + negated: *negated, + }) + ); + } + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index dfa0fe70433b..2cf6ed166cdd 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -17,6 +17,7 @@ pub mod context; pub mod expr_simplifier; +mod guarantees; mod or_in_list_simplifier; mod regex; pub mod simplify_exprs; diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 3f72ef588cb2..5501c8cae090 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -396,6 +396,22 @@ impl Interval { } } + /// Compute the logical negation of this (boolean) interval. + pub(crate) fn not(&self) -> Result { + if !matches!(self.get_datatype()?, DataType::Boolean) { + return internal_err!( + "Cannot apply logical negation to non-boolean interval" + ); + } + if self == &Interval::CERTAINLY_TRUE { + Ok(Interval::CERTAINLY_FALSE) + } else if self == &Interval::CERTAINLY_FALSE { + Ok(Interval::CERTAINLY_TRUE) + } else { + Ok(Interval::UNCERTAIN) + } + } + /// Compute the intersection of the interval with the given interval. /// If the intersection is empty, return None. pub(crate) fn intersect>( @@ -426,6 +442,23 @@ impl Interval { Ok(non_empty.then_some(Interval::new(lower, upper))) } + /// Decide if this interval is certainly contains, possibly contains, + /// or can't can't `other` by returning [true, true], + /// [false, true] or [false, false] respectively. + pub fn contains>(&self, other: T) -> Result { + match self.intersect(other.borrow())? { + Some(intersection) => { + // Need to compare with same bounds close-ness. + if intersection.close_bounds() == other.borrow().clone().close_bounds() { + Ok(Interval::CERTAINLY_TRUE) + } else { + Ok(Interval::UNCERTAIN) + } + } + None => Ok(Interval::CERTAINLY_FALSE), + } + } + /// Add the given interval (`other`) to this interval. Say we have /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. /// Note that this represents all possible values the sum can take if @@ -633,6 +666,7 @@ pub fn cardinality_ratio( pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { Operator::Eq => Ok(lhs.equal(rhs)), + Operator::NotEq => Ok(lhs.equal(rhs).not()?), Operator::Gt => Ok(lhs.gt(rhs)), Operator::GtEq => Ok(lhs.gt_eq(rhs)), Operator::Lt => Ok(lhs.lt(rhs)), @@ -667,6 +701,283 @@ fn calculate_cardinality_based_on_bounds( } } +/// An [Interval] that also tracks null status using a boolean interval. +/// +/// This represents values that may be in a particular range or be null. +/// +/// # Examples +/// +/// ``` +/// use arrow::datatypes::DataType; +/// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; +/// use datafusion_common::ScalarValue; +/// +/// // [1, 2) U {NULL} +/// NullableInterval::MaybeNull { +/// values: Interval::make(Some(1), Some(2), (false, true)), +/// }; +/// +/// // (0, ∞) +/// NullableInterval::NotNull { +/// values: Interval::make(Some(0), None, (true, true)), +/// }; +/// +/// // {NULL} +/// NullableInterval::Null { datatype: DataType::Int32 }; +/// +/// // {4} +/// NullableInterval::from(ScalarValue::Int32(Some(4))); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NullableInterval { + /// The value is always null in this interval + /// + /// This is typed so it can be used in physical expressions, which don't do + /// type coercion. + Null { datatype: DataType }, + /// The value may or may not be null in this interval. If it is non null its value is within + /// the specified values interval + MaybeNull { values: Interval }, + /// The value is definitely not null in this interval and is within values + NotNull { values: Interval }, +} + +impl Default for NullableInterval { + fn default() -> Self { + NullableInterval::MaybeNull { + values: Interval::default(), + } + } +} + +impl Display for NullableInterval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), + Self::MaybeNull { values } => { + write!(f, "NullableInterval: {} U {{NULL}}", values) + } + Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + } + } +} + +impl From for NullableInterval { + /// Create an interval that represents a single value. + fn from(value: ScalarValue) -> Self { + if value.is_null() { + Self::Null { + datatype: value.data_type(), + } + } else { + Self::NotNull { + values: Interval::new( + IntervalBound::new(value.clone(), false), + IntervalBound::new(value, false), + ), + } + } + } +} + +impl NullableInterval { + /// Get the values interval, or None if this interval is definitely null. + pub fn values(&self) -> Option<&Interval> { + match self { + Self::Null { .. } => None, + Self::MaybeNull { values } | Self::NotNull { values } => Some(values), + } + } + + /// Get the data type + pub fn get_datatype(&self) -> Result { + match self { + Self::Null { datatype } => Ok(datatype.clone()), + Self::MaybeNull { values } | Self::NotNull { values } => { + values.get_datatype() + } + } + } + + /// Return true if the value is definitely true (and not null). + pub fn is_certainly_true(&self) -> bool { + match self { + Self::Null { .. } | Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, + } + } + + /// Return true if the value is definitely false (and not null). + pub fn is_certainly_false(&self) -> bool { + match self { + Self::Null { .. } => false, + Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, + } + } + + /// Perform logical negation on a boolean nullable interval. + fn not(&self) -> Result { + match self { + Self::Null { datatype } => Ok(Self::Null { + datatype: datatype.clone(), + }), + Self::MaybeNull { values } => Ok(Self::MaybeNull { + values: values.not()?, + }), + Self::NotNull { values } => Ok(Self::NotNull { + values: values.not()?, + }), + } + } + + /// Apply the given operator to this interval and the given interval. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; + /// + /// // 4 > 3 -> true + /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); + /// + /// // [1, 3) > NULL -> NULL + /// let lhs = NullableInterval::NotNull { + /// values: Interval::make(Some(1), Some(3), (false, true)), + /// }; + /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); + /// + /// // [1, 3] > [2, 4] -> [false, true] + /// let lhs = NullableInterval::NotNull { + /// values: Interval::make(Some(1), Some(3), (false, false)), + /// }; + /// let rhs = NullableInterval::NotNull { + /// values: Interval::make(Some(2), Some(4), (false, false)), + /// }; + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// // Both inputs are valid (non-null), so result must be non-null + /// assert_eq!(result, NullableInterval::NotNull { + /// // Uncertain whether inequality is true or false + /// values: Interval::UNCERTAIN, + /// }); + /// + /// ``` + pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { + match op { + Operator::IsDistinctFrom => { + let values = match (self, rhs) { + // NULL is distinct from NULL -> False + (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, + // x is distinct from y -> x != y, + // if at least one of them is never null. + (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { + let lhs_values = self.values(); + let rhs_values = rhs.values(); + match (lhs_values, rhs_values) { + (Some(lhs_values), Some(rhs_values)) => { + lhs_values.equal(rhs_values).not()? + } + (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, + (None, None) => unreachable!("Null case handled above"), + } + } + _ => Interval::UNCERTAIN, + }; + // IsDistinctFrom never returns null. + Ok(Self::NotNull { values }) + } + Operator::IsNotDistinctFrom => self + .apply_operator(&Operator::IsDistinctFrom, rhs) + .map(|i| i.not())?, + _ => { + if let (Some(left_values), Some(right_values)) = + (self.values(), rhs.values()) + { + let values = apply_operator(op, left_values, right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else if op.is_comparison_operator() { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } else { + Ok(Self::Null { + datatype: self.get_datatype()?, + }) + } + } + } + } + + /// Determine if this interval contains the given interval. Returns a boolean + /// interval that is [true, true] if this interval is a superset of the + /// given interval, [false, false] if this interval is disjoint from the + /// given interval, and [false, true] otherwise. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { + let values = left_values.contains(right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } + } + + /// If the interval has collapsed to a single value, return that value. + /// + /// Otherwise returns None. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(None)); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); + /// + /// let interval = NullableInterval::MaybeNull { + /// values: Interval::make(Some(1), Some(4), (false, true)), + /// }; + /// assert_eq!(interval.single_value(), None); + /// ``` + pub fn single_value(&self) -> Option { + match self { + Self::Null { datatype } => { + Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) + } + Self::MaybeNull { values } | Self::NotNull { values } + if values.lower.value == values.upper.value + && !values.lower.is_unbounded() => + { + Some(values.lower.value.clone()) + } + _ => None, + } + } +} + #[cfg(test)] mod tests { use super::next_value;