Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the schema mismatch between logical and physical for aggregate function, add AggregateUDFImpl::is_null #11989

Merged
merged 23 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ pub fn bounded_window_exec(
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
false,
true,
)
.unwrap()],
input.clone(),
Expand Down
20 changes: 16 additions & 4 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan,
WindowFrame, WindowFrameBound, WriteOp,
WindowFrame, WindowFrameBound, WindowFunctionDefinition, WriteOp,
};
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::LexOrdering;
Expand Down Expand Up @@ -670,6 +670,12 @@ impl DefaultPhysicalPlanner {
let input_exec = children.one()?;
let physical_input_schema = input_exec.schema();
let logical_input_schema = input.as_ref().schema();
let physical_input_schema_from_logical: Arc<Schema> =
logical_input_schema.as_ref().clone().into();

if physical_input_schema != physical_input_schema_from_logical {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

return internal_err!("Physical input schema should be the same as the one converted from logical input schema.");
}

let groups = self.create_grouping_physical_expr(
group_expr,
Expand Down Expand Up @@ -1503,6 +1509,11 @@ pub fn create_window_expr_with_name(
);
}

let is_nullable = match fun {
WindowFunctionDefinition::AggregateUDF(udaf) => udaf.is_nullable(),
_ => true,
};

let window_frame = Arc::new(window_frame.clone());
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
Expand All @@ -1515,6 +1526,7 @@ pub fn create_window_expr_with_name(
window_frame,
physical_schema,
ignore_nulls,
is_nullable,
)
}
other => plan_err!("Invalid window expression '{other:?}'"),
Expand Down Expand Up @@ -1548,7 +1560,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
e: &Expr,
name: Option<String>,
logical_input_schema: &DFSchema,
_physical_input_schema: &Schema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<AggregateExprWithOptionalArgs> {
match e {
Expand Down Expand Up @@ -1599,14 +1611,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);

let schema: Schema = logical_input_schema.clone().into();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workaround cleanup

let agg_expr =
AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec())
.order_by(ordering_reqs.to_vec())
.schema(Arc::new(schema))
.schema(Arc::new(physical_input_schema.to_owned()))
.alias(name)
.with_ignore_nulls(ignore_nulls)
.with_distinct(*distinct)
.with_nullable(func.is_nullable())
.build()?;

(agg_expr, filter, physical_sort_exprs)
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
Arc::new(window_frame),
&extended_schema,
false,
true,
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
vec![window_expr],
Expand Down Expand Up @@ -677,6 +678,7 @@ async fn run_window_test(
Arc::new(window_frame.clone()),
&extended_schema,
false,
true,
)?],
exec1,
vec![],
Expand All @@ -695,6 +697,7 @@ async fn run_window_test(
Arc::new(window_frame.clone()),
&extended_schema,
false,
true,
)?],
exec2,
vec![],
Expand Down
27 changes: 19 additions & 8 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,18 +320,29 @@ impl ExprSchemable for Expr {
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarFunction(ScalarFunction { func, args }) => {
Ok(func.is_nullable(args, input_schema))
}
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
// TODO: UDF should be able to customize nullability
if func.name() == "count" {
Ok(false)
} else {
Ok(true)
}
Ok(func.is_nullable())
}
Expr::WindowFunction(WindowFunction { fun, .. }) => match fun {
WindowFunctionDefinition::BuiltInWindowFunction(func) => {
if func.name() == "ROW_NUMBER"
|| func.name() == "RANK"
|| func.name() == "NTILE"
|| func.name() == "CUME_DIST"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this list is complete. What about DenseRank and PercentRank?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary code, there would be no name checking after #8709 is done. We can see that Row_Number is gone

{
Ok(false)
} else {
Ok(true)
}
}
WindowFunctionDefinition::AggregateUDF(func) => Ok(func.is_nullable()),
_ => Ok(true),
},
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction(..)
| Expr::WindowFunction { .. }
| Expr::Unnest(_)
| Expr::Placeholder(_) => Ok(true),
Expr::IsNull(_)
Expand Down
12 changes: 7 additions & 5 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2015,10 +2015,9 @@ impl Projection {
/// produced by the projection operation. If the schema computation is successful,
/// the `Result` will contain the schema; otherwise, it will contain an error.
pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result<Arc<DFSchema>> {
let mut schema = DFSchema::new_with_metadata(
exprlist_to_fields(exprs, input)?,
input.schema().metadata().clone(),
)?;
let metadata = input.schema().metadata().clone();
let mut schema =
DFSchema::new_with_metadata(exprlist_to_fields(exprs, input)?, metadata)?;
schema = schema.with_functional_dependencies(calc_func_dependencies_for_project(
exprs, input,
)?)?;
Expand Down Expand Up @@ -2659,7 +2658,10 @@ impl Aggregate {

qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);

let schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?;
let schema = DFSchema::new_with_metadata(
qualified_fields,
input.schema().metadata().clone(),
)?;

Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema))
}
Expand Down
23 changes: 22 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::vec;

