From 17b99fcd9dda009df9e24d66fa3cbb4916f8c02d Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 31 Mar 2023 16:35:32 +0200 Subject: [PATCH] perf(rust, python): rechunk before aggs (#7903) --- .../src/frame/groupby/aggregations/mod.rs | 203 +++++++++--------- 1 file changed, 105 insertions(+), 98 deletions(-) diff --git a/polars/polars-core/src/frame/groupby/aggregations/mod.rs b/polars/polars-core/src/frame/groupby/aggregations/mod.rs index 8ecc8919c007..a3df2a47d0dd 100644 --- a/polars/polars-core/src/frame/groupby/aggregations/mod.rs +++ b/polars/polars-core/src/frame/groupby/aggregations/mod.rs @@ -510,15 +510,18 @@ where return Series::full_null(ca.name(), groups.len(), ca.dtype()); } match groups { - GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { - debug_assert!(idx.len() <= ca.len()); - if idx.is_empty() { - return None; - } - let take = { ca.take_unchecked(idx.into()) }; - // checked with invalid quantile check - take._quantile(quantile, interpol).unwrap_unchecked() - }), + GroupsProxy::Idx(groups) => { + let ca = ca.rechunk(); + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let take = { ca.take_unchecked(idx.into()) }; + // checked with invalid quantile check + take._quantile(quantile, interpol).unwrap_unchecked() + }) + } GroupsProxy::Slice { groups, .. } => { if _use_rolling_kernels(groups, ca.chunks()) { // this cast is a no-op for floats @@ -574,14 +577,17 @@ where K: PolarsNumericType, { match groups { - GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { - debug_assert!(idx.len() <= ca.len()); - if idx.is_empty() { - return None; - } - let take = { ca.take_unchecked(idx.into()) }; - take._median() - }), + GroupsProxy::Idx(groups) => { + let ca = ca.rechunk(); + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let take = { ca.take_unchecked(idx.into()) }; + take._median() + }) + } GroupsProxy::Slice { .. } => { agg_quantile_generic::(ca, groups, 0.5, QuantileInterpolOptions::Linear) } @@ -610,34 +616,33 @@ where _ => {} } match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { - debug_assert!(idx.len() <= self.len()); - if idx.is_empty() { - None - } else if idx.len() == 1 { - self.get(first as usize) - } else { - match (self.has_validity(), self.chunks.len()) { - (false, 1) => Some(take_agg_no_null_primitive_iter_unchecked( - self.downcast_iter().next().unwrap(), + GroupsProxy::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + _agg_helper_idx::(groups, |(first, idx)| { + debug_assert!(idx.len() <= arr.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get(first as usize) + } else if arr.null_count() == 0 { + Some(take_agg_no_null_primitive_iter_unchecked( + arr, idx.iter().map(|i| *i as usize), take_min, T::Native::max_value(), - )), - (_, 1) => take_agg_primitive_iter_unchecked::( - self.downcast_iter().next().unwrap(), + )) + } else { + take_agg_primitive_iter_unchecked::( + arr, idx.iter().map(|i| *i as usize), take_min, T::Native::max_value(), idx.len() as IdxSize, - ), - _ => { - let take = { self.take_unchecked(idx.into()) }; - take.min() - } + ) } - } - }), + }) + } GroupsProxy::Slice { groups: groups_slice, .. @@ -688,36 +693,35 @@ where } match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { - debug_assert!(idx.len() <= self.len()); - if idx.is_empty() { - None - } else if idx.len() == 1 { - self.get(first as usize) - } else { - match (self.has_validity(), self.chunks.len()) { - (false, 1) => Some({ + GroupsProxy::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + _agg_helper_idx::(groups, |(first, idx)| { + debug_assert!(idx.len() <= arr.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get(first as usize) + } else if arr.null_count() == 0 { + Some({ take_agg_no_null_primitive_iter_unchecked( - self.downcast_iter().next().unwrap(), + arr, idx.iter().map(|i| *i as usize), take_max, T::Native::min_value(), ) - }), - (_, 1) => take_agg_primitive_iter_unchecked::( - self.downcast_iter().next().unwrap(), + }) + } else { + take_agg_primitive_iter_unchecked::( + arr, idx.iter().map(|i| *i as usize), take_max, T::Native::min_value(), idx.len() as IdxSize, - ), - _ => { - let take = { self.take_unchecked(idx.into()) }; - take.max() - } + ) } - } - }), + }) + } GroupsProxy::Slice { groups: groups_slice, .. @@ -757,36 +761,33 @@ where pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { - debug_assert!(idx.len() <= self.len()); - if idx.is_empty() { - None - } else if idx.len() == 1 { - self.get(first as usize) - } else { - match (self.has_validity(), self.chunks.len()) { - (false, 1) => Some({ - take_agg_no_null_primitive_iter_unchecked( - self.downcast_iter().next().unwrap(), - idx.iter().map(|i| *i as usize), - |a, b| a + b, - T::Native::zero(), - ) - }), - (_, 1) => take_agg_primitive_iter_unchecked::( - self.downcast_iter().next().unwrap(), + GroupsProxy::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + _agg_helper_idx::(groups, |(first, idx)| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get(first as usize) + } else if arr.null_count() == 0 { + Some(take_agg_no_null_primitive_iter_unchecked( + arr, + idx.iter().map(|i| *i as usize), + |a, b| a + b, + T::Native::zero(), + )) + } else { + take_agg_primitive_iter_unchecked::( + arr, idx.iter().map(|i| *i as usize), |a, b| a + b, T::Native::zero(), idx.len() as IdxSize, - ), - _ => { - let take = { self.take_unchecked(idx.into()) }; - take.sum() - } + ) } - } - }), + }) + } GroupsProxy::Slice { groups, .. } => { if _use_rolling_kernels(groups, self.chunks()) { let arr = self.downcast_iter().next().unwrap(); @@ -921,14 +922,17 @@ where pub(crate) unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { let ca = &self.0; match groups { - GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { - debug_assert!(idx.len() <= ca.len()); - if idx.is_empty() { - return None; - } - let take = { ca.take_unchecked(idx.into()) }; - take.var(ddof) - }), + GroupsProxy::Idx(groups) => { + let ca = ca.rechunk(); + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let take = { ca.take_unchecked(idx.into()) }; + take.var(ddof) + }) + } GroupsProxy::Slice { groups, .. } => { if _use_rolling_kernels(groups, self.chunks()) { let arr = self.downcast_iter().next().unwrap(); @@ -965,14 +969,17 @@ where pub(crate) unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { let ca = &self.0; match groups { - GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { - debug_assert!(idx.len() <= ca.len()); - if idx.is_empty() { - return None; - } - let take = { ca.take_unchecked(idx.into()) }; - take.std(ddof) - }), + GroupsProxy::Idx(groups) => { + let ca = ca.rechunk(); + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let take = { ca.take_unchecked(idx.into()) }; + take.std(ddof) + }) + } GroupsProxy::Slice { groups, .. } => { if _use_rolling_kernels(groups, self.chunks()) { let arr = self.downcast_iter().next().unwrap();