-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support GroupsAccumulator accumulator for udaf #8892
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,29 +16,35 @@ | |||||||||||||||||
// under the License. | ||||||||||||||||||
|
||||||||||||||||||
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; | ||||||||||||||||||
use datafusion_physical_expr::NullState; | ||||||||||||||||||
use std::{any::Any, sync::Arc}; | ||||||||||||||||||
|
||||||||||||||||||
use arrow::{ | ||||||||||||||||||
array::{ArrayRef, Float32Array}, | ||||||||||||||||||
array::{ | ||||||||||||||||||
ArrayRef, AsArray, Float32Array, PrimitiveArray, PrimitiveBuilder, UInt64Array, | ||||||||||||||||||
}, | ||||||||||||||||||
datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt64Type}, | ||||||||||||||||||
record_batch::RecordBatch, | ||||||||||||||||||
}; | ||||||||||||||||||
use datafusion::error::Result; | ||||||||||||||||||
use datafusion::prelude::*; | ||||||||||||||||||
use datafusion_common::{cast::as_float64_array, ScalarValue}; | ||||||||||||||||||
use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; | ||||||||||||||||||
use datafusion_expr::{ | ||||||||||||||||||
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, | ||||||||||||||||||
}; | ||||||||||||||||||
|
||||||||||||||||||
/// This example shows how to use the full AggregateUDFImpl API to implement a user | ||||||||||||||||||
/// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements | ||||||||||||||||||
/// a function `accumulator` that returns the `Accumulator` instance. | ||||||||||||||||||
/// | ||||||||||||||||||
/// To do so, we must implement the `AggregateUDFImpl` trait. | ||||||||||||||||||
#[derive(Debug, Clone)] | ||||||||||||||||||
struct GeoMeanUdf { | ||||||||||||||||||
struct GeoMeanUdaf { | ||||||||||||||||||
signature: Signature, | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
impl GeoMeanUdf { | ||||||||||||||||||
/// Create a new instance of the GeoMeanUdf struct | ||||||||||||||||||
impl GeoMeanUdaf { | ||||||||||||||||||
/// Create a new instance of the GeoMeanUdaf struct | ||||||||||||||||||
fn new() -> Self { | ||||||||||||||||||
Self { | ||||||||||||||||||
signature: Signature::exact( | ||||||||||||||||||
|
@@ -52,7 +58,7 @@ impl GeoMeanUdf { | |||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
impl AggregateUDFImpl for GeoMeanUdf { | ||||||||||||||||||
impl AggregateUDFImpl for GeoMeanUdaf { | ||||||||||||||||||
/// We implement as_any so that we can downcast the AggregateUDFImpl trait object | ||||||||||||||||||
fn as_any(&self) -> &dyn Any { | ||||||||||||||||||
self | ||||||||||||||||||
|
@@ -82,6 +88,16 @@ impl AggregateUDFImpl for GeoMeanUdf { | |||||||||||||||||
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> { | ||||||||||||||||||
Ok(vec![DataType::Float64, DataType::UInt32]) | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
fn groups_accumulator_supported(&self) -> bool { | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be good to add some context annotating this function for the example:
Suggested change
|
||||||||||||||||||
true | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> { | ||||||||||||||||||
Ok(Box::new(GeometricMeanGroupsAccumulator::new( | ||||||||||||||||||
|pord: f64, count: u64| Ok(pord.powf(1.0 / count as f64)), | ||||||||||||||||||
))) | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. | ||||||||||||||||||
|
@@ -194,12 +210,196 @@ fn create_context() -> Result<SessionContext> { | |||||||||||||||||
Ok(ctx) | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
struct GeometricMeanGroupsAccumulator<F> | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
where | ||||||||||||||||||
F: Fn( | ||||||||||||||||||
<Float64Type as ArrowPrimitiveType>::Native, | ||||||||||||||||||
u64, | ||||||||||||||||||
) -> Result<<Float64Type as ArrowPrimitiveType>::Native> | ||||||||||||||||||
+ Send, | ||||||||||||||||||
{ | ||||||||||||||||||
/// The type of the internal sum | ||||||||||||||||||
prod_data_type: DataType, | ||||||||||||||||||
|
||||||||||||||||||
/// The type of the returned sum | ||||||||||||||||||
return_data_type: DataType, | ||||||||||||||||||
|
||||||||||||||||||
/// Count per group (use u64 to make UInt64Array) | ||||||||||||||||||
counts: Vec<u64>, | ||||||||||||||||||
|
||||||||||||||||||
/// product per group, stored as the native type | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
prods: Vec<f64>, | ||||||||||||||||||
|
||||||||||||||||||
/// Track nulls in the input / filters | ||||||||||||||||||
null_state: NullState, | ||||||||||||||||||
|
||||||||||||||||||
/// Function that computes the final geometric mean (value / count) | ||||||||||||||||||
geo_mean_fn: F, | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the example would be simpler if you removed the generics and simply inlined the definition of |
||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
impl<F> GeometricMeanGroupsAccumulator<F> | ||||||||||||||||||
where | ||||||||||||||||||
F: Fn( | ||||||||||||||||||
<Float64Type as ArrowPrimitiveType>::Native, | ||||||||||||||||||
u64, | ||||||||||||||||||
) -> Result<<Float64Type as ArrowPrimitiveType>::Native> | ||||||||||||||||||
+ Send, | ||||||||||||||||||
{ | ||||||||||||||||||
fn new(geo_mean_fn: F) -> Self { | ||||||||||||||||||
Self { | ||||||||||||||||||
prod_data_type: DataType::Float64, | ||||||||||||||||||
return_data_type: DataType::Float64, | ||||||||||||||||||
counts: vec![], | ||||||||||||||||||
prods: vec![], | ||||||||||||||||||
null_state: NullState::new(), | ||||||||||||||||||
geo_mean_fn, | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
impl<F> GroupsAccumulator for GeometricMeanGroupsAccumulator<F> | ||||||||||||||||||
where | ||||||||||||||||||
F: Fn( | ||||||||||||||||||
<Float64Type as ArrowPrimitiveType>::Native, | ||||||||||||||||||
u64, | ||||||||||||||||||
) -> Result<<Float64Type as ArrowPrimitiveType>::Native> | ||||||||||||||||||
+ Send, | ||||||||||||||||||
{ | ||||||||||||||||||
fn update_batch( | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
&mut self, | ||||||||||||||||||
values: &[ArrayRef], | ||||||||||||||||||
group_indices: &[usize], | ||||||||||||||||||
opt_filter: Option<&arrow::array::BooleanArray>, | ||||||||||||||||||
total_num_groups: usize, | ||||||||||||||||||
) -> Result<()> { | ||||||||||||||||||
assert_eq!(values.len(), 1, "single argument to update_batch"); | ||||||||||||||||||
let values = values[0].as_primitive::<Float64Type>(); | ||||||||||||||||||
|
||||||||||||||||||
// increment counts, update sums | ||||||||||||||||||
self.counts.resize(total_num_groups, 0); | ||||||||||||||||||
self.prods | ||||||||||||||||||
.resize(total_num_groups, Float64Type::default_value()); | ||||||||||||||||||
self.null_state.accumulate( | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
group_indices, | ||||||||||||||||||
values, | ||||||||||||||||||
opt_filter, | ||||||||||||||||||
total_num_groups, | ||||||||||||||||||
|group_index, new_value| { | ||||||||||||||||||
let prod = &mut self.prods[group_index]; | ||||||||||||||||||
*prod = prod.mul_wrapping(new_value); | ||||||||||||||||||
|
||||||||||||||||||
self.counts[group_index] += 1; | ||||||||||||||||||
}, | ||||||||||||||||||
); | ||||||||||||||||||
|
||||||||||||||||||
Ok(()) | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
fn merge_batch( | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
&mut self, | ||||||||||||||||||
values: &[ArrayRef], | ||||||||||||||||||
group_indices: &[usize], | ||||||||||||||||||
opt_filter: Option<&arrow::array::BooleanArray>, | ||||||||||||||||||
total_num_groups: usize, | ||||||||||||||||||
) -> Result<()> { | ||||||||||||||||||
assert_eq!(values.len(), 2, "two arguments to merge_batch"); | ||||||||||||||||||
// first batch is counts, second is partial sums | ||||||||||||||||||
let partial_counts = values[0].as_primitive::<UInt64Type>(); | ||||||||||||||||||
let partial_prods = values[1].as_primitive::<Float64Type>(); | ||||||||||||||||||
// update counts with partial counts | ||||||||||||||||||
self.counts.resize(total_num_groups, 0); | ||||||||||||||||||
self.null_state.accumulate( | ||||||||||||||||||
group_indices, | ||||||||||||||||||
partial_counts, | ||||||||||||||||||
opt_filter, | ||||||||||||||||||
total_num_groups, | ||||||||||||||||||
|group_index, partial_count| { | ||||||||||||||||||
self.counts[group_index] += partial_count; | ||||||||||||||||||
}, | ||||||||||||||||||
); | ||||||||||||||||||
|
||||||||||||||||||
// update prods | ||||||||||||||||||
self.prods | ||||||||||||||||||
.resize(total_num_groups, Float64Type::default_value()); | ||||||||||||||||||
self.null_state.accumulate( | ||||||||||||||||||
group_indices, | ||||||||||||||||||
partial_prods, | ||||||||||||||||||
opt_filter, | ||||||||||||||||||
total_num_groups, | ||||||||||||||||||
|group_index, new_value: <Float64Type as ArrowPrimitiveType>::Native| { | ||||||||||||||||||
let prod = &mut self.prods[group_index]; | ||||||||||||||||||
*prod = prod.mul_wrapping(new_value); | ||||||||||||||||||
}, | ||||||||||||||||||
); | ||||||||||||||||||
|
||||||||||||||||||
Ok(()) | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> { | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
let counts = emit_to.take_needed(&mut self.counts); | ||||||||||||||||||
let prods = emit_to.take_needed(&mut self.prods); | ||||||||||||||||||
let nulls = self.null_state.build(emit_to); | ||||||||||||||||||
|
||||||||||||||||||
assert_eq!(nulls.len(), prods.len()); | ||||||||||||||||||
assert_eq!(counts.len(), prods.len()); | ||||||||||||||||||
|
||||||||||||||||||
// don't evaluate geometric mean with null inputs to avoid errors on null values | ||||||||||||||||||
|
||||||||||||||||||
let array: PrimitiveArray<Float64Type> = if nulls.null_count() > 0 { | ||||||||||||||||||
let mut builder = PrimitiveBuilder::<Float64Type>::with_capacity(nulls.len()); | ||||||||||||||||||
let iter = prods.into_iter().zip(counts).zip(nulls.iter()); | ||||||||||||||||||
|
||||||||||||||||||
for ((prod, count), is_valid) in iter { | ||||||||||||||||||
if is_valid { | ||||||||||||||||||
builder.append_value((self.geo_mean_fn)(prod, count)?) | ||||||||||||||||||
} else { | ||||||||||||||||||
builder.append_null(); | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
builder.finish() | ||||||||||||||||||
} else { | ||||||||||||||||||
let geo_mean: Vec<<Float64Type as ArrowPrimitiveType>::Native> = prods | ||||||||||||||||||
.into_iter() | ||||||||||||||||||
.zip(counts.into_iter()) | ||||||||||||||||||
.map(|(prod, count)| (self.geo_mean_fn)(prod, count)) | ||||||||||||||||||
.collect::<Result<Vec<_>>>()?; | ||||||||||||||||||
PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy | ||||||||||||||||||
.with_data_type(self.return_data_type.clone()) | ||||||||||||||||||
}; | ||||||||||||||||||
|
||||||||||||||||||
Ok(Arc::new(array)) | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
// return arrays for counts and prods | ||||||||||||||||||
fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> { | ||||||||||||||||||
let nulls = self.null_state.build(emit_to); | ||||||||||||||||||
let nulls = Some(nulls); | ||||||||||||||||||
|
||||||||||||||||||
let counts = emit_to.take_needed(&mut self.counts); | ||||||||||||||||||
let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy | ||||||||||||||||||
|
||||||||||||||||||
let prods = emit_to.take_needed(&mut self.prods); | ||||||||||||||||||
let prods = PrimitiveArray::<Float64Type>::new(prods.into(), nulls) // zero copy | ||||||||||||||||||
.with_data_type(self.prod_data_type.clone()); | ||||||||||||||||||
|
||||||||||||||||||
Ok(vec![ | ||||||||||||||||||
Arc::new(counts) as ArrayRef, | ||||||||||||||||||
Arc::new(prods) as ArrayRef, | ||||||||||||||||||
]) | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
fn size(&self) -> usize { | ||||||||||||||||||
self.counts.capacity() * std::mem::size_of::<u64>() | ||||||||||||||||||
+ self.prods.capacity() * std::mem::size_of::<Float64Type>() | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
#[tokio::main] | ||||||||||||||||||
async fn main() -> Result<()> { | ||||||||||||||||||
let ctx = create_context()?; | ||||||||||||||||||
|
||||||||||||||||||
// create the AggregateUDF | ||||||||||||||||||
let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); | ||||||||||||||||||
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new()); | ||||||||||||||||||
ctx.register_udaf(geometric_mean.clone()); | ||||||||||||||||||
|
||||||||||||||||||
let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend we add a note to
accumulator()
above about when this is used. Now that I write this maybe we should also put some of this information on the docstrings forAggregateUDF::groups_accumulator