Skip to content

Commit

Permalink
refactor(rust): MinMaxKernel in primitive/binary parquet stats (#17158)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Jun 24, 2024
1 parent c8166a8 commit afb4741
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 53 deletions.
70 changes: 34 additions & 36 deletions crates/polars-compute/src/min_max/dyn_array.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
use arrow::array::{Array, BooleanArray, PrimitiveArray};
use arrow::scalar::{BooleanScalar, PrimitiveScalar, Scalar};
use arrow::array::{
Array, BinaryArray, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8Array, Utf8ViewArray,
};
use arrow::scalar::{BinaryScalar, BinaryViewScalar, BooleanScalar, PrimitiveScalar, Scalar};

use crate::min_max::MinMaxKernel;

macro_rules! call_op {
($T:ty, $scalar:ty, $(=> ($($arg:expr),+),)? $arr:expr, $op:path) => {{
($T:ty, $scalar:ty, $arr:expr, $op:path) => {{
let arr: &$T = $arr.as_any().downcast_ref().unwrap();
$op(arr).map(|v| Box::new(<$scalar>::new(Some(v))) as Box<dyn Scalar>)
}};
(dt: $T:ty, $scalar:ty, $arr:expr, $op:path) => {{
let arr: &$T = $arr.as_any().downcast_ref().unwrap();
$op(arr).map(|v| Box::new(<$scalar>::new($($($arg,)+)? Some(v))) as Box<dyn Scalar>)
$op(arr)
.map(|v| Box::new(<$scalar>::new(arr.data_type().clone(), Some(v))) as Box<dyn Scalar>)
}};
}

Expand All @@ -14,40 +22,30 @@ macro_rules! call {
let arr = $arr;

use arrow::datatypes::{PhysicalType as PH, PrimitiveType as PR};
use PrimitiveArray as PArr;
use PrimitiveScalar as PScalar;
match arr.data_type().to_physical_type() {
PH::Boolean => call_op!(BooleanArray, BooleanScalar, arr, $op),
PH::Primitive(PR::Int8) => call_op!(PrimitiveArray<i8>, PrimitiveScalar<i8>, => (arr.data_type().clone()), arr, $op),
PH::Primitive(PR::Int16) => {
call_op!(PrimitiveArray<i16>, PrimitiveScalar<i16>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::Int32) => {
call_op!(PrimitiveArray<i32>, PrimitiveScalar<i32>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::Int64) => {
call_op!(PrimitiveArray<i64>, PrimitiveScalar<i64>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::Int128) => {
call_op!(PrimitiveArray<i128>, PrimitiveScalar<i128>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::UInt8) => call_op!(PrimitiveArray<u8>, PrimitiveScalar<u8>, => (arr.data_type().clone()), arr, $op),
PH::Primitive(PR::UInt16) => {
call_op!(PrimitiveArray<u16>, PrimitiveScalar<u16>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::UInt32) => {
call_op!(PrimitiveArray<u32>, PrimitiveScalar<u32>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::UInt64) => {
call_op!(PrimitiveArray<u64>, PrimitiveScalar<u64>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::UInt128) => {
call_op!(PrimitiveArray<u128>, PrimitiveScalar<u128>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::Float32) => {
call_op!(PrimitiveArray<f32>, PrimitiveScalar<f32>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::Float64) => {
call_op!(PrimitiveArray<f64>, PrimitiveScalar<f64>, => (arr.data_type().clone()), arr, $op)
},
PH::Primitive(PR::Int8) => call_op!(dt: PArr<i8>, PScalar<i8>, arr, $op),
PH::Primitive(PR::Int16) => call_op!(dt: PArr<i16>, PScalar<i16>, arr, $op),
PH::Primitive(PR::Int32) => call_op!(dt: PArr<i32>, PScalar<i32>, arr, $op),
PH::Primitive(PR::Int64) => call_op!(dt: PArr<i64>, PScalar<i64>, arr, $op),
PH::Primitive(PR::Int128) => call_op!(dt: PArr<i128>, PScalar<i128>, arr, $op),
PH::Primitive(PR::UInt8) => call_op!(dt: PArr<u8>, PScalar<u8>, arr, $op),
PH::Primitive(PR::UInt16) => call_op!(dt: PArr<u16>, PScalar<u16>, arr, $op),
PH::Primitive(PR::UInt32) => call_op!(dt: PArr<u32>, PScalar<u32>, arr, $op),
PH::Primitive(PR::UInt64) => call_op!(dt: PArr<u64>, PScalar<u64>, arr, $op),
PH::Primitive(PR::UInt128) => call_op!(dt: PArr<u128>, PScalar<u128>, arr, $op),
PH::Primitive(PR::Float32) => call_op!(dt: PArr<f32>, PScalar<f32>, arr, $op),
PH::Primitive(PR::Float64) => call_op!(dt: PArr<f64>, PScalar<f64>, arr, $op),

PH::BinaryView => call_op!(BinaryViewArray, BinaryViewScalar<[u8]>, arr, $op),
PH::Utf8View => call_op!(Utf8ViewArray, BinaryViewScalar<str>, arr, $op),

PH::Binary => call_op!(BinaryArray<i32>, BinaryScalar<i32>, arr, $op),
PH::LargeBinary => call_op!(BinaryArray<i64>, BinaryScalar<i64>, arr, $op),
PH::Utf8 => call_op!(Utf8Array<i32>, BinaryScalar<i32>, arr, $op),
PH::LargeUtf8 => call_op!(Utf8Array<i64>, BinaryScalar<i64>, arr, $op),

_ => todo!("Dynamic MinMax is not yet implemented for {:?}", arr.data_type()),
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-compute/src/min_max/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub use self::dyn_array::{
dyn_array_min_propagate_nan,
};

// Low-level min/max kernel.
/// Low-level min/max kernel.
pub trait MinMaxKernel {
type Scalar<'a>: MinMax
where
Expand Down
68 changes: 66 additions & 2 deletions crates/polars-compute/src/min_max/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use arrow::array::{Array, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8ViewArray};
use arrow::types::NativeType;
use arrow::array::{
Array, BinaryArray, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8Array, Utf8ViewArray,
};
use arrow::types::{NativeType, Offset};
use polars_utils::min_max::MinMax;

use super::MinMaxKernel;
Expand Down Expand Up @@ -149,3 +151,65 @@ impl MinMaxKernel for Utf8ViewArray {
self.max_ignore_nan_kernel()
}
}

impl<O: Offset> MinMaxKernel for BinaryArray<O> {
type Scalar<'a> = &'a [u8];

fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
if self.null_count() == 0 {
self.values_iter().reduce(MinMax::min_ignore_nan)
} else {
self.non_null_values_iter().reduce(MinMax::min_ignore_nan)
}
}

fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
if self.null_count() == 0 {
self.values_iter().reduce(MinMax::max_ignore_nan)
} else {
self.non_null_values_iter().reduce(MinMax::max_ignore_nan)
}
}

#[inline(always)]
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.min_ignore_nan_kernel()
}

#[inline(always)]
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.max_ignore_nan_kernel()
}
}

impl<O: Offset> MinMaxKernel for Utf8Array<O> {
type Scalar<'a> = &'a str;

#[inline(always)]
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.to_binary().min_ignore_nan_kernel().map(|s| unsafe {
// SAFETY: the lifetime is the same, and it is valid UTF-8.
#[allow(clippy::transmute_bytes_to_str)]
std::mem::transmute::<&[u8], &str>(s)
})
}

#[inline(always)]
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.to_binary().max_ignore_nan_kernel().map(|s| unsafe {
// SAFETY: the lifetime is the same, and it is valid UTF-8.
#[allow(clippy::transmute_bytes_to_str)]
std::mem::transmute::<&[u8], &str>(s)
})
}

#[inline(always)]
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.min_ignore_nan_kernel()
}

#[inline(always)]
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.max_ignore_nan_kernel()
}
}
4 changes: 4 additions & 0 deletions crates/polars-compute/src/min_max/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ where
F: FnMut(Simd<T, N>, Simd<T, N>) -> Simd<T, N>,
LaneCount<N>: SupportedLaneCount,
{
if arr.is_empty() {
return None;
}

let mut arr_chunks = arr.chunks_exact(N);

let identity = Simd::splat(scalar_identity);
Expand Down
18 changes: 4 additions & 14 deletions crates/polars-parquet/src/arrow/write/binary/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,29 +91,19 @@ pub(crate) fn build_statistics<O: Offset>(
primitive_type: PrimitiveType,
options: &StatisticsOptions,
) -> ParquetStatistics {
use polars_compute::min_max::MinMaxKernel;

BinaryStatistics {
primitive_type,
null_count: options.null_count.then_some(array.null_count() as i64),
distinct_count: None,
max_value: options
.max_value
.then(|| {
array
.iter()
.flatten()
.max_by(|x, y| ord_binary(x, y))
.map(|x| x.to_vec())
})
.then(|| array.max_propagate_nan_kernel().map(<[u8]>::to_vec))
.flatten(),
min_value: options
.min_value
.then(|| {
array
.iter()
.flatten()
.min_by(|x, y| ord_binary(x, y))
.map(|x| x.to_vec())
})
.then(|| array.min_propagate_nan_kernel().map(<[u8]>::to_vec))
.flatten(),
}
.serialize()
Expand Down

0 comments on commit afb4741

Please sign in to comment.