Skip to content

Commit

Permalink
fix: Ensure ignore_nulls is respected in horizontal sum/mean (#20469)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Dec 31, 2024
1 parent c5cf3f9 commit dff1ad7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 31 deletions.
76 changes: 46 additions & 30 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::borrow::Cow;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::series::arithmetic::coerce_lhs_rhs;
use polars_core::utils::dtypes_to_supertype;
use polars_core::{with_match_physical_numeric_polars_type, POOL};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};

Expand Down Expand Up @@ -37,7 +38,7 @@ impl MinMaxHorizontal for DataFrame {
}
}

#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum NullStrategy {
Ignore,
Propagate,
Expand Down Expand Up @@ -194,20 +195,19 @@ pub fn sum_horizontal(
null_strategy: NullStrategy,
) -> PolarsResult<Option<Column>> {
validate_column_lengths(columns)?;
let ignore_nulls = null_strategy == NullStrategy::Ignore;

let apply_null_strategy = |s: Series, null_strategy: NullStrategy| -> PolarsResult<Series> {
if let NullStrategy::Ignore = null_strategy {
// if has nulls
if s.null_count() > 0 {
return s.fill_null(FillNullStrategy::Zero);
}
let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
if ignore_nulls && s.null_count() > 0 {
s.fill_null(FillNullStrategy::Zero)
} else {
Ok(s)
}
Ok(s)
};

let sum_fn = |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult<Series> {
let acc: Series = apply_null_strategy(acc, null_strategy)?;
let s = apply_null_strategy(s, null_strategy)?;
let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
let acc: Series = apply_null_strategy(acc)?;
let s = apply_null_strategy(s)?;
// This will do owned arithmetic and can be mutable
std::ops::Add::add(acc, s)
};
Expand All @@ -219,6 +219,20 @@ pub fn sum_horizontal(
.map(|c| c.as_materialized_series())
.collect::<Vec<_>>();

// If we have any null columns and null strategy is not `Ignore`, we can return immediately.
if !ignore_nulls && non_null_cols.len() < columns.len() {
// We must first determine the correct return dtype.
let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
DataType::Boolean => DataType::UInt32,
dt => dt,
};
return Ok(Some(Column::full_null(
columns[0].name().clone(),
columns[0].len(),
&return_dtype,
)));
}

match non_null_cols.len() {
0 => {
if columns.is_empty() {
Expand All @@ -229,23 +243,16 @@ pub fn sum_horizontal(
}
},
1 => Ok(Some(
apply_null_strategy(
if non_null_cols[0].dtype() == &DataType::Boolean {
non_null_cols[0].cast(&DataType::UInt32)?
} else {
non_null_cols[0].clone()
},
null_strategy,
)?
apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
non_null_cols[0].cast(&DataType::UInt32)?
} else {
non_null_cols[0].clone()
})?
.into(),
)),
2 => sum_fn(
non_null_cols[0].clone(),
non_null_cols[1].clone(),
null_strategy,
)
.map(Column::from)
.map(Some),
2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
.map(Column::from)
.map(Some),
_ => {
// the try_reduce_with is a bit slower in parallelism,
// but I don't think it matters here as we parallelize over columns, not over elements
Expand All @@ -254,7 +261,7 @@ pub fn sum_horizontal(
.into_par_iter()
.cloned()
.map(Ok)
.try_reduce_with(|l, r| sum_fn(l, r, null_strategy))
.try_reduce_with(sum_fn)
// We can unwrap because we started with at least 3 columns, so we always get a Some
.unwrap()
});
Expand Down Expand Up @@ -283,7 +290,8 @@ pub fn mean_horizontal(
);
}
let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
match columns.len() {
let num_rows = columns.len();
match num_rows {
0 => Ok(None),
1 => Ok(Some(match columns[0].dtype() {
dt if dt != &DataType::Float32 && !dt.is_decimal() => {
Expand Down Expand Up @@ -319,7 +327,7 @@ pub fn mean_horizontal(
// value lengths: len - null_count
let value_length: UInt32Chunked = (Column::new_scalar(
PlSmallStr::EMPTY,
Scalar::from(columns.len() as u32),
Scalar::from(num_rows as u32),
null_count.len(),
) - null_count)?
.u32()
Expand All @@ -328,10 +336,18 @@ pub fn mean_horizontal(

// make sure that we do not divide by zero
// by replacing with None
let dt = if sum
.as_ref()
.is_some_and(|s| s.dtype() == &DataType::Float32)
{
&DataType::Float32
} else {
&DataType::Float64
};
let value_length = value_length
.set(&value_length.equal(0), None)?
.into_column()
.cast(&DataType::Float64)?;
.cast(dt)?;

sum.map(|sum| std::ops::Div::div(&sum, &value_length))
.transpose()
Expand Down
10 changes: 9 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,15 @@ impl FunctionExpr {
}
})
},
MeanHorizontal { .. } => mapper.map_to_float_dtype(),
MeanHorizontal { .. } => {
mapper.map_to_supertype().map(|mut f| {
match f.dtype {
dt @ DataType::Float32 => { f.dtype = dt; },
_ => { f.dtype = DataType::Float64; },
};
f
})
}
#[cfg(feature = "ewma")]
EwmMean { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma_by")]
Expand Down
65 changes: 65 additions & 0 deletions py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,68 @@ def test_horizontal_sum_boolean_with_null() -> None:
)

assert_frame_equal(out.collect(), expected_df)


@pytest.mark.parametrize("ignore_nulls", [True, False])
@pytest.mark.parametrize(
("dtype_in", "dtype_out"),
[
(pl.Null, pl.Null),
(pl.Boolean, pl.UInt32),
(pl.UInt8, pl.UInt8),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
(pl.Decimal(None, 5), pl.Decimal(None, 5)),
],
)
def test_horizontal_sum_with_null_col_ignore_strategy(
dtype_in: PolarsDataType,
dtype_out: PolarsDataType,
ignore_nulls: bool,
) -> None:
lf = pl.LazyFrame(
{
"null": [None, None, None],
"s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False),
"s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False),
}
)
result = lf.select(pl.sum_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls))
if ignore_nulls and dtype_in != pl.Null:
values = [2, 0, 1]
else:
values = [None, None, None] # type: ignore[list-item]
expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))
assert_frame_equal(result, expected)


@pytest.mark.parametrize("ignore_nulls", [True, False])
@pytest.mark.parametrize(
("dtype_in", "dtype_out"),
[
(pl.Null, pl.Float64),
(pl.Boolean, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_horizontal_mean_with_null_col_ignore_strategy(
dtype_in: PolarsDataType,
dtype_out: PolarsDataType,
ignore_nulls: bool,
) -> None:
lf = pl.LazyFrame(
{
"null": [None, None, None],
"s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False),
"s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False),
}
)
result = lf.select(pl.mean_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls))
if ignore_nulls and dtype_in != pl.Null:
values = [1, 0, 1]
else:
values = [None, None, None] # type: ignore[list-item]
expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))
assert_frame_equal(result, expected)

0 comments on commit dff1ad7

Please sign in to comment.