Skip to content

Commit

Permalink
perf(rust, python): rechunk before aggs (#7903)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Mar 31, 2023
1 parent cbaeddf commit 17b99fc
Showing 1 changed file with 105 additions and 98 deletions.
203 changes: 105 additions & 98 deletions polars/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,18 @@ where
return Series::full_null(ca.name(), groups.len(), ca.dtype());
}
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<K, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
// checked with invalid quantile check
take._quantile(quantile, interpol).unwrap_unchecked()
}),
GroupsProxy::Idx(groups) => {
let ca = ca.rechunk();
agg_helper_idx_on_all::<K, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
// checked with invalid quantile check
take._quantile(quantile, interpol).unwrap_unchecked()
})
}
GroupsProxy::Slice { groups, .. } => {
if _use_rolling_kernels(groups, ca.chunks()) {
// this cast is a no-op for floats
Expand Down Expand Up @@ -574,14 +577,17 @@ where
K: PolarsNumericType,
{
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<K, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
take._median()
}),
GroupsProxy::Idx(groups) => {
let ca = ca.rechunk();
agg_helper_idx_on_all::<K, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
take._median()
})
}
GroupsProxy::Slice { .. } => {
agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileInterpolOptions::Linear)
}
Expand Down Expand Up @@ -610,34 +616,33 @@ where
_ => {}
}
match groups {
GroupsProxy::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
self.get(first as usize)
} else {
match (self.has_validity(), self.chunks.len()) {
(false, 1) => Some(take_agg_no_null_primitive_iter_unchecked(
self.downcast_iter().next().unwrap(),
GroupsProxy::Idx(groups) => {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
_agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= arr.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
arr.get(first as usize)
} else if arr.null_count() == 0 {
Some(take_agg_no_null_primitive_iter_unchecked(
arr,
idx.iter().map(|i| *i as usize),
take_min,
T::Native::max_value(),
)),
(_, 1) => take_agg_primitive_iter_unchecked::<T::Native, _, _>(
self.downcast_iter().next().unwrap(),
))
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
arr,
idx.iter().map(|i| *i as usize),
take_min,
T::Native::max_value(),
idx.len() as IdxSize,
),
_ => {
let take = { self.take_unchecked(idx.into()) };
take.min()
}
)
}
}
}),
})
}
GroupsProxy::Slice {
groups: groups_slice,
..
Expand Down Expand Up @@ -688,36 +693,35 @@ where
}

match groups {
GroupsProxy::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
self.get(first as usize)
} else {
match (self.has_validity(), self.chunks.len()) {
(false, 1) => Some({
GroupsProxy::Idx(groups) => {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
_agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= arr.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
arr.get(first as usize)
} else if arr.null_count() == 0 {
Some({
take_agg_no_null_primitive_iter_unchecked(
self.downcast_iter().next().unwrap(),
arr,
idx.iter().map(|i| *i as usize),
take_max,
T::Native::min_value(),
)
}),
(_, 1) => take_agg_primitive_iter_unchecked::<T::Native, _, _>(
self.downcast_iter().next().unwrap(),
})
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
arr,
idx.iter().map(|i| *i as usize),
take_max,
T::Native::min_value(),
idx.len() as IdxSize,
),
_ => {
let take = { self.take_unchecked(idx.into()) };
take.max()
}
)
}
}
}),
})
}
GroupsProxy::Slice {
groups: groups_slice,
..
Expand Down Expand Up @@ -757,36 +761,33 @@ where

pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series {
match groups {
GroupsProxy::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
self.get(first as usize)
} else {
match (self.has_validity(), self.chunks.len()) {
(false, 1) => Some({
take_agg_no_null_primitive_iter_unchecked(
self.downcast_iter().next().unwrap(),
idx.iter().map(|i| *i as usize),
|a, b| a + b,
T::Native::zero(),
)
}),
(_, 1) => take_agg_primitive_iter_unchecked::<T::Native, _, _>(
self.downcast_iter().next().unwrap(),
GroupsProxy::Idx(groups) => {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
_agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
arr.get(first as usize)
} else if arr.null_count() == 0 {
Some(take_agg_no_null_primitive_iter_unchecked(
arr,
idx.iter().map(|i| *i as usize),
|a, b| a + b,
T::Native::zero(),
))
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
arr,
idx.iter().map(|i| *i as usize),
|a, b| a + b,
T::Native::zero(),
idx.len() as IdxSize,
),
_ => {
let take = { self.take_unchecked(idx.into()) };
take.sum()
}
)
}
}
}),
})
}
GroupsProxy::Slice { groups, .. } => {
if _use_rolling_kernels(groups, self.chunks()) {
let arr = self.downcast_iter().next().unwrap();
Expand Down Expand Up @@ -921,14 +922,17 @@ where
pub(crate) unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series {
let ca = &self.0;
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
take.var(ddof)
}),
GroupsProxy::Idx(groups) => {
let ca = ca.rechunk();
agg_helper_idx_on_all::<T, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
take.var(ddof)
})
}
GroupsProxy::Slice { groups, .. } => {
if _use_rolling_kernels(groups, self.chunks()) {
let arr = self.downcast_iter().next().unwrap();
Expand Down Expand Up @@ -965,14 +969,17 @@ where
pub(crate) unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series {
let ca = &self.0;
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
take.std(ddof)
}),
GroupsProxy::Idx(groups) => {
let ca = ca.rechunk();
agg_helper_idx_on_all::<T, _>(groups, |idx| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = { ca.take_unchecked(idx.into()) };
take.std(ddof)
})
}
GroupsProxy::Slice { groups, .. } => {
if _use_rolling_kernels(groups, self.chunks()) {
let arr = self.downcast_iter().next().unwrap();
Expand Down

0 comments on commit 17b99fc

Please sign in to comment.