Skip to content

Commit

Permalink
Improve speed of median by implementing special GroupsAccumulator (
Browse files Browse the repository at this point in the history
…#13681)

* draft of `MedianGroupAccumulator`.

* impl `state`.

* impl rest methods of `MedianGroupsAccumulator`.

* improve comments.

* use `MedianGroupsAccumulator`.

* remove unused import.

* add `group_median_table` to test group median.

* complete group median test cases in aggregate slt.

* fix type of state.

* Clippy

* Fmt

* add fuzzy tests for median.

* fix decimal.

* fix clippy.

* improve comments.

* add median cases with nulls.

* Update datafusion/functions-aggregate/src/median.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* use `OffsetBuffer::new_unchecked` in `convert_to_state`.

* add todo.

* remove assert and switch to i32 try from.

* return error when try from failed.

---------

Co-authored-by: Daniël Heres <danielheres@gmail.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
3 people authored Jan 31, 2025
1 parent 11435de commit 53728b3
Show file tree
Hide file tree
Showing 3 changed files with 541 additions and 2 deletions.
20 changes: 20 additions & 0 deletions datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,26 @@ async fn test_count() {
.await;
}

#[tokio::test(flavor = "multi_thread")]
async fn test_median() {
let data_gen_config = baseline_config();

// Queries like SELECT median(a), median(distinct) FROM fuzz_table GROUP BY b
let query_builder = QueryBuilder::new()
.with_table_name("fuzz_table")
.with_aggregate_function("median")
.with_distinct_aggregate_function("median")
// median only works on numeric columns
.with_aggregate_arguments(data_gen_config.numeric_columns())
.set_group_by_columns(data_gen_config.all_columns());

AggregationFuzzerBuilder::from(data_gen_config)
.add_query_builder(query_builder)
.build()
.run()
.await;
}

/// Return a standard set of columns for testing data generation
///
/// Includes numeric and string types
Expand Down
262 changes: 260 additions & 2 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ use std::fmt::{Debug, Formatter};
use std::mem::{size_of, size_of_val};
use std::sync::Arc;

use arrow::array::{downcast_integer, ArrowNumericType};
use arrow::array::{
downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
PrimitiveBuilder,
};
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::{
array::{ArrayRef, AsArray},
datatypes::{
Expand All @@ -33,12 +37,17 @@ use arrow::array::Array;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};

use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue};
use datafusion_common::{
internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
Documentation, Signature, Volatility,
};
use datafusion_expr::{EmitTo, GroupsAccumulator};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_functions_aggregate_common::utils::Hashable;
use datafusion_macros::user_doc;

Expand Down Expand Up @@ -165,6 +174,45 @@ impl AggregateUDFImpl for Median {
}
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
!args.is_distinct
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let num_args = args.exprs.len();
if num_args != 1 {
return internal_err!(
"median should only have 1 arg, but found num args:{}",
args.exprs.len()
);
}

let dt = args.exprs[0].data_type(args.schema)?;

macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
};
}

downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
_ => Err(DataFusionError::NotImplemented(format!(
"MedianGroupsAccumulator not supported for {} with {}",
args.name,
dt,
))),
}
}

fn aliases(&self) -> &[String] {
&[]
}
Expand Down Expand Up @@ -230,6 +278,216 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
}
}

/// The median groups accumulator accumulates the raw input values
///
/// For calculating the accurate medians of groups, we need to store all values
/// of groups before final evaluation.
/// So values in each group will be stored in a `Vec<T>`, and the total group values
/// will be actually organized as a `Vec<Vec<T>>`.
///
#[derive(Debug)]
struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
data_type: DataType,
group_values: Vec<Vec<T::Native>>,
}

impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
pub fn new(data_type: DataType) -> Self {
Self {
data_type,
group_values: Vec::new(),
}
}
}

impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();

// Push the `not nulls + not filtered` row into its group
self.group_values.resize(total_num_groups, Vec::new());
accumulate(
group_indices,
values,
opt_filter,
|group_index, new_value| {
self.group_values[group_index].push(new_value);
},
);

Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
// Since aggregate filter should be applied in partial stage, in final stage there should be no filter
_opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");