use arrow::datatypes::{DataType, Field};

use datafusion_common::{exec_err, not_impl_err, Result};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};

use crate::expr::AggregateFunction;
use crate::function::{
Expand Down Expand Up @@ -163,6 +163,10 @@ impl AggregateUDF {
self.inner.name()
}

pub fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}

/// Returns the aliases for this function.
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
Expand Down Expand Up @@ -257,6 +261,11 @@ impl AggregateUDF {
pub fn is_descending(&self) -> Option<bool> {
self.inner.is_descending()
}

/// See [`AggregateUDFImpl::default_value`] for more details.
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
self.inner.default_value(data_type)
}
}

impl<F> From<F> for AggregateUDF
Expand Down Expand Up @@ -342,6 +351,11 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// the arguments
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;

/// Whether the aggregate function is nullable
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
fn is_nullable(&self) -> bool {
true
}

/// Return a new [`Accumulator`] that aggregates values for a specific
/// group during query execution.
///
Expand Down Expand Up @@ -552,6 +566,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn is_descending(&self) -> Option<bool> {
None
}

/// Returns default value of the function given the input is Null
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
/// Most of the aggregate function return Null if input is Null,
/// while `count` returns 0 if input is Null
fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
ScalarValue::try_from(data_type)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an AggregateUDFImpl overrides is_nullable() to return false but does not set the default_value(), it seems that it would indicate it is not nullable, yet its default_value() would return null. Is there a way to prevent this behavior? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, it should at least failed when creating record batch. Since the schema indicates non-null but got null value.

If you are saying prevent panic, it is a good question, I don't have answer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can improve the docuemention?


pub enum ReversedUDAF {
Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ impl ScalarUDF {
self.inner.invoke(args)
}

pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
self.inner.is_nullable(args, schema)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
Expand Down Expand Up @@ -416,6 +420,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
self.return_type(arg_types)
}

fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
true
}

/// Invoke the function on `args`, returning the appropriate result
///
/// The function will be invoked passed with the slice of [`ColumnarValue`]
Expand Down
10 changes: 7 additions & 3 deletions datafusion/functions-aggregate-common/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
//! (built-in and custom) need to satisfy.

use crate::order::AggregateOrderSensitivity;
use arrow::datatypes::Field;
use datafusion_common::exec_err;
use datafusion_common::{not_impl_err, Result};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::groups_accumulator::GroupsAccumulator;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
Expand Down Expand Up @@ -171,6 +170,11 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {
fn get_minmax_desc(&self) -> Option<(Field, bool)> {
None
}

/// Returns default value of the function given the input is Null
/// Most of the aggregate function return Null if input is Null,
/// while `count` returns 0 if input is Null
fn default_value(&self, data_type: &DataType) -> Result<ScalarValue>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

}

/// Stores the physical expressions used inside the `AggregateExpr`.
Expand Down
10 changes: 9 additions & 1 deletion datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ impl AggregateUDFImpl for Count {
Ok(DataType::Int64)
}

fn is_nullable(&self) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

false
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
if args.is_distinct {
Ok(vec![Field::new_list(
Expand All @@ -133,7 +137,7 @@ impl AggregateUDFImpl for Count {
Ok(vec![Field::new(
format_state_name(args.name, "count"),
DataType::Int64,
true,
false,
)])
}
}
Expand Down Expand Up @@ -283,6 +287,10 @@ impl AggregateUDFImpl for Count {
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}

fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(0)))
}
}

#[derive(Debug)]
Expand Down
9 changes: 7 additions & 2 deletions datafusion/functions/src/core/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, is_not_null, is_null};
use arrow::datatypes::DataType;

use datafusion_common::{exec_err, Result};
use datafusion_common::{exec_err, ExprSchema, Result};
use datafusion_expr::type_coercion::binary::type_union_resolution;
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};

#[derive(Debug)]
Expand Down Expand Up @@ -63,6 +63,11 @@ impl ScalarUDFImpl for CoalesceFunc {
Ok(arg_types[0].clone())
}

// If all the element in coalesce is non-null, the result is non-null
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true))
}

/// coalesce evaluates to the first value which is not NULL
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
// do not accept 0 arguments.
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ mod tests {
.build()?;

let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\
\n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\
\n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
Expand Down
Loading
Loading