diff --git a/Cargo.toml b/Cargo.toml index 1e493f864c03..cb6e23cdc752 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,4 +70,4 @@ lto = false opt-level = 3 overflow-checks = false panic = 'unwind' -rpath = false +rpath = false \ No newline at end of file diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 149c6e2c5bdf..3db230bf0554 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1197,6 +1197,7 @@ dependencies = [ "itertools 0.11.0", "lazy_static", "libc", + "log", "md-5", "paste", "petgraph", diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 4bf5f664450e..e1583ebe0548 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -49,6 +49,7 @@ use std::sync::Arc; mod bounded_aggregate_stream; mod no_grouping; mod row_hash; +mod row_hash2; mod utils; pub use datafusion_expr::AggregateFunction; @@ -58,6 +59,7 @@ use datafusion_physical_expr::utils::{ get_finer_ordering, ordering_satisfy_requirement_concrete, }; +use self::row_hash2::GroupedHashAggregateStream2; use super::DisplayAs; /// Hash aggregate modes @@ -198,6 +200,7 @@ impl PartialEq for PhysicalGroupBy { enum StreamType { AggregateStream(AggregateStream), GroupedHashAggregateStream(GroupedHashAggregateStream), + GroupedHashAggregateStream2(GroupedHashAggregateStream2), BoundedAggregate(BoundedAggregateStream), } @@ -206,6 +209,7 @@ impl From for SendableRecordBatchStream { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream), + StreamType::GroupedHashAggregateStream2(stream) => Box::pin(stream), StreamType::BoundedAggregate(stream) => Box::pin(stream), } } @@ -713,12 +717,23 @@ impl AggregateExec { partition, aggregation_ordering, )?)) + } else if self.use_poc_group_by() { + Ok(StreamType::GroupedHashAggregateStream2( + GroupedHashAggregateStream2::new(self, context, partition)?, + )) } else { Ok(StreamType::GroupedHashAggregateStream( GroupedHashAggregateStream::new(self, context, partition)?, )) } } + + /// Returns true if we should use the POC group by stream + /// TODO: check for actually supported aggregates, etc + fn use_poc_group_by(&self) -> bool { + //info!("AAL Checking POC group by: {self:#?}"); + true + } } impl DisplayAs for AggregateExec { @@ -984,7 +999,7 @@ fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { Arc::new(Schema::new(group_fields)) } -/// returns physical expressions to evaluate against a batch +/// returns physical expressions for arguments to evaluate against a batch /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions /// * Final: columns of `AggregateExpr::state_fields()` @@ -1787,10 +1802,10 @@ mod tests { assert!(matches!(stream, StreamType::AggregateStream(_))); } 1 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHashAggregateStream2(_))); } 2 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHashAggregateStream2(_))); } _ => panic!("Unknown version: {version}"), } diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index beb70f1b4c55..4741f181f9ea 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -17,6 +17,7 @@ //! Hash aggregation through row format +use log::info; use std::cmp::min; use std::ops::Range; use std::sync::Arc; @@ -110,6 +111,8 @@ pub(crate) struct GroupedHashAggregateStream { /// first element in the array corresponds to normal accumulators /// second element in the array corresponds to row accumulators indices: [Vec>; 2], + // buffer to be reused to store hashes + hashes_buffer: Vec, } impl GroupedHashAggregateStream { @@ -119,6 +122,7 @@ impl GroupedHashAggregateStream { context: Arc, partition: usize, ) -> Result { + info!("Creating GroupedHashAggregateStream"); let agg_schema = Arc::clone(&agg.schema); let agg_group_by = agg.group_by.clone(); let agg_filter_expr = agg.filter_expr.clone(); @@ -229,6 +233,7 @@ impl GroupedHashAggregateStream { scalar_update_factor, row_group_skip_position: 0, indices: [normal_agg_indices, row_agg_indices], + hashes_buffer: vec![], }) } } @@ -322,15 +327,17 @@ impl GroupedHashAggregateStream { let mut groups_with_rows = vec![]; // 1.1 Calculate the group keys for the group values - let mut batch_hashes = vec![0; n_rows]; - create_hashes(group_values, &self.random_state, &mut batch_hashes)?; + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(group_values, &self.random_state, batch_hashes)?; let AggregationState { map, group_states, .. } = &mut self.aggr_state; - for (row, hash) in batch_hashes.into_iter().enumerate() { - let entry = map.get_mut(hash, |(_hash, group_idx)| { + for (row, hash) in batch_hashes.iter_mut().enumerate() { + let entry = map.get_mut(*hash, |(_hash, group_idx)| { // verify that a group that we are inserting with hash is // actually the same key value as the group in // existing_idx (aka group_values @ row) @@ -385,7 +392,7 @@ impl GroupedHashAggregateStream { // for hasher function, use precomputed hash value map.insert_accounted( - (hash, group_idx), + (*hash, group_idx), |(hash, _group_index)| *hash, allocated, ); diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs new file mode 100644 index 000000000000..335ce40754fa --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Hash aggregation through row format +//! +//! POC demonstration of GroupByHashApproach + +use datafusion_physical_expr::{ + AggregateExpr, GroupsAccumulator, GroupsAccumulatorAdapter, +}; +use log::debug; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::vec; + +use ahash::RandomState; +use arrow::row::{RowConverter, Rows, SortField}; +use datafusion_physical_expr::hash_utils::create_hashes; +use futures::ready; +use futures::stream::{Stream, StreamExt}; + +use crate::physical_plan::aggregates::{ + evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, + PhysicalGroupBy, +}; +use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; +use crate::physical_plan::{aggregates, PhysicalExpr}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use arrow::array::*; +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::TaskContext; +use hashbrown::raw::RawTable; + +#[derive(Debug, Clone)] +/// This object tracks the aggregation phase (input/output) +pub(crate) enum ExecutionState { + ReadingInput, + /// When producing output, the remaining rows to output are stored + /// here and are sliced off as needed in batch_size chunks + ProducingOutput(RecordBatch), + Done, +} + +use super::AggregateExec; + +/// Hash based Grouping Aggregator +/// +/// # Design Goals +/// +/// This structure is designed so that much can be vectorized (done in +/// a tight loop) as possible +/// +/// # Architecture +/// +/// ```text +/// +/// stores "group stores group values, internally stores aggregate +/// indexes" in arrow_row format values, for all groups +/// +/// ┌─────────────┐ ┌────────────┐ ┌──────────────┐ ┌──────────────┐ +/// │ ┌─────┐ │ │ ┌────────┐ │ │┌────────────┐│ │┌────────────┐│ +/// │ │ 5 │ │ ┌────┼▶│ "A" │ │ ││accumulator ││ ││accumulator ││ +/// │ ├─────┤ │ │ │ ├────────┤ │ ││ 0 ││ ││ N ││ +/// │ │ 9 │ │ │ │ │ "Z" │ │ ││ ┌────────┐ ││ ││ ┌────────┐ ││ +/// │ └─────┘ │ │ │ └────────┘ │ ││ │ state │ ││ ││ │ state │ ││ +/// │ ... │ │ │ │ ││ │┌─────┐ │ ││ ... ││ │┌─────┐ │ ││ +/// │ ┌─────┐ │ │ │ ... │ ││ │├─────┤ │ ││ ││ │├─────┤ │ ││ +/// │ │ 1 │───┼─┘ │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││ +/// │ ├─────┤ │ │ │ ││ │ │ ││ ││ │ │ ││ +/// │ │ 13 │───┼─┐ │ ┌────────┐ │ ││ │ ... │ ││ ││ │ ... │ ││ +/// │ └─────┘ │ └────┼▶│ "Q" │ │ ││ │ │ ││ ││ │ │ ││ +/// └─────────────┘ │ └────────┘ │ ││ │┌─────┐ │ ││ ││ │┌─────┐ │ ││ +/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││ +/// └────────────┘ ││ └────────┘ ││ ││ └────────┘ ││ +/// │└────────────┘│ │└────────────┘│ +/// └──────────────┘ └──────────────┘ +/// +/// map group_values accumulators +/// (Hash Table) +/// +/// ``` +/// +/// For example, given a query like `COUNT(x), SUM(y) ... GROUP BY z`, +/// `group_values` will store the distinct values of `z`. There will +/// be one accumulator for `COUNT(x)`, specialized for the data type +/// of `x` and one accumulator for `SUM(y)`, specialized for the data +/// type of `y`. +/// +/// # Description +/// +/// The hash table stores "group indices", one for each (distinct) +/// group value. +/// +/// The group values are stored in [`Self::group_values`] at the +/// corresponding group index. +/// +/// The accumulator state (e.g partial sums) is managed by and stored +/// by a [`GroupsAccumulator`] accumulator. There is one accumulator +/// per aggregate expression (COUNT, AVG, etc) in the +/// query. Internally, each `GroupsAccumulator` manages the state for +/// multiple groups, and is passed `group_indexes` during update. Note +/// The accumulator state is not managed by this operator (e.g in the +/// hash table). +pub(crate) struct GroupedHashAggregateStream2 { + schema: SchemaRef, + input: SendableRecordBatchStream, + mode: AggregateMode, + + /// Accumulators, one for each `AggregateExpr` in the query + /// + /// For example, if the query has aggregates, `SUM(x)`, + /// `COUNT(y)`, there will be two accumulators, each one + /// specialized for that partcular aggregate and its input types + accumulators: Vec>, + + /// Arguments to pass to accumulator. + aggregate_arguments: Vec>>, + + /// Optional filter expression to evaluate, one for each for + /// accumulator. If present, only those rows for which the filter + /// evaluate to true should be included in the aggregate results. + /// + /// For example, for an aggregate like `SUM(x FILTER x > 100)`, + /// the filter expression is `x > 100`. + filter_expressions: Vec>>, + + /// Converter for each row + row_converter: RowConverter, + + /// GROUP BY expressions + group_by: PhysicalGroupBy, + + /// The memory reservation for this grouping + reservation: MemoryReservation, + + /// Logically maps group values to a group_index in + /// [`Self::group_values`] and in each accumulator + /// + /// Uses the raw API of hashbrown to avoid actually storing the + /// keys (group values) in the table + /// + /// keys: u64 hashes of the GroupValue + /// values: (hash, group_index) + map: RawTable<(u64, usize)>, + + /// The actual group by values, stored in arrow [`Row`] format. The + /// `group_values[i]` holds the group value for group_index `i`. + /// + /// The row format is used to compare group keys quickly. This is + /// especially important for multi-column group keys. + /// + /// [`Row`]: arrow::row::Row + group_values: Rows, + + /// scratch space for the current input [`RecordBatch`] being + /// processed. Reused across batches here to avoid reallocations + current_group_indices: Vec, + + /// Tracks if this stream is generating input or output + exec_state: ExecutionState, + + /// Execution metrics + baseline_metrics: BaselineMetrics, + + /// Random state for creating hashes + random_state: RandomState, + + /// max rows in output RecordBatches + batch_size: usize, +} + +impl GroupedHashAggregateStream2 { + /// Create a new GroupedHashAggregateStream2 + pub fn new( + agg: &AggregateExec, + context: Arc, + partition: usize, + ) -> Result { + debug!("Creating GroupedHashAggregateStream2"); + let agg_schema = Arc::clone(&agg.schema); + let agg_group_by = agg.group_by.clone(); + let agg_filter_expr = agg.filter_expr.clone(); + + let batch_size = context.session_config().batch_size(); + let input = agg.input.execute(partition, Arc::clone(&context))?; + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + let timer = baseline_metrics.elapsed_compute().timer(); + + let aggregate_exprs = agg.aggr_expr.clone(); + + // arguments for each aggregate, one vec of expressions per + // aggregate + let aggregate_arguments = aggregates::aggregate_expressions( + &agg.aggr_expr, + &agg.mode, + agg_group_by.expr.len(), + )?; + + let filter_expressions = match agg.mode { + AggregateMode::Partial | AggregateMode::Single => agg_filter_expr, + AggregateMode::Final | AggregateMode::FinalPartitioned => { + vec![None; agg.aggr_expr.len()] + } + }; + + // Instantiate the accumulators + let accumulators: Vec<_> = aggregate_exprs + .iter() + .map(create_group_accumulator) + .collect::>()?; + + let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + let row_converter = RowConverter::new( + group_schema + .fields() + .iter() + .map(|f| SortField::new(f.data_type().clone())) + .collect(), + )?; + + let name = format!("GroupedHashAggregateStream2[{partition}]"); + let reservation = MemoryConsumer::new(name).register(context.memory_pool()); + let map = RawTable::with_capacity(0); + let group_values = row_converter.empty_rows(0, 0); + let current_group_indices = vec![]; + + timer.done(); + + let exec_state = ExecutionState::ReadingInput; + + Ok(GroupedHashAggregateStream2 { + schema: agg_schema, + input, + mode: agg.mode, + accumulators, + aggregate_arguments, + filter_expressions, + row_converter, + group_by: agg_group_by, + reservation, + map, + group_values, + current_group_indices, + exec_state, + baseline_metrics, + random_state: Default::default(), + batch_size, + }) + } +} + +/// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if +/// that is supported by the aggrgate, or a +/// [`GroupsAccumulatorAdapter`] if not. +fn create_group_accumulator( + agg_expr: &Arc, +) -> Result> { + if agg_expr.groups_accumulator_supported() { + agg_expr.create_groups_accumulator() + } else { + let agg_expr_captured = agg_expr.clone(); + let factory = move || agg_expr_captured.create_accumulator(); + Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) + } +} + +impl Stream for GroupedHashAggregateStream2 { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + let exec_state = self.exec_state.clone(); + match exec_state { + ExecutionState::ReadingInput => { + match ready!(self.input.poll_next_unpin(cx)) { + // new batch to aggregate + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + let result = self.group_aggregate_batch(batch); + timer.done(); + + // allocate memory AFTER we actually used + // the memory, which simplifies the whole + // accounting and we are OK with + // overshooting a bit. + // + // Also this means we either store the + // whole record batch or not. + let result = result.and_then(|allocated| { + self.reservation.try_grow(allocated) + }); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + } + // inner had error, return to caller + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + // inner is done, producing output + None => { + let timer = elapsed_compute.timer(); + match self.create_batch_from_map() { + Ok(batch) => { + self.exec_state = + ExecutionState::ProducingOutput(batch) + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + timer.done(); + } + } + } + + ExecutionState::ProducingOutput(batch) => { + // slice off a part of the batch, if needed + let output_batch = if batch.num_rows() <= self.batch_size { + self.exec_state = ExecutionState::Done; + batch + } else { + // output first batch_size rows + let num_remaining = batch.num_rows() - self.batch_size; + let remaining = batch.slice(self.batch_size, num_remaining); + self.exec_state = ExecutionState::ProducingOutput(remaining); + batch.slice(0, self.batch_size) + }; + return Poll::Ready(Some(Ok( + output_batch.record_output(&self.baseline_metrics) + ))); + } + + ExecutionState::Done => return Poll::Ready(None), + } + } + } +} + +impl RecordBatchStream for GroupedHashAggregateStream2 { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl GroupedHashAggregateStream2 { + /// Calculates the group indicies for each input row of + /// `group_values`. + /// + /// At the return of this function, + /// [`Self::current_group_indices`] has the same number of entries + /// as each array in `group_values` and holds the correct + /// group_index for that row. + fn update_group_state( + &mut self, + group_values: &[ArrayRef], + allocated: &mut usize, + ) -> Result<()> { + // Convert the group keys into the row format + // Avoid reallocation when https://github.com/apache/arrow-rs/issues/4479 is available + let group_rows = self.row_converter.convert_columns(group_values)?; + let n_rows = group_rows.num_rows(); + + // tracks to which group each of the input rows belongs + let group_indices = &mut self.current_group_indices; + let group_indices_size_pre = group_indices.allocated_size(); + let group_values_size_pre = self.group_values.size(); + + // 1.1 Calculate the group keys for the group values + group_indices.clear(); + let mut batch_hashes = vec![0; n_rows]; + create_hashes(group_values, &self.random_state, &mut batch_hashes)?; + + for (row, hash) in batch_hashes.into_iter().enumerate() { + let entry = self.map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + group_rows.row(row) == self.group_values.row(*group_idx) + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + let group_idx = self.group_values.num_rows(); + self.group_values.push(group_rows.row(row)); + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + allocated, + ); + group_idx + } + }; + group_indices.push(group_idx); + } + + // memory growth in group_indieces + *allocated += group_indices.allocated_size(); + *allocated -= group_indices_size_pre; // subtract after adding to avoid underflow + + // account for any memory increase used to store group_values + *allocated += self + .group_values + .size() + .saturating_sub(group_values_size_pre); + + Ok(()) + } + + /// Perform group-by aggregation for the given [`RecordBatch`]. + /// + /// If successful, returns the additional amount of memory, in + /// bytes, that were allocated during this process. + fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result { + // Evaluate the grouping expressions + let group_by_values = evaluate_group_by(&self.group_by, &batch)?; + + // Keep track of memory allocated: + let mut allocated = 0usize; + + // Evaluate the aggregation expressions. + let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; + + // Evalute the filter expressions, if any, against the inputs + let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; + + let row_converter_size_pre = self.row_converter.size(); + + for group_values in &group_by_values { + // calculate the group indicies for each input row + self.update_group_state(group_values, &mut allocated)?; + let group_indices = &self.current_group_indices; + + // Gather the inputs to call the actual aggregation + let t = self + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()); + + let total_num_groups = self.group_values.num_rows(); + + for ((acc, values), opt_filter) in t { + let acc_size_pre = acc.size(); + let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + + // Call the appropriate method on each aggregator with + // the entire input row and the relevant group indexes + match self.mode { + AggregateMode::Partial | AggregateMode::Single => { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } + AggregateMode::FinalPartitioned | AggregateMode::Final => { + // if aggregation is over intermediate states, + // use merge + acc.merge_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } + } + + allocated += acc.size().saturating_sub(acc_size_pre); + } + } + allocated += self + .row_converter + .size() + .saturating_sub(row_converter_size_pre); + + Ok(allocated) + } +} + +impl GroupedHashAggregateStream2 { + /// Create an output RecordBatch with all group keys and accumulator states/values + fn create_batch_from_map(&mut self) -> Result { + if self.group_values.num_rows() == 0 { + let schema = self.schema.clone(); + return Ok(RecordBatch::new_empty(schema)); + } + + // First output rows are the groups + let groups_rows = self.group_values.iter(); + + let mut output: Vec = self.row_converter.convert_rows(groups_rows)?; + + // Next output the accumulators + for acc in self.accumulators.iter_mut() { + match self.mode { + AggregateMode::Partial => output.extend(acc.state()?), + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single => output.push(acc.evaluate()?), + } + } + + Ok(RecordBatch::try_new(self.schema.clone(), output)?) + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 74370049e81f..74dd9ee1d13e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -28,8 +28,8 @@ use datafusion::physical_plan::aggregates::{ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use datafusion::physical_plan::collect; use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_physical_expr::expressions::{col, Sum}; use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; @@ -107,6 +107,10 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .map(|elem| (col(elem, &schema).unwrap(), elem.to_string())) .collect::>(); let group_by = PhysicalGroupBy::new_single(expr); + + println!("aggregate_expr: {aggregate_expr:?}"); + println!("group_by: {group_by:?}"); + let aggregate_exec_running = Arc::new( AggregateExec::try_new( AggregateMode::Partial, @@ -118,7 +122,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str schema.clone(), ) .unwrap(), - ) as _; + ) as Arc; let aggregate_exec_usual = Arc::new( AggregateExec::try_new( @@ -131,14 +135,14 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str schema.clone(), ) .unwrap(), - ) as _; + ) as Arc; let task_ctx = ctx.task_ctx(); - let collected_usual = collect(aggregate_exec_usual, task_ctx.clone()) + let collected_usual = collect(aggregate_exec_usual.clone(), task_ctx.clone()) .await .unwrap(); - let collected_running = collect(aggregate_exec_running, task_ctx.clone()) + let collected_running = collect(aggregate_exec_running.clone(), task_ctx.clone()) .await .unwrap(); assert!(collected_running.len() > 2); @@ -162,7 +166,23 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .zip(&running_formatted_sorted) .enumerate() { - assert_eq!((i, usual_line), (i, running_line), "Inconsistent result"); + assert_eq!( + (i, usual_line), + (i, running_line), + "Inconsistent result\n\n\ + Left Plan:\n{}\n\ + Right Plan:\n{}\n\ + schema:\n{schema}\n\ + Left Ouptut:\n{}\n\ + Right Output:\n{}\n\ + input:\n{}\n\ + ", + displayable(aggregate_exec_usual.as_ref()).indent(false), + displayable(aggregate_exec_running.as_ref()).indent(false), + usual_formatted, + running_formatted, + pretty_format_batches(&input1).unwrap(), + ); } } diff --git a/datafusion/execution/src/memory_pool/proxy.rs b/datafusion/execution/src/memory_pool/proxy.rs index 43532f9a81f1..2bf485c6ee76 100644 --- a/datafusion/execution/src/memory_pool/proxy.rs +++ b/datafusion/execution/src/memory_pool/proxy.rs @@ -26,6 +26,11 @@ pub trait VecAllocExt { /// [Push](Vec::push) new element to vector and store additional allocated bytes in `accounting` (additive). fn push_accounted(&mut self, x: Self::T, accounting: &mut usize); + + /// Return the amount of memory allocated by this Vec (not + /// recursively counting any heap allocations contained within the + /// structure). Does not include the size of `self` + fn allocated_size(&self) -> usize; } impl VecAllocExt for Vec { @@ -44,6 +49,9 @@ impl VecAllocExt for Vec { self.push(x); } + fn allocated_size(&self) -> usize { + std::mem::size_of::() * self.capacity() + } } /// Extension trait for [`RawTable`] to account for allocations. diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d1c2f7bf3377..b7ffa1810cce 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -62,6 +62,7 @@ indexmap = "2.0.0" itertools = { version = "0.11", features = ["use_std"] } lazy_static = { version = "^1.4.0" } libc = "0.2.140" +log = "^0.4" md-5 = { version = "^0.10.0", optional = true } paste = "^1.0" petgraph = "0.6.2" diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 3c76da51a9d4..6081564ccdf1 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -17,10 +17,14 @@ //! Defines physical expressions that can evaluated at runtime during query execution +use arrow::array::{AsArray, PrimitiveBuilder}; +use log::debug; + use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; +use crate::aggregate::groups_accumulator::accumulate::NullState; use crate::aggregate::row_accumulator::{ is_row_accumulator_support_dtype, RowAccumulator, }; @@ -29,19 +33,23 @@ use crate::aggregate::sum::sum_batch; use crate::aggregate::utils::calculate_result_decimal_for_avg; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::compute; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type}; use arrow::{ array::{ArrayRef, UInt64Array}, datatypes::Field, }; -use arrow_array::Array; +use arrow_array::{ + Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray, +}; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use datafusion_row::accessor::RowAccessor; +use super::utils::{adjust_output_array, Decimal128Averager}; + /// AVG aggregate expression #[derive(Debug, Clone)] pub struct Avg { @@ -155,6 +163,50 @@ impl AggregateExpr for Avg { &self.rt_data_type, )?)) } + + fn groups_accumulator_supported(&self) -> bool { + use DataType::*; + + matches!(&self.rt_data_type, Float64 | Decimal128(_, _)) + } + + fn create_groups_accumulator(&self) -> Result> { + use DataType::*; + // instantiate specialized accumulator based for the type + match (&self.sum_data_type, &self.rt_data_type) { + (Float64, Float64) => { + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.sum_data_type, + &self.rt_data_type, + |sum: f64, count: u64| Ok(sum / count as f64), + ))) + } + ( + Decimal128(_sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => { + let decimal_averager = Decimal128Averager::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = + move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); + + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.sum_data_type, + &self.rt_data_type, + avg_fn, + ))) + } + + _ => Err(DataFusionError::NotImplemented(format!( + "AvgGroupsAccumulator for ({} --> {})", + self.sum_data_type, self.rt_data_type, + ))), + } + } } impl PartialEq for Avg { @@ -383,6 +435,189 @@ impl RowAccumulator for AvgRowAccumulator { } } +/// An accumulator to compute the average of `[PrimitiveArray]`. +/// Stores values as native types, and does overflow checking +/// +/// F: Function that calcuates the average value from a sum of +/// T::Native and a total count +#[derive(Debug)] +struct AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result + Send, +{ + /// The type of the internal sum + sum_data_type: DataType, + + /// The type of the returned sum + return_data_type: DataType, + + /// Count per group (use u64 to make UInt64Array) + counts: Vec, + + /// Sums per group, stored as the native type + sums: Vec, + + /// Track nulls in the input / filters + null_state: NullState, + + /// Function that computes the final average (value / count) + avg_fn: F, +} + +impl AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result + Send, +{ + pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { + debug!( + "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}", + std::any::type_name::() + ); + + Self { + return_data_type: return_data_type.clone(), + sum_data_type: sum_data_type.clone(), + counts: vec![], + sums: vec![], + null_state: NullState::new(), + avg_fn, + } + } +} + +impl GroupsAccumulator for AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result + Send, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_primitive::(); + + // increment counts, update sums + self.counts.resize(total_num_groups, 0); + self.sums.resize(total_num_groups, T::default_value()); + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values.get(0).unwrap().as_primitive::(); + let partial_sums = values.get(1).unwrap().as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + self.null_state.accumulate( + group_indices, + partial_counts, + opt_filter, + total_num_groups, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ); + + // update sums + self.sums + .resize_with(total_num_groups, || T::default_value()); + self.null_state.accumulate( + group_indices, + partial_sums, + opt_filter, + total_num_groups, + |group_index, new_value: ::Native| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let counts = std::mem::take(&mut self.counts); + let sums = std::mem::take(&mut self.sums); + let nulls = self.null_state.build(); + + assert_eq!(counts.len(), sums.len()); + + // don't evaluate averages with null inputs to avoid errors on null vaues + let array: PrimitiveArray = if let Some(nulls) = nulls.as_ref() { + assert_eq!(nulls.len(), sums.len()); + let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); + let iter = sums.into_iter().zip(counts.into_iter()).zip(nulls.iter()); + + for ((sum, count), is_valid) in iter { + if is_valid { + builder.append_value((self.avg_fn)(sum, count)?) + } else { + builder.append_null(); + } + } + builder.finish() + } else { + let averages: Vec = sums + .into_iter() + .zip(counts.into_iter()) + .map(|(sum, count)| (self.avg_fn)(sum, count)) + .collect::>>()?; + PrimitiveArray::new(averages.into(), nulls) // no copy + }; + + // fix up decimal precision and scale for decimals + let array = adjust_output_array(&self.return_data_type, Arc::new(array))?; + + Ok(array) + } + + // return arrays for sums and counts + fn state(&mut self) -> Result> { + let nulls = self.null_state.build(); + let counts = std::mem::take(&mut self.counts); + let counts = UInt64Array::from(counts); // zero copy + + let sums = std::mem::take(&mut self.sums); + let sums = PrimitiveArray::::new(sums.into(), nulls); // zero copy + let sums = adjust_output_array(&self.sum_data_type, Arc::new(sums))?; + + Ok(vec![ + Arc::new(counts) as ArrayRef, + Arc::new(sums) as ArrayRef, + ]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index 4bbe563edce8..9f239a09ddef 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -1,5 +1,5 @@ // Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file +// or more contributaor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines BitAnd, BitOr, and BitXor Aggregate accumulators use ahash::RandomState; use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use arrow::datatypes::{ + DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, +}; use arrow::{ array::{ ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, @@ -35,6 +38,7 @@ use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::collections::HashSet; +use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use crate::aggregate::row_accumulator::{ is_row_accumulator_support_dtype, RowAccumulator, }; @@ -44,6 +48,18 @@ use arrow::array::Array; use arrow::compute::{bit_and, bit_or, bit_xor}; use datafusion_row::accessor::RowAccessor; +/// Creates a [`PrimitiveGroupsAccumulator`] with the specified +/// [`ArrowPrimitiveType`] which applies `$FN` to each element +/// +/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType +macro_rules! instantiate_primitive_accumulator { + ($PRIMTYPE:ident, $FN:expr) => {{ + Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( + $FN, + ))) + }}; +} + // returns the new value after bit_and/bit_or/bit_xor with the new values, taking nullability into account macro_rules! typed_bit_and_or_xor_batch { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ @@ -254,6 +270,46 @@ impl AggregateExpr for BitAnd { ))) } + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + use std::ops::BitAndAssign; + match self.data_type { + DataType::Int8 => { + instantiate_primitive_accumulator!(Int8Type, |x, y| x.bitand_assign(y)) + } + DataType::Int16 => { + instantiate_primitive_accumulator!(Int16Type, |x, y| x.bitand_assign(y)) + } + DataType::Int32 => { + instantiate_primitive_accumulator!(Int32Type, |x, y| x.bitand_assign(y)) + } + DataType::Int64 => { + instantiate_primitive_accumulator!(Int64Type, |x, y| x.bitand_assign(y)) + } + DataType::UInt8 => { + instantiate_primitive_accumulator!(UInt8Type, |x, y| x.bitand_assign(y)) + } + DataType::UInt16 => { + instantiate_primitive_accumulator!(UInt16Type, |x, y| x.bitand_assign(y)) + } + DataType::UInt32 => { + instantiate_primitive_accumulator!(UInt32Type, |x, y| x.bitand_assign(y)) + } + DataType::UInt64 => { + instantiate_primitive_accumulator!(UInt64Type, |x, y| x.bitand_assign(y)) + } + + _ => Err(DataFusionError::NotImplemented(format!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } @@ -444,6 +500,46 @@ impl AggregateExpr for BitOr { ))) } + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + use std::ops::BitOrAssign; + match self.data_type { + DataType::Int8 => { + instantiate_primitive_accumulator!(Int8Type, |x, y| x.bitor_assign(y)) + } + DataType::Int16 => { + instantiate_primitive_accumulator!(Int16Type, |x, y| x.bitor_assign(y)) + } + DataType::Int32 => { + instantiate_primitive_accumulator!(Int32Type, |x, y| x.bitor_assign(y)) + } + DataType::Int64 => { + instantiate_primitive_accumulator!(Int64Type, |x, y| x.bitor_assign(y)) + } + DataType::UInt8 => { + instantiate_primitive_accumulator!(UInt8Type, |x, y| x.bitor_assign(y)) + } + DataType::UInt16 => { + instantiate_primitive_accumulator!(UInt16Type, |x, y| x.bitor_assign(y)) + } + DataType::UInt32 => { + instantiate_primitive_accumulator!(UInt32Type, |x, y| x.bitor_assign(y)) + } + DataType::UInt64 => { + instantiate_primitive_accumulator!(UInt64Type, |x, y| x.bitor_assign(y)) + } + + _ => Err(DataFusionError::NotImplemented(format!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } @@ -635,6 +731,46 @@ impl AggregateExpr for BitXor { ))) } + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + use std::ops::BitXorAssign; + match self.data_type { + DataType::Int8 => { + instantiate_primitive_accumulator!(Int8Type, |x, y| x.bitxor_assign(y)) + } + DataType::Int16 => { + instantiate_primitive_accumulator!(Int16Type, |x, y| x.bitxor_assign(y)) + } + DataType::Int32 => { + instantiate_primitive_accumulator!(Int32Type, |x, y| x.bitxor_assign(y)) + } + DataType::Int64 => { + instantiate_primitive_accumulator!(Int64Type, |x, y| x.bitxor_assign(y)) + } + DataType::UInt8 => { + instantiate_primitive_accumulator!(UInt8Type, |x, y| x.bitxor_assign(y)) + } + DataType::UInt16 => { + instantiate_primitive_accumulator!(UInt16Type, |x, y| x.bitxor_assign(y)) + } + DataType::UInt32 => { + instantiate_primitive_accumulator!(UInt32Type, |x, y| x.bitxor_assign(y)) + } + DataType::UInt64 => { + instantiate_primitive_accumulator!(UInt64Type, |x, y| x.bitxor_assign(y)) + } + + _ => Err(DataFusionError::NotImplemented(format!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } diff --git a/datafusion/physical-expr/src/aggregate/bool_and_or.rs b/datafusion/physical-expr/src/aggregate/bool_and_or.rs index e444dc61ee1b..6107b0972c81 100644 --- a/datafusion/physical-expr/src/aggregate/bool_and_or.rs +++ b/datafusion/physical-expr/src/aggregate/bool_and_or.rs @@ -17,10 +17,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::datatypes::DataType; use arrow::{ array::{ArrayRef, BooleanArray}, @@ -28,7 +25,10 @@ use arrow::{ }; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; +use crate::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; use crate::aggregate::row_accumulator::{ is_row_accumulator_support_dtype, RowAccumulator, }; @@ -193,6 +193,23 @@ impl AggregateExpr for BoolAnd { ))) } + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + match self.data_type { + DataType::Boolean => { + Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y))) + } + _ => Err(DataFusionError::NotImplemented(format!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } @@ -381,6 +398,23 @@ impl AggregateExpr for BoolOr { ))) } + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + match self.data_type { + DataType::Boolean => { + Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x || y))) + } + _ => Err(DataFusionError::NotImplemented(format!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 22cb2512fc42..37a756894c72 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -24,11 +24,14 @@ use std::sync::Arc; use crate::aggregate::row_accumulator::RowAccumulator; use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::array::{Array, Int64Array}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::PrimitiveArray; use arrow_buffer::BooleanBuffer; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; @@ -37,6 +40,8 @@ use datafusion_row::accessor::RowAccessor; use crate::expressions::format_state_name; +use super::groups_accumulator::accumulate::accumulate_indices; + /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. #[derive(Debug, Clone)] @@ -44,6 +49,10 @@ pub struct Count { name: String, data_type: DataType, nullable: bool, + /// Input exprs + /// + /// For `COUNT(c1)` this is `[c1]` + /// For `COUNT(c1, c2)` this is `[c1, c2]` exprs: Vec>, } @@ -76,6 +85,109 @@ impl Count { } } +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct CountGroupsAccumulator { + /// Count per group (use i64 to make Int64Array) + counts: Vec, +} + +impl CountGroupsAccumulator { + pub fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap(); + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, 0); + accumulate_indices( + group_indices, + values.nulls(), // ignore values + opt_filter, + |group_index| { + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values.get(0).unwrap().as_primitive::(); + + // intermediate counts are always created as non null + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + // Adds the counts with the partial counts + self.counts.resize(total_num_groups, 0); + match opt_filter { + Some(filter) => filter + .iter() + .zip(group_indices.iter()) + .zip(partial_counts.iter()) + .for_each(|((filter_value, &group_index), partial_count)| { + if let Some(true) = filter_value { + self.counts[group_index] += partial_count; + } + }), + None => group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ), + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let counts = std::mem::take(&mut self.counts); + + // Count is always non null (null inputs just don't contribute to the overall values) + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + } + + // return arrays for sums and counts + fn state(&mut self) -> Result> { + let counts = std::mem::take(&mut self.counts); + let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls + Ok(vec![Arc::new(counts) as ArrayRef]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + /// count null values for multiple columns /// for each row if one column value is null, then null_count + 1 fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { @@ -133,6 +245,13 @@ impl AggregateExpr for Count { true } + fn groups_accumulator_supported(&self) -> bool { + // groups accumulator only supports `COUNT(c1)`, not + // `COUNT(c1, c2)`, etc + // TODO file a ticket to optimize + self.exprs.len() == 1 + } + fn create_row_accumulator( &self, start_index: usize, @@ -147,6 +266,11 @@ impl AggregateExpr for Count { fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(CountAccumulator::new())) } + + fn create_groups_accumulator(&self) -> Result> { + // instantiate specialized accumulator + Ok(Box::new(CountGroupsAccumulator::new())) + } } impl PartialEq for Count { diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs new file mode 100644 index 000000000000..a9627ece7c43 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -0,0 +1,879 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] +//! +//! These functions are designed to be the performance critical inner +//! loops of [`GroupsAccumulator`], so there are multiple type +//! specific methods, invoked depending on the input. +//! +//! [`GroupsAccumulator`]: crate::GroupsAccumulator + +use arrow::datatypes::ArrowPrimitiveType; +use arrow_array::{Array, BooleanArray, PrimitiveArray}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + +/// Track the accumulator null state per row: if any values for that +/// group were null and if any values have been seen at all for that group. +/// +/// This is part of the inner loop for many GroupsAccumulators, and +/// thus the performance is critical. +/// +/// typically 4 potential combinations of input values that +/// accumulators need to special case for performance, +/// +/// GroupsAccumulators need handle all four combinations of: +/// +/// * With / Without filter +/// * With / Without nulls in the input +/// +/// If there are filters present, `NullState` tarcks if it has seen +/// *any* value for that group (as some values may be filtered +/// out). Without a filter, the accumulator is only passed groups +/// that actually had a value to accumulate so they do not need to +/// track if they have seen values for a particular group. +/// +/// If the input has nulls, then the accumulator must potentially +/// handle each input null value specially (e.g. for `SUM` to mark the +/// corresponding sum as null) +#[derive(Debug)] +pub struct NullState { + /// Tracks if a null input value has been seen for `group_index`, + /// if there were any nulls in the input. + /// + /// If `null_inputs[i]` is true, have not seen any null values for + /// that group, or have not seen any vaues + /// + /// If `null_inputs[i]` is false, saw at least one null value for + /// that group + null_inputs: Option, + + /// If there has been a filter value, has it seen any non-filtered + /// input values for `group_index`? + /// + /// If `seen_values[i]` is true, it seen at least one non null + /// value for this group + /// + /// If `seen_values[i]` is false, have not seen any values that + /// pass the filter yet for the group + seen_values: Option, +} + +impl NullState { + pub fn new() -> Self { + Self { + null_inputs: None, + seen_values: None, + } + } + + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value, while tracking which groups have seen null + /// inputs and which groups have seen any inputs + // + /// # Arguments: + /// + /// * `values`: the input arguments to the accumulator + /// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) + /// * `opt_filter`: if present, only rows for which is Some(true) are included + /// * `value_fn`: function invoked for (group_index, value) where value is non null + /// + /// `F`: Invoked for each input row like `value_fn(group_index, + /// value)` for each non null, non filtered value. + /// + /// # Example + /// + /// ```text + /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ + /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ + /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ + /// │ └─────┘ │ │ └─────┘ │ └─────┘ + /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ + /// + /// group_indices values opt_filter + /// ``` + /// + /// In the example above, `value_fn` is invoked for each (group_index, + /// value) pair where `opt_filter[i]` is true + /// + /// ```text + /// value_fn(2, 200) + /// value_fn(0, 200) + /// value_fn(0, 300) + /// ``` + /// + /// It also sets + /// + /// 1. `self.seen_values[group_index]` to true for all rows that had a value if there is a filter + /// + /// 2. `self.null_inputs[group_index]` to true for all rows that had a null in input + pub fn accumulate( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, + { + let data: &[T::Native] = values.values(); + assert_eq!(data.len(), group_indices.len()); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + // if we have previously seen nulls, ensure the null + // buffer is big enough (start everything at valid) + if self.null_inputs.is_some() { + initialize_builder(&mut self.null_inputs, total_num_groups, true); + } + let iter = group_indices.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value) + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + // All groups start as valid (true), and are set to + // null if we see a null in the input) + let null_inputs = + initialize_builder(&mut self.null_inputs, total_num_groups, true); + + // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = nulls.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + value_fn(group_index, new_value); + } else { + // input null means this group is now null + null_inputs.set_bit(group_index, false); + } + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the intial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .zip(data_remainder.iter()) + .enumerate() + .for_each(|(i, (&group_index, &new_value))| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + value_fn(group_index, new_value); + } else { + // input null means this group is now null + null_inputs.set_bit(group_index, false); + } + }); + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + + // default seen to false (we fill it in as we go) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than a single + // iterator. TODO file a ticket + group_indices + .iter() + .zip(data.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, &new_value), filter_value)| { + if let Some(true) = filter_value { + value_fn(group_index, new_value); + // remember we have seen a value for this index + seen_values.set_bit(group_index, true); + } + }) + } + // both null values and filters + (true, Some(filter)) => { + let null_inputs = + initialize_builder(&mut self.null_inputs, total_num_groups, true); + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than using + // iterators. TODO file a ticket + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + value_fn(*group_index, new_value) + } else { + // input null means this group is now null + null_inputs.set_bit(*group_index, false); + } + // remember we have seen a value for this index + seen_values.set_bit(*group_index, true); + } + }) + } + } + } + + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`BooleanArray`]s. + /// + /// See [`Self::accumulate`], which handles [`PrimitiveArray`]s, + /// for more details. + pub fn accumulate_boolean( + &mut self, + group_indices: &[usize], + values: &BooleanArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, bool) + Send, + { + let data = values.values(); + assert_eq!(data.len(), group_indices.len()); + + // These could be made more performant by iterating in chunks of 64 bits at a time + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + // if we have previously seen nulls, ensure the null + // buffer is big enough (start everything at valid) + if self.null_inputs.is_some() { + initialize_builder(&mut self.null_inputs, total_num_groups, true); + } + group_indices.iter().zip(data.iter()).for_each( + |(&group_index, new_value)| value_fn(group_index, new_value), + ) + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + // All groups start as valid (true), and are set to + // null if we see a null in the input) + let null_inputs = + initialize_builder(&mut self.null_inputs, total_num_groups, true); + + group_indices + .iter() + .zip(data.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + value_fn(group_index, new_value); + } else { + // input null means this group is now null + null_inputs.set_bit(group_index, false); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + + // default seen to false (we fill it in as we go) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + group_indices + .iter() + .zip(data.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + value_fn(group_index, new_value); + // remember we have seen a value for this index + seen_values.set_bit(group_index, true); + } + }) + } + // both null values and filters + (true, Some(filter)) => { + let null_inputs = + initialize_builder(&mut self.null_inputs, total_num_groups, true); + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + value_fn(*group_index, new_value) + } else { + // input null means this group is now null + null_inputs.set_bit(*group_index, false); + } + // remember we have seen a value for this index + seen_values.set_bit(*group_index, true); + } + }) + } + } + } + + /// Creates the final NullBuffer representing which group_indices have + /// null values (if they saw a null input, or because they never saw any values) + /// + /// resets the internal state to empty + /// + /// nulls (validity) set false for any group that saw a null + /// seen_values (validtity) set true for any group that saw a value + pub fn build(&mut self) -> Option { + let nulls = self + .null_inputs + .as_mut() + .map(|null_inputs| NullBuffer::new(null_inputs.finish())) + .and_then(|nulls| { + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } + }); + + // if we had filters, some groups may never have seen a group + // so they are only non-null if we have seen values + let seen_values = self + .seen_values + .as_mut() + .map(|seen_values| NullBuffer::new(seen_values.finish())); + + match (nulls, seen_values) { + (None, None) => None, + (Some(nulls), None) => Some(nulls), + (None, Some(seen_values)) => Some(seen_values), + (Some(seen_values), Some(nulls)) => { + NullBuffer::union(Some(&seen_values), Some(&nulls)) + } + } + } +} + +/// This function is called to update the accumulator state per row +/// when the value is not needed (e.g. COUNT) +/// +/// `F`: Invoked like `value_fn(group_index) for all non null values +/// passing the filter. Note that no tracking is done for null inputs +/// or which groups have seen any values +pub fn accumulate_indices( + group_indices: &[usize], + nulls: Option<&NullBuffer>, + opt_filter: Option<&BooleanArray>, + mut index_fn: F, +) where + F: FnMut(usize) + Send, +{ + match (nulls, opt_filter) { + (None, None) => { + for &group_index in group_indices.iter() { + index_fn(group_index) + } + } + (None, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than a single + // iterator. TODO file a ticket + let iter = group_indices.iter().zip(filter.iter()); + for (&group_index, filter_value) in iter { + if let Some(true) = filter_value { + index_fn(group_index) + } + } + } + (Some(valids), None) => { + assert_eq!(valids.len(), group_indices.len()); + // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let bit_chunks = valids.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks.zip(bit_chunks.iter()).for_each( + |(group_index_chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }, + ); + + // handle any remaining bits (after the intial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + index_fn(group_index) + } + }); + } + + (Some(valids), Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + assert_eq!(valids.len(), group_indices.len()); + // The performance with a filter could likely be improved by + // iterating over the filter in chunks, rather than using + // iterators. TODO file a ticket + filter + .iter() + .zip(group_indices.iter()) + .zip(valids.iter()) + .for_each(|((filter_value, &group_index), is_valid)| { + if let (Some(true), true) = (filter_value, is_valid) { + index_fn(group_index) + } + }) + } + } +} + +/// Enures that `builder` contains a `BooleanBufferBuilder with at +/// least `total_num_groups`. +/// +/// All new entries are initialized to `default_value` +fn initialize_builder( + builder: &mut Option, + total_num_groups: usize, + default_value: bool, +) -> &mut BooleanBufferBuilder { + if builder.is_none() { + *builder = Some(BooleanBufferBuilder::new(total_num_groups)); + } + let builder = builder.as_mut().unwrap(); + + if builder.len() < total_num_groups { + let new_groups = total_num_groups - builder.len(); + builder.append_n(new_groups, default_value); + } + builder +} + +#[cfg(test)] +mod test { + use super::*; + + use arrow_array::UInt32Array; + use hashbrown::HashSet; + use rand::{rngs::ThreadRng, Rng}; + + #[test] + fn accumulate() { + let group_indices = (0..100).collect(); + let values = (0..100).map(|i| (i + 1) * 10).collect(); + let values_with_nulls = (0..100) + .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) }) + .collect(); + + // default to every fifth value being false, every even + // being null + let filter: BooleanArray = (0..100) + .map(|i| { + let is_even = i % 2 == 0; + let is_fifth = i % 5 == 0; + if is_even { + None + } else if is_fifth { + Some(false) + } else { + Some(true) + } + }) + .collect(); + + Fixture { + group_indices, + values, + values_with_nulls, + filter, + } + .run() + } + + #[test] + fn accumulate_fuzz() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + Fixture::new_random(&mut rng).run(); + } + } + + /// Values for testing (there are enough values to exercise the 64 bit chunks + struct Fixture { + /// 100..0 + group_indices: Vec, + + /// 10, 20, ... 1010 + values: Vec, + + /// same as values, but every third is null: + /// None, Some(20), Some(30), None ... + values_with_nulls: Vec>, + + /// filter (defaults to None) + filter: BooleanArray, + } + + impl Fixture { + fn new_random(rng: &mut ThreadRng) -> Self { + // Number of input values in a batch + let num_values: usize = rng.gen_range(1..200); + // number of distinct groups + let num_groups: usize = rng.gen_range(2..1000); + let max_group = num_groups - 1; + + let group_indices: Vec = (0..num_values) + .map(|_| rng.gen_range(0..max_group)) + .collect(); + + let values: Vec = (0..num_values).map(|_| rng.gen()).collect(); + + // 10% chance of false + // 10% change of null + // 80% chance of true + let filter: BooleanArray = (0..num_values) + .map(|_| { + let filter_value = rng.gen_range(0.0..1.0); + if filter_value < 0.1 { + Some(false) + } else if filter_value < 0.2 { + None + } else { + Some(true) + } + }) + .collect(); + + // random values with random number and location of nulls + // random null percentage + let null_pct: f32 = rng.gen_range(0.0..1.0); + let values_with_nulls: Vec> = (0..num_values) + .map(|_| { + let is_null = null_pct < rng.gen_range(0.0..1.0); + if is_null { + None + } else { + Some(rng.gen()) + } + }) + .collect(); + + Self { + group_indices, + values, + values_with_nulls, + filter, + } + } + + /// returns `Self::values` an Array + fn values_array(&self) -> UInt32Array { + UInt32Array::from(self.values.clone()) + } + + /// returns `Self::values_with_nulls` as an Array + fn values_with_nulls_array(&self) -> UInt32Array { + UInt32Array::from(self.values_with_nulls.clone()) + } + + /// Calls `NullState::accumulate` and `accumulate_indices` + /// with all combinations of nulls and filter values + fn run(&self) { + let total_num_groups = *self.group_indices.iter().max().unwrap() + 1; + + let group_indices = &self.group_indices; + let values_array = self.values_array(); + let values_with_nulls_array = self.values_with_nulls_array(); + let filter = &self.filter; + + // no null, no filters + Self::accumulate_test(group_indices, &values_array, None, total_num_groups); + + // nulls, no filters + Self::accumulate_test( + group_indices, + &values_with_nulls_array, + None, + total_num_groups, + ); + + // no nulls, filters + Self::accumulate_test( + group_indices, + &values_array, + Some(filter), + total_num_groups, + ); + + // nulls, filters + Self::accumulate_test( + group_indices, + &values_with_nulls_array, + Some(filter), + total_num_groups, + ); + } + + /// Calls `NullState::accumulate` and `accumulate_indices` to + /// ensure it generates the correct values. + /// + fn accumulate_test( + group_indices: &[usize], + values: &UInt32Array, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + Self::accumulate_values_test( + group_indices, + values, + opt_filter, + total_num_groups, + ); + Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter); + } + + /// This is effectively a different implementation of + /// accumulate that we compare with the above implementation + fn accumulate_values_test( + group_indices: &[usize], + values: &UInt32Array, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + let mut accumulated_values = vec![]; + let mut null_state = NullState::new(); + + null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, value| { + accumulated_values.push((group_index, value)); + }, + ); + + // Figure out the expected values + let mut expected_values = vec![]; + let mut expected_null_input = HashSet::new(); + let mut expected_seen_values = HashSet::new(); + + match opt_filter { + None => group_indices.iter().zip(values.iter()).for_each( + |(&group_index, value)| { + expected_seen_values.insert(group_index); + if let Some(value) = value { + expected_values.push((group_index, value)); + } else { + expected_null_input.insert(group_index); + } + }, + ), + Some(filter) => { + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, value), is_included)| { + // if value passed filter + if let Some(true) = is_included { + expected_seen_values.insert(group_index); + if let Some(value) = value { + expected_values.push((group_index, value)); + } else { + expected_null_input.insert(group_index); + } + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + + // validate null state + if values.null_count() > 0 { + let null_inputs = + null_state.null_inputs.as_ref().unwrap().finish_cloned(); + for (group_index, is_valid) in null_inputs.iter().enumerate() { + let expected_valid = !expected_null_input.contains(&group_index); + assert_eq!( + expected_valid, is_valid, + "mismatch at for group {group_index}" + ); + } + } + + // validate seen_values + + if opt_filter.is_some() { + let seen_values = + null_state.seen_values.as_ref().unwrap().finish_cloned(); + for (group_index, is_seen) in seen_values.iter().enumerate() { + let expected_seen = expected_seen_values.contains(&group_index); + assert_eq!( + expected_seen, is_seen, + "mismatch at for group {group_index}" + ); + } + } + + // Validate the final buffer (one value per group) + let expected_null_buffer = + match (values.null_count() > 0, opt_filter.is_some()) { + (false, false) => None, + // only nulls + (true, false) => { + let null_buffer: NullBuffer = (0..total_num_groups) + .map(|group_index| { + // there was and no null inputs + !expected_null_input.contains(&group_index) + }) + .collect(); + Some(null_buffer) + } + // only filter + (false, true) => { + let null_buffer: NullBuffer = (0..total_num_groups) + .map(|group_index| { + // we saw a value + expected_seen_values.contains(&group_index) + }) + .collect(); + Some(null_buffer) + } + // nulls and filter + (true, true) => { + let null_buffer: NullBuffer = (0..total_num_groups) + .map(|group_index| { + // output is valid if there was at least one + // input value and no null inputs + expected_seen_values.contains(&group_index) + && !expected_null_input.contains(&group_index) + }) + .collect(); + Some(null_buffer) + } + }; + + let null_buffer = null_state.build(); + + assert_eq!(null_buffer, expected_null_buffer); + } + + // Calls `accumulate_indices` + // and opt_filter and ensures it calls the right values + fn accumulate_indices_test( + group_indices: &[usize], + nulls: Option<&NullBuffer>, + opt_filter: Option<&BooleanArray>, + ) { + let mut accumulated_values = vec![]; + + accumulate_indices(group_indices, nulls, opt_filter, |group_index| { + accumulated_values.push(group_index); + }); + + // Figure out the expected values + let mut expected_values = vec![]; + + match (nulls, opt_filter) { + (None, None) => group_indices.iter().for_each(|&group_index| { + expected_values.push(group_index); + }), + (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each( + |(&group_index, is_valid)| { + if is_valid { + expected_values.push(group_index); + } + }, + ), + (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each( + |(&group_index, is_included)| { + if let Some(true) = is_included { + expected_values.push(group_index); + } + }, + ), + (Some(nulls), Some(filter)) => { + group_indices + .iter() + .zip(nulls.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, is_valid), is_included)| { + // if value passed filter + if let (true, Some(true)) = (is_valid, is_included) { + expected_values.push(group_index); + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + } + } +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs new file mode 100644 index 000000000000..a403a6d584c0 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] + +use super::GroupsAccumulator; +use arrow::{ + array::{AsArray, UInt32Builder}, + compute, + datatypes::UInt32Type, +}; +use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; +use datafusion_common::{ + utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::Accumulator; + +/// An adpater that implements [`GroupsAccumulator`] for any [`Accumulator`] +/// +/// While [`Accumulator`] are simpler to implement and can support +/// more general calculations (like retractable), but are not as fast +/// as `GroupsAccumulator`. This interface bridges the gap. +pub struct GroupsAccumulatorAdapter { + factory: Box Result> + Send>, + + /// state for each group, stored in group_index order + states: Vec, + + /// Current memory usage, in bytes. + /// + /// Note this is incrementally updated to avoid size() being a + /// bottleneck, which we saw in earlier implementations. + allocation_bytes: usize, +} + +struct AccumulatorState { + /// [`Accumulator`] that stores the per-group state + accumulator: Box, + + // scratch space: indexes in the input array that will be fed to + // this accumulator. Stores indexes as `u32` to match the arrow + // `take` kernel input. + indices: Vec, +} + +impl AccumulatorState { + fn new(accumulator: Box) -> Self { + Self { + accumulator, + indices: vec![], + } + } + + /// Returns the amount of memory taken by this structre and its accumulator + fn size(&self) -> usize { + self.accumulator.size() + + std::mem::size_of_val(self) + + std::mem::size_of::() * self.indices.capacity() + } +} + +impl GroupsAccumulatorAdapter { + /// Create a new adapter that will create a new [`Accumulator`] + /// for each group, using the specified factory function + pub fn new(factory: F) -> Self + where + F: Fn() -> Result> + Send + 'static, + { + let mut new_self = Self { + factory: Box::new(factory), + states: vec![], + allocation_bytes: 0, + }; + new_self.reset_allocation(); + new_self + } + + // Reset the allocation bytes to empty state + fn reset_allocation(&mut self) { + assert!(self.states.is_empty()); + self.allocation_bytes = std::mem::size_of::(); + } + + /// Ensure that self.accumulators has total_num_groups + fn make_accumulators_if_needed(&mut self, total_num_groups: usize) -> Result<()> { + // can't shrink + assert!(total_num_groups >= self.states.len()); + let vec_size_pre = + std::mem::size_of::() * self.states.capacity(); + + // instanatiate new accumulators + let new_accumulators = total_num_groups - self.states.len(); + for _ in 0..new_accumulators { + let accumulator = (self.factory)()?; + let state = AccumulatorState::new(accumulator); + self.allocation_bytes += state.size(); + self.states.push(state); + } + let vec_size_post = + std::mem::size_of::() * self.states.capacity(); + + self.allocation_bytes += vec_size_post.saturating_sub(vec_size_pre); + Ok(()) + } + + /// invokes f(accumulator, values) for each group that has values + /// in group_indices. + /// + /// This function first reorders the input and filter so that + /// values for each group_index are contiguous and then invokes f + /// on the contiguous ranges, to minimize per-row overhead + /// + /// ```text + /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ + /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ ┏━━━━━┓ │ ┌─────┐ │ ┌─────┐ + /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ ┃ 0 ┃ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ ┃ 0 ┃ │ │ 300 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ ┃ 1 ┃ │ │ 200 │ │ │ │NULL │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ────────▶ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ ┃ 2 ┃ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ ┃ 2 ┃ │ │ 100 │ │ │ │ f │ │ + /// │ └─────┘ │ │ └─────┘ │ └─────┘ ┗━━━━━┛ │ └─────┘ │ └─────┘ + /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ └─────────┘ └ ─ ─ ─ ─ ┘ + /// + /// values opt_filter logical group values opt_filter + /// index + /// ``` + fn invoke_per_accumulator( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + f: F, + ) -> Result<()> + where + F: Fn(&mut dyn Accumulator, &[ArrayRef]) -> Result<()>, + { + self.make_accumulators_if_needed(total_num_groups)?; + + // reorderes the input and filter so that values for group_indexes are contiguous. + // Then it invokes Accumulator::update / merge for each of those contiguous ranges + assert_eq!(values[0].len(), group_indices.len()); + + // figure out which input rows correspond to which groups Note + // that self.state.indices empty for all groups always (it is + // cleared out below) + for (idx, group_index) in group_indices.iter().enumerate() { + self.states[*group_index].indices.push(idx as u32); + } + + // groups_per_rows holds a list of group indexes that have + // any rows that need to be accumulated, stored in order of group_index + + let mut groups_with_rows = vec![]; + + // batch_indices holds indices in values, each group contiguously + let mut batch_indices = UInt32Builder::with_capacity(0); + + // offsets[i] is index into batch_indices where the rows for + // group_index i starts + let mut offsets = vec![0]; + + let mut offset_so_far = 0; + for (group_index, state) in self.states.iter_mut().enumerate() { + let indices = &state.indices; + if indices.is_empty() { + continue; + } + + groups_with_rows.push(group_index); + batch_indices.append_slice(indices); + offset_so_far += indices.len(); + offsets.push(offset_so_far); + } + let batch_indices = batch_indices.finish(); + + // reorder the values and opt_filter by batch_indices so that + // all values for each group are contiguous, then invoke the + // accumulator once per group with values + let values = get_arrayref_at_indices(values, &batch_indices)?; + let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?; + + // invoke each accumulator with the appropriate rows, first + // pulling the input arguments for this group into their own + // RecordBatch(es) + let iter = groups_with_rows.iter().zip(offsets.windows(2)); + + for (&group_idx, offsets) in iter { + let state = &mut self.states[group_idx]; + let size_pre = state.size(); + + let values_to_accumulate = + slice_and_maybe_filter(&values, opt_filter.as_ref(), offsets)?; + (f)(state.accumulator.as_mut(), &values_to_accumulate)?; + + // clear out the state + state.indices.clear(); + + let size_post = state.size(); + self.allocation_bytes += size_post.saturating_sub(size_pre); + } + Ok(()) + } +} + +impl GroupsAccumulator for GroupsAccumulatorAdapter { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.invoke_per_accumulator( + values, + group_indices, + opt_filter, + total_num_groups, + |accumulator, values_to_accumulate| { + accumulator.update_batch(values_to_accumulate) + }, + )?; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let states = std::mem::take(&mut self.states); + + // todo update memory usage + + let results: Vec = states + .into_iter() + .map(|state| state.accumulator.evaluate()) + .collect::>()?; + + let result = ScalarValue::iter_to_array(results); + self.reset_allocation(); + result + } + + fn state(&mut self) -> Result> { + let states = std::mem::take(&mut self.states); + + // todo update memory usage + + // each accumulator produces a potential vector of values + // which we need to form into columns + let mut results: Vec> = vec![]; + + for state in states { + let accumulator_state = state.accumulator.state()?; + results.resize_with(accumulator_state.len(), Vec::new); + for (idx, state_val) in accumulator_state.into_iter().enumerate() { + results[idx].push(state_val); + } + } + + // create an array for each intermediate column + let arrays = results + .into_iter() + .map(ScalarValue::iter_to_array) + .collect::>>()?; + + // double check each array has the same length (aka the + // accumulator was written correctly + if let Some(first_col) = arrays.get(0) { + for arr in &arrays { + assert_eq!(arr.len(), first_col.len()) + } + } + + self.reset_allocation(); + Ok(arrays) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.invoke_per_accumulator( + values, + group_indices, + opt_filter, + total_num_groups, + |accumulator, values_to_accumulate| { + accumulator.merge_batch(values_to_accumulate) + }, + )?; + Ok(()) + } + + fn size(&self) -> usize { + self.allocation_bytes + } +} + +fn get_filter_at_indices( + opt_filter: Option<&BooleanArray>, + indices: &PrimitiveArray, +) -> Result> { + opt_filter + .map(|filter| { + compute::take( + &filter, indices, None, // None: no index check + ) + }) + .transpose() + .map_err(DataFusionError::ArrowError) +} + +// Copied from physical-plan +pub(crate) fn slice_and_maybe_filter( + aggr_array: &[ArrayRef], + filter_opt: Option<&ArrayRef>, + offsets: &[usize], +) -> Result> { + let (offset, length) = (offsets[0], offsets[1] - offsets[0]); + let sliced_arrays: Vec = aggr_array + .iter() + .map(|array| array.slice(offset, length)) + .collect(); + + if let Some(f) = filter_opt { + let filter_array = f.slice(offset, length); + let filter_array = filter_array.as_boolean(); + + sliced_arrays + .iter() + .map(|array| { + compute::filter(array, filter_array).map_err(DataFusionError::ArrowError) + }) + .collect() + } else { + Ok(sliced_arrays) + } +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs new file mode 100644 index 000000000000..0f6dfdf045f4 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::AsArray; +use arrow_array::{ArrayRef, BooleanArray}; +use arrow_buffer::BooleanBufferBuilder; +use datafusion_common::Result; + +use crate::GroupsAccumulator; + +use super::accumulate::NullState; + +/// An accumulator that implements a single operation over +/// Boolean where the accumulated state is the same as the input +/// type (such as [`BitAndAssign`]) +/// +/// F: The function to apply to two elements. The first argument is +/// the existing value and should be updated with the second value +/// (e.g. [`BitAndAssign`] style). +/// +/// [`BitAndAssign`]: std::ops::BitAndAssign +#[derive(Debug)] +pub struct BooleanGroupsAccumulator +where + F: Fn(bool, bool) -> bool + Send + Sync, +{ + /// values per group + values: BooleanBufferBuilder, + + /// Track nulls in the input / filters + null_state: NullState, + + /// Function that computes the output + bool_fn: F, +} + +impl BooleanGroupsAccumulator +where + F: Fn(bool, bool) -> bool + Send + Sync, +{ + pub fn new(bitop_fn: F) -> Self { + Self { + values: BooleanBufferBuilder::new(0), + null_state: NullState::new(), + bool_fn: bitop_fn, + } + } +} + +impl GroupsAccumulator for BooleanGroupsAccumulator +where + F: Fn(bool, bool) -> bool + Send + Sync, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_boolean(); + + if self.values.len() < total_num_groups { + let new_groups = total_num_groups - self.values.len(); + self.values.append_n(new_groups, Default::default()); + } + + // NullState dispatches / handles tracking nulls and groups that saw no values + self.null_state.accumulate_boolean( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let current_value = self.values.get_bit(group_index); + let value = (self.bool_fn)(current_value, new_value); + self.values.set_bit(group_index, value); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let values = self.values.finish(); + let nulls = self.null_state.build(); + let values = BooleanArray::new(values, nulls); + Ok(Arc::new(values)) + } + + fn state(&mut self) -> Result> { + self.evaluate().map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are the same + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn size(&self) -> usize { + // capacity is in bits, so convert to bytes + self.values.capacity() / 8 + } +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs new file mode 100644 index 000000000000..5741aab7a24d --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Vectorized [`GroupsAccumulator`] + +pub(crate) mod accumulate; +mod adapter; +pub use adapter::GroupsAccumulatorAdapter; + +pub(crate) mod bool_op; +pub(crate) mod prim_op; + +use arrow_array::{ArrayRef, BooleanArray}; +use datafusion_common::Result; + +/// `GroupAccumulator` implements a single aggregate (e.g. AVG) and +/// stores the state for *all* groups internally. +/// +/// Each group is assigned a `group_index` by the hash table and each +/// accumulator manages the specific state, one per group_index. +/// +/// group_indexes are contiguous (there aren't gaps), and thus it is +/// expected that each GroupAccumulator will use something like `Vec<..>` +/// to store the group states. +pub trait GroupsAccumulator: Send { + /// Updates the accumulator's state from its arguments, encoded as + /// a vector of arrow [`ArrayRef`]s. + /// + /// * `values`: the input arguments to the accumulator + /// + /// * `group_indices`: To which groups do the rows in `values` + /// belong, group id) + /// + /// * `opt_filter`: if present, only update aggregate state using + /// `values[i]` if `opt_filter[i]` is true + /// + /// * `total_num_groups`: the number of groups (the largest + /// group_index is thus `total_num_groups - 1`) + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()>; + + /// Returns the final aggregate value for each group as a single + /// `RecordBatch`. + /// + /// The rows returned *must* be in group_index order: The value + /// for group_index 0, followed by 1, etc. + /// + /// OPEN QUESTION: Should this method take a "batch_size: usize" + /// and produce a `Vec` as output to avoid requiring + /// a contiguous intermediate buffer? + /// + /// For example, the `SUM` accumulator maintains a running sum, + /// and `evaluate` will produce that running sum as its output for + /// all groups, in group_index order + /// + /// This call should be treated as consuming (takes `self`) as no + /// other functions will be called after this. This can not + /// actually take `self` otherwise the trait would not be object + /// safe). The accumulator is free to release / reset it is + /// internal state after this call and error on any subsequent + /// call. + fn evaluate(&mut self) -> Result; + + /// Returns the intermediate aggregate state for this accumulator, + /// used for multi-phase grouping. + /// + /// The rows returned *must* be in group_index order: The value + /// for group_index 0, followed by 1, etc. Any group_index that + /// did not have values, should be null. + /// + /// For example, AVG returns two arrays: `SUM` and `COUNT`. + /// + /// Note more sophisticated internal state can be passed as + /// single `StructArray` rather than multiple arrays. + /// + /// This call should be treated as consuming, as described in the + /// comments of [`Self::evaluate`]. + fn state(&mut self) -> Result>; + + /// Merges intermediate state (from [`Self::state`]) into this + /// accumulator's values. + /// + /// For some aggregates (such as `SUM`), merge_batch is the same + /// as `update_batch`, but for some aggregrates (such as `COUNT`) + /// the operations differ. See [`Self::state`] for more details on how + /// state is used and merged. + /// + /// * `values`: arrays produced from calling `state` previously to the accumulator + /// + /// Other arguments are the same as for [`Self::update_batch`]; + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()>; + + /// Amount of memory used to store the state of this + /// accumulator. This function is called once per batch, so it + /// should be O(n) to compute + fn size(&self) -> usize; +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs new file mode 100644 index 000000000000..fe7fc7ecbaf0 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{array::AsArray, datatypes::ArrowPrimitiveType}; +use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; +use datafusion_common::Result; + +use crate::GroupsAccumulator; + +use super::accumulate::NullState; + +/// An accumulator that implements a single operation over +/// PrimtiveTypes where the accumulated state is the same as the input +/// type (such as [`BitAndAssign`]) +/// +/// F: The function to apply to two elements. The first argument is +/// the existing value and should be updated with the second value +/// (e.g. [`BitAndAssign`] style). +/// +/// [`BitAndAssign`]: std::ops::BitAndAssign +#[derive(Debug)] +pub struct PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, T::Native) + Send + Sync, +{ + /// values per group, stored as the native type + values: Vec, + + /// Track nulls in the input / filters + null_state: NullState, + + /// Function that computes the bitwise function + bitop_fn: F, +} + +impl PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, T::Native) + Send + Sync, +{ + pub fn new(bitop_fn: F) -> Self { + Self { + values: vec![], + null_state: NullState::new(), + bitop_fn, + } + } +} + +impl GroupsAccumulator for PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, T::Native) + Send + Sync, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_primitive::(); + + // update values + self.values + .resize_with(total_num_groups, || T::default_value()); + + // NullState dispatches / handles tracking nulls and groups that saw no values + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let value = &mut self.values[group_index]; + (self.bitop_fn)(value, new_value); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let values = std::mem::take(&mut self.values); + let nulls = self.null_state.build(); + let values = PrimitiveArray::::new(values.into(), nulls); // no copy + Ok(Arc::new(values)) + } + + fn state(&mut self) -> Result> { + self.evaluate().map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are the same + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn size(&self) -> usize { + self.values.capacity() * std::mem::size_of::() + } +} diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index e3c061dc1354..914299bcedbd 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::compute; use arrow::datatypes::{DataType, TimeUnit}; use arrow::{ @@ -35,9 +35,16 @@ use arrow::{ }, datatypes::Field, }; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::{ArrowNumericType, PrimitiveArray}; use datafusion_common::ScalarValue; use datafusion_common::{downcast_value, DataFusionError, Result}; use datafusion_expr::Accumulator; +use log::debug; use crate::aggregate::row_accumulator::{ is_row_accumulator_support_dtype, RowAccumulator, @@ -48,7 +55,9 @@ use arrow::array::Array; use arrow::array::Decimal128Array; use datafusion_row::accessor::RowAccessor; +use super::groups_accumulator::accumulate::NullState; use super::moving_min_max; +use super::utils::adjust_output_array; // Min/max aggregation can take Dictionary encode input but always produces unpacked // (aka non Dictionary) output. We need to adjust the output data type to reflect this. @@ -87,6 +96,15 @@ impl Max { } } +macro_rules! instantiate_min_max_accumulator { + ($SELF:expr, $NUMERICTYPE:ident, $MIN:expr) => {{ + Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::< + $NUMERICTYPE, + $MIN, + >::new(&$SELF.data_type))) + }}; +} + impl AggregateExpr for Max { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -125,6 +143,10 @@ impl AggregateExpr for Max { is_row_accumulator_support_dtype(&self.data_type) } + fn groups_accumulator_supported(&self) -> bool { + self.data_type.is_primitive() + } + fn create_row_accumulator( &self, start_index: usize, @@ -135,6 +157,36 @@ impl AggregateExpr for Max { ))) } + fn create_groups_accumulator(&self) -> Result> { + match self.data_type { + DataType::Int8 => instantiate_min_max_accumulator!(self, Int8Type, false), + DataType::Int16 => instantiate_min_max_accumulator!(self, Int16Type, false), + DataType::Int32 => instantiate_min_max_accumulator!(self, Int32Type, false), + DataType::Int64 => instantiate_min_max_accumulator!(self, Int64Type, false), + DataType::UInt8 => instantiate_min_max_accumulator!(self, UInt8Type, false), + DataType::UInt16 => instantiate_min_max_accumulator!(self, UInt16Type, false), + DataType::UInt32 => instantiate_min_max_accumulator!(self, UInt32Type, false), + DataType::UInt64 => instantiate_min_max_accumulator!(self, UInt64Type, false), + DataType::Float32 => { + instantiate_min_max_accumulator!(self, Float32Type, false) + } + DataType::Float64 => { + instantiate_min_max_accumulator!(self, Float64Type, false) + } + + DataType::Decimal128(_, _) => { + Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::< + Decimal128Type, + false, + >::new(&self.data_type))) + } + _ => Err(DataFusionError::NotImplemented(format!( + "MinMaxGroupsPrimitiveAccumulator not supported for {}", + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } @@ -835,6 +887,44 @@ impl AggregateExpr for Min { ))) } + fn groups_accumulator_supported(&self) -> bool { + Max::groups_accumulator_supported(&Max::new( + self.expr.clone(), + self.name.clone(), + self.data_type.clone(), + )) + } + + fn create_groups_accumulator(&self) -> Result> { + match self.data_type { + DataType::Int8 => instantiate_min_max_accumulator!(self, Int8Type, true), + DataType::Int16 => instantiate_min_max_accumulator!(self, Int16Type, true), + DataType::Int32 => instantiate_min_max_accumulator!(self, Int32Type, true), + DataType::Int64 => instantiate_min_max_accumulator!(self, Int64Type, true), + DataType::UInt8 => instantiate_min_max_accumulator!(self, UInt8Type, true), + DataType::UInt16 => instantiate_min_max_accumulator!(self, UInt16Type, true), + DataType::UInt32 => instantiate_min_max_accumulator!(self, UInt32Type, true), + DataType::UInt64 => instantiate_min_max_accumulator!(self, UInt64Type, true), + DataType::Float32 => { + instantiate_min_max_accumulator!(self, Float32Type, true) + } + DataType::Float64 => { + instantiate_min_max_accumulator!(self, Float64Type, true) + } + + DataType::Decimal128(_, _) => { + Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::< + Decimal128Type, + true, + >::new(&self.data_type))) + } + _ => Err(DataFusionError::NotImplemented(format!( + "MinMaxGroupsPrimitiveAccumulator not supported for {}", + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } @@ -1022,6 +1112,224 @@ impl RowAccumulator for MinRowAccumulator { } } +trait MinMax { + fn min() -> Self; + fn max() -> Self; +} + +impl MinMax for u8 { + fn min() -> Self { + u8::MIN + } + fn max() -> Self { + u8::MAX + } +} +impl MinMax for i8 { + fn min() -> Self { + i8::MIN + } + fn max() -> Self { + i8::MAX + } +} +impl MinMax for u16 { + fn min() -> Self { + u16::MIN + } + fn max() -> Self { + u16::MAX + } +} +impl MinMax for i16 { + fn min() -> Self { + i16::MIN + } + fn max() -> Self { + i16::MAX + } +} +impl MinMax for u32 { + fn min() -> Self { + u32::MIN + } + fn max() -> Self { + u32::MAX + } +} +impl MinMax for i32 { + fn min() -> Self { + i32::MIN + } + fn max() -> Self { + i32::MAX + } +} +impl MinMax for i64 { + fn min() -> Self { + i64::MIN + } + fn max() -> Self { + i64::MAX + } +} +impl MinMax for u64 { + fn min() -> Self { + u64::MIN + } + fn max() -> Self { + u64::MAX + } +} +impl MinMax for f32 { + fn min() -> Self { + f32::MIN + } + fn max() -> Self { + f32::MAX + } +} +impl MinMax for f64 { + fn min() -> Self { + f64::MIN + } + fn max() -> Self { + f64::MAX + } +} +impl MinMax for i128 { + fn min() -> Self { + i128::MIN + } + fn max() -> Self { + i128::MAX + } +} + +/// An accumulator to compute the min or max of [`PrimitiveArray`]. +/// Stores values as native/primitive type +#[derive(Debug)] +struct MinMaxGroupsPrimitiveAccumulator +where + T: ArrowNumericType + Send, + T::Native: MinMax, +{ + /// Min/max per group, stored as the native type + min_max: Vec, + + /// Track nulls in the input / filters + null_state: NullState, + + /// The output datatype (needed for decimal precision/scale) + data_type: DataType, +} + +impl MinMaxGroupsPrimitiveAccumulator +where + T: ArrowNumericType + Send, + T::Native: MinMax, +{ + pub fn new(data_type: &DataType) -> Self { + debug!( + "MinMaxGroupsPrimitiveAccumulator ({}, {})", + std::any::type_name::(), + MIN, + ); + + Self { + min_max: vec![], + null_state: NullState::new(), + data_type: data_type.clone(), + } + } +} + +impl GroupsAccumulator for MinMaxGroupsPrimitiveAccumulator +where + T: ArrowNumericType + Send, + T::Native: MinMax, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_primitive::(); + + self.min_max.resize_with(total_num_groups, || { + if MIN { + T::Native::max() + } else { + T::Native::min() + } + }); + + // NullState dispatches / handles tracking nulls and groups that saw no values + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let val = &mut self.min_max[group_index]; + match MIN { + true => { + if new_value < *val { + *val = new_value; + } + } + false => { + if new_value > *val { + *val = new_value; + } + } + } + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + Self::update_batch(self, values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self) -> Result { + let min_max = std::mem::take(&mut self.min_max); + let nulls = self.null_state.build(); + + let min_max = PrimitiveArray::::new(min_max.into(), nulls); // no copy + let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?; + + Ok(Arc::new(min_max)) + } + + // return arrays for min/max values + fn state(&mut self) -> Result> { + let nulls = self.null_state.build(); + + let min_max = std::mem::take(&mut self.min_max); + let min_max = PrimitiveArray::::new(min_max.into(), nulls); // zero copy + + let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?; + + Ok(vec![min_max]) + } + + fn size(&self) -> usize { + self.min_max.capacity() * std::mem::size_of::() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 9be6d5e1ba12..a21cddd62c63 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -25,6 +25,8 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; +use self::groups_accumulator::GroupsAccumulator; + pub(crate) mod approx_distinct; pub(crate) mod approx_median; pub(crate) mod approx_percentile_cont; @@ -45,6 +47,7 @@ pub(crate) mod median; #[macro_use] pub(crate) mod min_max; pub mod build_in; +pub(crate) mod groups_accumulator; mod hyperloglog; pub mod moving_min_max; pub mod row_accumulator; @@ -118,6 +121,25 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { ))) } + /// If the aggregate expression has a specialized + /// [`GroupsAccumulator`] implementation. If this returns true, + /// `[Self::create_groups_accumulator`] will be called. + fn groups_accumulator_supported(&self) -> bool { + false + } + + /// Return a specialized [`GroupsAccumulator`] that manages state for all groups + /// + /// For maximum performance, [`GroupsAccumulator`] should be + /// implemented in addition to [`Accumulator`]. + fn create_groups_accumulator(&self) -> Result> { + // TODO: The default should implement a wrapper over + // sef.create_accumulator + Err(DataFusionError::NotImplemented(format!( + "GroupsAccumulator hasn't been implemented for {self:?} yet" + ))) + } + /// Construct an expression that calculates the aggregate in reverse. /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). /// For aggregates that do not support calculation in reverse, diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index efa55f060264..c8e9a4028f40 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -15,14 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::aggregate::row_accumulator::{ + is_row_accumulator_support_dtype, RowAccumulator, +}; +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use arrow::array::Array; +use arrow::array::Decimal128Array; use arrow::compute; +use arrow::compute::kernels::cast; use arrow::datatypes::DataType; use arrow::{ array::{ @@ -31,18 +39,19 @@ use arrow::{ }, datatypes::Field, }; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Decimal128Type, Float32Type, Float64Type, Int32Type, Int64Type, UInt32Type, + UInt64Type, +}; +use arrow_array::{ArrowNativeTypeOp, ArrowNumericType, PrimitiveArray}; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; - -use crate::aggregate::row_accumulator::{ - is_row_accumulator_support_dtype, RowAccumulator, -}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::array::Decimal128Array; -use arrow::compute::cast; use datafusion_row::accessor::RowAccessor; +use log::debug; + +use super::groups_accumulator::accumulate::NullState; +use super::utils::adjust_output_array; /// SUM aggregate expression #[derive(Debug, Clone)] @@ -105,18 +114,11 @@ impl AggregateExpr for Sum { } fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "sum"), - self.data_type.clone(), - self.nullable, - ), - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - self.nullable, - ), - ]) + Ok(vec![Field::new( + format_state_name(&self.name, "sum"), + self.data_type.clone(), + self.nullable, + )]) } fn expressions(&self) -> Vec> { @@ -131,6 +133,10 @@ impl AggregateExpr for Sum { is_row_accumulator_support_dtype(&self.data_type) } + fn groups_accumulator_supported(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, @@ -141,12 +147,52 @@ impl AggregateExpr for Sum { ))) } + fn create_groups_accumulator(&self) -> Result> { + // instantiate specialized accumulator + match self.data_type { + DataType::UInt64 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, + &self.data_type, + ))), + DataType::Int64 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, + &self.data_type, + ))), + DataType::UInt32 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, + &self.data_type, + ))), + DataType::Int32 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, + &self.data_type, + ))), + DataType::Float32 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, + &self.data_type, + ))), + DataType::Float64 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, + &self.data_type, + ))), + DataType::Decimal128(_target_precision, _target_scale) => { + Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, + &self.data_type, + ))) + } + _ => Err(DataFusionError::NotImplemented(format!( + "SumGroupsAccumulator not supported for {}", + self.data_type + ))), + } + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + Ok(Box::new(SlidingSumAccumulator::try_new(&self.data_type)?)) } } @@ -164,10 +210,10 @@ impl PartialEq for Sum { } } +/// This accumulator computes SUM incrementally #[derive(Debug)] struct SumAccumulator { sum: ScalarValue, - count: u64, } impl SumAccumulator { @@ -175,12 +221,32 @@ impl SumAccumulator { pub fn try_new(data_type: &DataType) -> Result { Ok(Self { sum: ScalarValue::try_from(data_type)?, + }) + } +} + +/// This accumulator incrementally computes sums over a sliding window +#[derive(Debug)] +struct SlidingSumAccumulator { + sum: ScalarValue, + count: u64, +} + +impl SlidingSumAccumulator { + /// new sum accumulator + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + // start at zero + sum: ScalarValue::try_from(data_type)?, count: 0, }) } } -// returns the new value after sum with the new values, taking nullability into account +/// Sums the contents of the `$VALUES` array using the arrow compute +/// kernel, and return a `ScalarValue::$SCALAR`. +/// +/// Handles nullability macro_rules! typed_sum_delta_batch { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); @@ -322,6 +388,34 @@ pub(crate) fn update_avg_to_row( } impl Accumulator for SumAccumulator { + fn state(&self) -> Result> { + Ok(vec![self.sum.clone()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = sum_batch(values, &self.sum.get_datatype())?; + self.sum = self.sum.add(&delta)?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // sum(sum1, sum2, sum3, ...) = sum1 + sum2 + sum3 + ... + self.update_batch(states) + } + + fn evaluate(&self) -> Result { + // TODO: add the checker for overflow + // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. + Ok(self.sum.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() + } +} + +impl Accumulator for SlidingSumAccumulator { fn state(&self) -> Result> { Ok(vec![self.sum.clone(), ScalarValue::from(self.count)]) } @@ -424,6 +518,114 @@ impl RowAccumulator for SumRowAccumulator { } } +/// An accumulator to compute the sum of values in [`PrimitiveArray`] +#[derive(Debug)] +struct SumGroupsAccumulator +where + T: ArrowNumericType + Send, +{ + /// The type of the computed sum + sum_data_type: DataType, + + /// The type of the returned sum + return_data_type: DataType, + + /// Sums per group, stored as the native type + sums: Vec, + + /// Track nulls in the input / filters + null_state: NullState, +} + +impl SumGroupsAccumulator +where + T: ArrowNumericType + Send, +{ + pub fn new(sum_data_type: &DataType, return_data_type: &DataType) -> Self { + debug!( + "SumGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}", + std::any::type_name::() + ); + + Self { + return_data_type: sum_data_type.clone(), + sum_data_type: sum_data_type.clone(), + sums: vec![], + null_state: NullState::new(), + } + } +} + +impl GroupsAccumulator for SumGroupsAccumulator +where + T: ArrowNumericType + Send, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_primitive::(); + + // update sums + self.sums + .resize_with(total_num_groups, || T::default_value()); + + // NullState dispatches / handles tracking nulls and groups that saw no values + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self) -> Result { + let sums = std::mem::take(&mut self.sums); + let nulls = self.null_state.build(); + + let sums = PrimitiveArray::::new(sums.into(), nulls); // no copy + let sums = adjust_output_array(&self.return_data_type, Arc::new(sums))?; + + Ok(Arc::new(sums)) + } + + // return arrays for sums + fn state(&mut self) -> Result> { + let nulls = self.null_state.build(); + + let sums = std::mem::take(&mut self.sums); + let sums = Arc::new(PrimitiveArray::::new(sums.into(), nulls)); + + let sums = adjust_output_array(&self.sum_data_type, sums)?; + + Ok(vec![sums.clone() as ArrayRef]) + } + + fn size(&self) -> usize { + self.sums.capacity() * std::mem::size_of::() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index dbbe0c3f92c0..63587c925b43 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -20,6 +20,8 @@ use crate::{AggregateExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; +use arrow_array::cast::AsArray; +use arrow_array::types::Decimal128Type; use arrow_schema::{DataType, Field}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -145,6 +147,28 @@ pub fn calculate_result_decimal_for_avg( } } +/// Adjust array type metadata if needed +/// +/// Since `Decimal128Arrays` created from `Vec` have +/// default precision and scale, this function adjusts the output to +/// match `data_type`. +pub fn adjust_output_array( + data_type: &DataType, + array: ArrayRef, +) -> Result { + let array = match data_type { + DataType::Decimal128(p, s) => Arc::new( + array + .as_primitive::() + .clone() + .with_precision_and_scale(*p, *s)?, + ), + // no adjustment needed for other arrays + _ => array, + }; + Ok(array) +} + /// Downcast a `Box` or `Arc` /// and return the inner trait object as [`Any`](std::any::Any) so /// that it can be downcast to a specific implementation. diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 1484cf7ff52c..b695ee169eed 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -47,7 +47,9 @@ pub mod var_provider; pub mod window; // reexport this to maintain compatibility with anything that used from_slice previously +pub use aggregate::groups_accumulator::{GroupsAccumulator, GroupsAccumulatorAdapter}; pub use aggregate::AggregateExpr; + pub use equivalence::{ project_equivalence_properties, project_ordering_equivalence_properties, EquivalenceProperties, EquivalentClass, OrderingEquivalenceProperties,