// The merged values should be organized like as a `ListArray` which is nullable
// (input with nulls usually generated from `convert_to_state`), but `inner array` of
// `ListArray` is `non-nullable`.
//
// Following is the possible and impossible input `values`:
//
// # Possible values
// ```text
// group 0: [1, 2, 3]
// group 1: null (list array is nullable)
// group 2: [6, 7, 8]
// ...
// group n: [...]
// ```
//
// # Impossible values
// ```text
// group x: [1, 2, null] (values in list array is non-nullable)
// ```
//
let input_group_values = values[0].as_list::<i32>();

// Ensure group values big enough
self.group_values.resize(total_num_groups, Vec::new());

// Extend values to related groups
// TODO: avoid using iterator of the `ListArray`, this will lead to
// many calls of `slice` of its ``inner array`, and `slice` is not
// so efficient(due to the calculation of `null_count` for each `slice`).
group_indices
.iter()
.zip(input_group_values.iter())
.for_each(|(&group_index, values_opt)| {
if let Some(values) = values_opt {
let values = values.as_primitive::<T>();
self.group_values[group_index].extend(values.values().iter());
}
});

Ok(())
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
// Emit values
let emit_group_values = emit_to.take_needed(&mut self.group_values);

// Build offsets
let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
offsets.push(0);
let mut cur_len = 0_i32;
for group_value in &emit_group_values {
cur_len += group_value.len() as i32;
offsets.push(cur_len);
}
// TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
// but safety should be considered more carefully here(and I am not sure if it can get
// performance improvement when we introduce checks to keep the safety...).
//
// Can see more details in:
// https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
//
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));

// Build inner array
let flatten_group_values =
emit_group_values.into_iter().flatten().collect::<Vec<_>>();
let group_values_array =
PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
.with_data_type(self.data_type.clone());

// Build the result list array
let result_list_array = ListArray::new(
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
offsets,
Arc::new(group_values_array),
None,
);

Ok(vec![Arc::new(result_list_array)])
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
// Emit values
let emit_group_values = emit_to.take_needed(&mut self.group_values);

// Calculate median for each group
let mut evaluate_result_builder =
PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
for values in emit_group_values {
let median = calculate_median::<T>(values);
evaluate_result_builder.append_option(median);
}

Ok(Arc::new(evaluate_result_builder.finish()))
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
assert_eq!(values.len(), 1, "one argument to merge_batch");

let input_array = values[0].as_primitive::<T>();

// Directly convert the input array to states, each row will be
// seen as a respective group.
// For detail, the `input_array` will be converted to a `ListArray`.
// And if row is `not null + not filtered`, it will be converted to a list
// with only one element; otherwise, this row in `ListArray` will be set
// to null.

// Reuse values buffer in `input_array` to build `values` in `ListArray`
let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
.with_data_type(self.data_type.clone());

// `offsets` in `ListArray`, each row as a list element
let offset_end = i32::try_from(input_array.len()).map_err(|e| {
internal_datafusion_err!(
"cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
)
})?;
let offsets = (0..=offset_end).collect::<Vec<_>>();
// Safety: all checks in `OffsetBuffer::new` are ensured to pass
let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };

// `nulls` for converted `ListArray`
let nulls = filtered_null_mask(opt_filter, input_array);

let converted_list_array = ListArray::new(
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
offsets,
Arc::new(values),
nulls,
);

Ok(vec![Arc::new(converted_list_array)])
}

fn supports_convert_to_state(&self) -> bool {
true
}

fn size(&self) -> usize {
self.group_values
.iter()
.map(|values| values.capacity() * size_of::<T>())
.sum::<usize>()
// account for size of self.grou_values too
+ self.group_values.capacity() * size_of::<Vec<T>>()
}
}

/// The distinct median accumulator accumulates the raw input values
/// as `ScalarValue`s
///
Expand Down
Loading

0 comments on commit 53728b3

Please sign in to comment.