-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the schema mismatch between logical and physical for aggregate function, add AggregateUDFImpl::is_null
#11989
Changes from 10 commits
aed01f0
cbfefc6
b3fc2c8
20d0a5f
1132686
611092e
e732adc
ab38a5a
1d299eb
19a1ac7
984ced7
9b75540
6361bc4
794ce12
cb63514
9c12566
a42654c
e45d1bb
83ce363
3519e75
da30827
356faa8
043c332
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,7 +80,7 @@ use datafusion_expr::expr_rewriter::unnormalize_cols; | |
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; | ||
use datafusion_expr::{ | ||
DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, | ||
WindowFrame, WindowFrameBound, WriteOp, | ||
WindowFrame, WindowFrameBound, WindowFunctionDefinition, WriteOp, | ||
}; | ||
use datafusion_physical_expr::expressions::Literal; | ||
use datafusion_physical_expr::LexOrdering; | ||
|
@@ -670,6 +670,12 @@ impl DefaultPhysicalPlanner { | |
let input_exec = children.one()?; | ||
let physical_input_schema = input_exec.schema(); | ||
let logical_input_schema = input.as_ref().schema(); | ||
let physical_input_schema_from_logical: Arc<Schema> = | ||
logical_input_schema.as_ref().clone().into(); | ||
|
||
if physical_input_schema != physical_input_schema_from_logical { | ||
return internal_err!("Physical input schema should be the same as the one converted from logical input schema."); | ||
} | ||
|
||
let groups = self.create_grouping_physical_expr( | ||
group_expr, | ||
|
@@ -1503,6 +1509,11 @@ pub fn create_window_expr_with_name( | |
); | ||
} | ||
|
||
let is_nullable = match fun { | ||
WindowFunctionDefinition::AggregateUDF(udaf) => udaf.is_nullable(), | ||
_ => true, | ||
}; | ||
|
||
let window_frame = Arc::new(window_frame.clone()); | ||
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) | ||
== NullTreatment::IgnoreNulls; | ||
|
@@ -1515,6 +1526,7 @@ pub fn create_window_expr_with_name( | |
window_frame, | ||
physical_schema, | ||
ignore_nulls, | ||
is_nullable, | ||
) | ||
} | ||
other => plan_err!("Invalid window expression '{other:?}'"), | ||
|
@@ -1548,7 +1560,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( | |
e: &Expr, | ||
name: Option<String>, | ||
logical_input_schema: &DFSchema, | ||
_physical_input_schema: &Schema, | ||
physical_input_schema: &Schema, | ||
execution_props: &ExecutionProps, | ||
) -> Result<AggregateExprWithOptionalArgs> { | ||
match e { | ||
|
@@ -1599,14 +1611,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( | |
let ordering_reqs: Vec<PhysicalSortExpr> = | ||
physical_sort_exprs.clone().unwrap_or(vec![]); | ||
|
||
let schema: Schema = logical_input_schema.clone().into(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. workaround cleanup |
||
let agg_expr = | ||
AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) | ||
.order_by(ordering_reqs.to_vec()) | ||
.schema(Arc::new(schema)) | ||
.schema(Arc::new(physical_input_schema.to_owned())) | ||
.alias(name) | ||
.with_ignore_nulls(ignore_nulls) | ||
.with_distinct(*distinct) | ||
.with_nullable(func.is_nullable()) | ||
.build()?; | ||
|
||
(agg_expr, filter, physical_sort_exprs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -320,18 +320,29 @@ impl ExprSchemable for Expr { | |
} | ||
} | ||
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), | ||
Expr::ScalarFunction(ScalarFunction { func, args }) => { | ||
Ok(func.is_nullable(args, input_schema)) | ||
} | ||
Expr::AggregateFunction(AggregateFunction { func, .. }) => { | ||
// TODO: UDF should be able to customize nullability | ||
if func.name() == "count" { | ||
Ok(false) | ||
} else { | ||
Ok(true) | ||
} | ||
Ok(func.is_nullable()) | ||
} | ||
Expr::WindowFunction(WindowFunction { fun, .. }) => match fun { | ||
WindowFunctionDefinition::BuiltInWindowFunction(func) => { | ||
if func.name() == "ROW_NUMBER" | ||
|| func.name() == "RANK" | ||
|| func.name() == "NTILE" | ||
|| func.name() == "CUME_DIST" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if this list is complete. What about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Temporary code, there would be no name checking after #8709 is done. We can see that |
||
{ | ||
Ok(false) | ||
} else { | ||
Ok(true) | ||
} | ||
} | ||
WindowFunctionDefinition::AggregateUDF(func) => Ok(func.is_nullable()), | ||
_ => Ok(true), | ||
}, | ||
Expr::ScalarVariable(_, _) | ||
| Expr::TryCast { .. } | ||
| Expr::ScalarFunction(..) | ||
| Expr::WindowFunction { .. } | ||
| Expr::Unnest(_) | ||
| Expr::Placeholder(_) => Ok(true), | ||
Expr::IsNull(_) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ use std::vec; | |
|
||
use arrow::datatypes::{DataType, Field}; | ||
|
||
use datafusion_common::{exec_err, not_impl_err, Result}; | ||
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; | ||
|
||
use crate::expr::AggregateFunction; | ||
use crate::function::{ | ||
|
@@ -163,6 +163,10 @@ impl AggregateUDF { | |
self.inner.name() | ||
} | ||
|
||
pub fn is_nullable(&self) -> bool { | ||
self.inner.is_nullable() | ||
} | ||
|
||
/// Returns the aliases for this function. | ||
pub fn aliases(&self) -> &[String] { | ||
self.inner.aliases() | ||
|
@@ -257,6 +261,11 @@ impl AggregateUDF { | |
pub fn is_descending(&self) -> Option<bool> { | ||
self.inner.is_descending() | ||
} | ||
|
||
/// See [`AggregateUDFImpl::default_value`] for more details. | ||
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> { | ||
self.inner.default_value(data_type) | ||
} | ||
} | ||
|
||
impl<F> From<F> for AggregateUDF | ||
|
@@ -342,6 +351,11 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { | |
/// the arguments | ||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>; | ||
|
||
/// Whether the aggregate function is nullable | ||
jayzhan211 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fn is_nullable(&self) -> bool { | ||
true | ||
} | ||
|
||
/// Return a new [`Accumulator`] that aggregates values for a specific | ||
/// group during query execution. | ||
/// | ||
|
@@ -552,6 +566,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { | |
fn is_descending(&self) -> Option<bool> { | ||
None | ||
} | ||
|
||
/// Returns default value of the function given the input is Null | ||
jayzhan211 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// Most of the aggregate function return Null if input is Null, | ||
/// while `count` returns 0 if input is Null | ||
fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> { | ||
ScalarValue::try_from(data_type) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, it should at least failed when creating record batch. Since the schema indicates non-null but got null value. If you are saying prevent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can improve the docuemention? |
||
|
||
pub enum ReversedUDAF { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,9 +19,8 @@ | |
//! (built-in and custom) need to satisfy. | ||
|
||
use crate::order::AggregateOrderSensitivity; | ||
use arrow::datatypes::Field; | ||
use datafusion_common::exec_err; | ||
use datafusion_common::{not_impl_err, Result}; | ||
use arrow::datatypes::{DataType, Field}; | ||
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; | ||
use datafusion_expr_common::accumulator::Accumulator; | ||
use datafusion_expr_common::groups_accumulator::GroupsAccumulator; | ||
use datafusion_physical_expr_common::physical_expr::PhysicalExpr; | ||
|
@@ -171,6 +170,11 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> { | |
fn get_minmax_desc(&self) -> Option<(Field, bool)> { | ||
None | ||
} | ||
|
||
/// Returns default value of the function given the input is Null | ||
/// Most of the aggregate function return Null if input is Null, | ||
/// while `count` returns 0 if input is Null | ||
fn default_value(&self, data_type: &DataType) -> Result<ScalarValue>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍🏻 |
||
} | ||
|
||
/// Stores the physical expressions used inside the `AggregateExpr`. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,6 +121,10 @@ impl AggregateUDFImpl for Count { | |
Ok(DataType::Int64) | ||
} | ||
|
||
fn is_nullable(&self) -> bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
false | ||
} | ||
|
||
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
if args.is_distinct { | ||
Ok(vec![Field::new_list( | ||
|
@@ -133,7 +137,7 @@ impl AggregateUDFImpl for Count { | |
Ok(vec![Field::new( | ||
format_state_name(args.name, "count"), | ||
DataType::Int64, | ||
true, | ||
false, | ||
)]) | ||
} | ||
} | ||
|
@@ -283,6 +287,10 @@ impl AggregateUDFImpl for Count { | |
fn reverse_expr(&self) -> ReversedUDAF { | ||
ReversedUDAF::Identical | ||
} | ||
|
||
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> { | ||
Ok(ScalarValue::Int64(Some(0))) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