Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GroupsAccumulator accumulator for udaf #8892

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ datafusion = { path = "../datafusion/core", features = ["avro"] }
datafusion-common = { path = "../datafusion/common" }
datafusion-expr = { path = "../datafusion/expr" }
datafusion-optimizer = { path = "../datafusion/optimizer" }
datafusion-physical-expr = { workspace = true }
datafusion-sql = { path = "../datafusion/sql" }
env_logger = { workspace = true }
futures = { workspace = true }
Expand Down
214 changes: 207 additions & 7 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,35 @@
// under the License.

use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};

use arrow::{
array::{ArrayRef, Float32Array},
array::{
ArrayRef, AsArray, Float32Array, PrimitiveArray, PrimitiveBuilder, UInt64Array,
},
datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt64Type},
record_batch::RecordBatch,
};
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature};
use datafusion_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
/// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements
/// a function `accumulator` that returns the `Accumulator` instance.
///
/// To do so, we must implement the `AggregateUDFImpl` trait.
#[derive(Debug, Clone)]
struct GeoMeanUdf {
struct GeoMeanUdaf {
signature: Signature,
}

impl GeoMeanUdf {
/// Create a new instance of the GeoMeanUdf struct
impl GeoMeanUdaf {
/// Create a new instance of the GeoMeanUdaf struct
fn new() -> Self {
Self {
signature: Signature::exact(
Expand All @@ -52,7 +58,7 @@ impl GeoMeanUdf {
}
}

impl AggregateUDFImpl for GeoMeanUdf {
impl AggregateUDFImpl for GeoMeanUdaf {
/// We implement as_any so that we can downcast the AggregateUDFImpl trait object
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -82,6 +88,16 @@ impl AggregateUDFImpl for GeoMeanUdf {
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
Ok(vec![DataType::Float64, DataType::UInt32])
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend we add a note to accumulator() above about when this is used. Now that I write this maybe we should also put some of this information on the docstrings for AggregateUDF::groups_accumulator

-    /// This is the accumulator factory; DataFusion uses it to create new accumulators.
+   /// This is the accumulator factory for row wise accumulation; Even when `GroupsAccumulator`
+   /// is supported, DataFusion will use this row oriented
+   /// accumulator when the aggregate function is used as a window function
+   /// or when there are only aggregates (no GROUP BY columns) in the plan.

fn groups_accumulator_supported(&self) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add some context annotating this function for the example:

Suggested change
fn groups_accumulator_supported(&self) -> bool {
/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
/// which is used for cases when there are grouping columns in the query
fn groups_accumulator_supported(&self) -> bool {

true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(GeometricMeanGroupsAccumulator::new(
|pord: f64, count: u64| Ok(pord.powf(1.0 / count as f64)),
)))
}
}

/// A UDAF has state across multiple rows, and thus we require a `struct` with that state.
Expand Down Expand Up @@ -194,12 +210,196 @@ fn create_context() -> Result<SessionContext> {
Ok(ctx)
}

struct GeometricMeanGroupsAccumulator<F>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct GeometricMeanGroupsAccumulator<F>
/// Define a `GroupsAccumulator` for GeometricMean
/// which handles accumulator state for multiple groups at once.
/// This API is significantly more complicated than `Accumulator`, which manages
/// the state for a single group, but for queries with a large number of groups
/// can be significantly faster. See the `GroupsAccumulator` documentation for
/// more information.
struct GeometricMeanGroupsAccumulator<F>

where
F: Fn(
<Float64Type as ArrowPrimitiveType>::Native,
u64,
) -> Result<<Float64Type as ArrowPrimitiveType>::Native>
+ Send,
{
/// The type of the internal sum
prod_data_type: DataType,

/// The type of the returned sum
return_data_type: DataType,

/// Count per group (use u64 to make UInt64Array)
counts: Vec<u64>,

/// product per group, stored as the native type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// product per group, stored as the native type
/// product per group, stored as the native type (not `ScalarValue`)

prods: Vec<f64>,

/// Track nulls in the input / filters
null_state: NullState,

/// Function that computes the final geometric mean (value / count)
geo_mean_fn: F,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the example would be simpler if you removed the generics and simply inlined the definition of geo_mean_fn into the callsite in evaluate. The generics are needed for GroupsAccumulators that are specialized on type (e.g. a special one for Float32, Float64, etc).

}

impl<F> GeometricMeanGroupsAccumulator<F>
where
F: Fn(
<Float64Type as ArrowPrimitiveType>::Native,
u64,
) -> Result<<Float64Type as ArrowPrimitiveType>::Native>
+ Send,
{
fn new(geo_mean_fn: F) -> Self {
Self {
prod_data_type: DataType::Float64,
return_data_type: DataType::Float64,
counts: vec![],
prods: vec![],
null_state: NullState::new(),
geo_mean_fn,
}
}
}

impl<F> GroupsAccumulator for GeometricMeanGroupsAccumulator<F>
where
F: Fn(
<Float64Type as ArrowPrimitiveType>::Native,
u64,
) -> Result<<Float64Type as ArrowPrimitiveType>::Native>
+ Send,
{
fn update_batch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn update_batch(
/// Updates the accumulator state given input. DataFusion provides `group_indices`, the groups that each
/// row in `values` belongs to as well as an optional filter of which rows passed.
fn update_batch(

&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<Float64Type>();

// increment counts, update sums
self.counts.resize(total_num_groups, 0);
self.prods
.resize(total_num_groups, Float64Type::default_value());
self.null_state.accumulate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.null_state.accumulate(
/// Use the `NullState` structure to generate specialized code for null / non null input elements
self.null_state.accumulate(

group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);

self.counts[group_index] += 1;
},
);

Ok(())
}

fn merge_batch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn merge_batch(
/// Merge the results from previous invocations of `evaluate` into this accumulator's state
fn merge_batch(

&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 2, "two arguments to merge_batch");
// first batch is counts, second is partial sums
let partial_counts = values[0].as_primitive::<UInt64Type>();
let partial_prods = values[1].as_primitive::<Float64Type>();
// update counts with partial counts
self.counts.resize(total_num_groups, 0);
self.null_state.accumulate(
group_indices,
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
self.counts[group_index] += partial_count;
},
);

// update prods
self.prods
.resize(total_num_groups, Float64Type::default_value());
self.null_state.accumulate(
group_indices,
partial_prods,
opt_filter,
total_num_groups,
|group_index, new_value: <Float64Type as ArrowPrimitiveType>::Native| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);
},
);

Ok(())
}

fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
/// Generate output, as specififed by `emit_to` and update the intermediate state
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {

let counts = emit_to.take_needed(&mut self.counts);
let prods = emit_to.take_needed(&mut self.prods);
let nulls = self.null_state.build(emit_to);

assert_eq!(nulls.len(), prods.len());
assert_eq!(counts.len(), prods.len());

// don't evaluate geometric mean with null inputs to avoid errors on null values

let array: PrimitiveArray<Float64Type> = if nulls.null_count() > 0 {
let mut builder = PrimitiveBuilder::<Float64Type>::with_capacity(nulls.len());
let iter = prods.into_iter().zip(counts).zip(nulls.iter());

for ((prod, count), is_valid) in iter {
if is_valid {
builder.append_value((self.geo_mean_fn)(prod, count)?)
} else {
builder.append_null();
}
}
builder.finish()
} else {
let geo_mean: Vec<<Float64Type as ArrowPrimitiveType>::Native> = prods
.into_iter()
.zip(counts.into_iter())
.map(|(prod, count)| (self.geo_mean_fn)(prod, count))
.collect::<Result<Vec<_>>>()?;
PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy
.with_data_type(self.return_data_type.clone())
};

Ok(Arc::new(array))
}

// return arrays for counts and prods
fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
let nulls = self.null_state.build(emit_to);
let nulls = Some(nulls);

let counts = emit_to.take_needed(&mut self.counts);
let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy

let prods = emit_to.take_needed(&mut self.prods);
let prods = PrimitiveArray::<Float64Type>::new(prods.into(), nulls) // zero copy
.with_data_type(self.prod_data_type.clone());

Ok(vec![
Arc::new(counts) as ArrayRef,
Arc::new(prods) as ArrayRef,
])
}

fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<u64>()
+ self.prods.capacity() * std::mem::size_of::<Float64Type>()
}
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

// create the AggregateUDF
let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
ctx.register_udaf(geometric_mean.clone());

let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?;
Expand Down
Loading