Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure ignore_nulls is respected in horizontal sum/mean #20469

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::borrow::Cow;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::series::arithmetic::coerce_lhs_rhs;
use polars_core::utils::dtypes_to_supertype;
use polars_core::{with_match_physical_numeric_polars_type, POOL};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};

Expand Down Expand Up @@ -37,7 +38,7 @@ impl MinMaxHorizontal for DataFrame {
}
}

#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum NullStrategy {
Ignore,
Propagate,
Expand Down Expand Up @@ -192,20 +193,19 @@ pub fn sum_horizontal(
null_strategy: NullStrategy,
) -> PolarsResult<Option<Column>> {
validate_column_lengths(columns)?;
let ignore_nulls = null_strategy == NullStrategy::Ignore;

let apply_null_strategy = |s: Series, null_strategy: NullStrategy| -> PolarsResult<Series> {
if let NullStrategy::Ignore = null_strategy {
// if has nulls
if s.null_count() > 0 {
return s.fill_null(FillNullStrategy::Zero);
}
let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
if ignore_nulls && s.null_count() > 0 {
s.fill_null(FillNullStrategy::Zero)
} else {
Ok(s)
}
Ok(s)
};

let sum_fn = |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult<Series> {
let acc: Series = apply_null_strategy(acc, null_strategy)?;
let s = apply_null_strategy(s, null_strategy)?;
let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
let acc: Series = apply_null_strategy(acc)?;
let s = apply_null_strategy(s)?;
// This will do owned arithmetic and can be mutable
std::ops::Add::add(acc, s)
};
Expand All @@ -217,6 +217,20 @@ pub fn sum_horizontal(
.map(|c| c.as_materialized_series())
.collect::<Vec<_>>();

// If we have any null columns and null strategy is not `Ignore`, we can return immediately.
if !ignore_nulls && non_null_cols.len() < columns.len() {
// We must first determine the correct return dtype.
let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
DataType::Boolean => DataType::UInt32,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be in this PR, but this should actually be IndexType.

dt => dt,
};
return Ok(Some(Column::full_null(
columns[0].name().clone(),
columns[0].len(),
&return_dtype,
)));
}

match non_null_cols.len() {
0 => {
if columns.is_empty() {
Expand All @@ -227,23 +241,16 @@ pub fn sum_horizontal(
}
},
1 => Ok(Some(
apply_null_strategy(
if non_null_cols[0].dtype() == &DataType::Boolean {
non_null_cols[0].cast(&DataType::UInt32)?
} else {
non_null_cols[0].clone()
},
null_strategy,
)?
apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
non_null_cols[0].cast(&DataType::UInt32)?
} else {
non_null_cols[0].clone()
})?
.into(),
)),
2 => sum_fn(
non_null_cols[0].clone(),
non_null_cols[1].clone(),
null_strategy,
)
.map(Column::from)
.map(Some),
2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
.map(Column::from)
.map(Some),
_ => {
// the try_reduce_with is a bit slower in parallelism,
// but I don't think it matters here as we parallelize over columns, not over elements
Expand All @@ -252,7 +259,7 @@ pub fn sum_horizontal(
.into_par_iter()
.cloned()
.map(Ok)
.try_reduce_with(|l, r| sum_fn(l, r, null_strategy))
.try_reduce_with(sum_fn)
// We can unwrap because we started with at least 3 columns, so we always get a Some
.unwrap()
});
Expand Down Expand Up @@ -281,7 +288,8 @@ pub fn mean_horizontal(
);
}
let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
match columns.len() {
let num_rows = columns.len();
match num_rows {
0 => Ok(None),
1 => Ok(Some(match columns[0].dtype() {
dt if dt != &DataType::Float32 && !dt.is_decimal() => {
Expand Down Expand Up @@ -317,7 +325,7 @@ pub fn mean_horizontal(
// value lengths: len - null_count
let value_length: UInt32Chunked = (Column::new_scalar(
PlSmallStr::EMPTY,
Scalar::from(columns.len() as u32),
Scalar::from(num_rows as u32),
null_count.len(),
) - null_count)?
.u32()
Expand All @@ -326,10 +334,18 @@ pub fn mean_horizontal(

// make sure that we do not divide by zero
// by replacing with None
let dt = if sum
.as_ref()
.is_some_and(|s| s.dtype() == &DataType::Float32)
{
&DataType::Float32
} else {
&DataType::Float64
};
let value_length = value_length
.set(&value_length.equal(0), None)?
.into_column()
.cast(&DataType::Float64)?;
.cast(dt)?;

sum.map(|sum| std::ops::Div::div(&sum, &value_length))
.transpose()
Expand Down
10 changes: 9 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,15 @@ impl FunctionExpr {
}
})
},
MeanHorizontal { .. } => mapper.map_to_float_dtype(),
MeanHorizontal { .. } => {
mapper.map_to_supertype().map(|mut f| {
match f.dtype {
dt @ DataType::Float32 => { f.dtype = dt; },
_ => { f.dtype = DataType::Float64; },
};
f
})
}
#[cfg(feature = "ewma")]
EwmMean { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma_by")]
Expand Down
65 changes: 65 additions & 0 deletions py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,68 @@ def test_horizontal_sum_boolean_with_null() -> None:
)

assert_frame_equal(out.collect(), expected_df)


@pytest.mark.parametrize("ignore_nulls", [True, False])
@pytest.mark.parametrize(
("dtype_in", "dtype_out"),
[
(pl.Null, pl.Null),
(pl.Boolean, pl.UInt32),
(pl.UInt8, pl.UInt8),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
(pl.Decimal(None, 5), pl.Decimal(None, 5)),
],
)
def test_horizontal_sum_with_null_col_ignore_strategy(
dtype_in: PolarsDataType,
dtype_out: PolarsDataType,
ignore_nulls: bool,
) -> None:
lf = pl.LazyFrame(
{
"null": [None, None, None],
"s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False),
"s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False),
}
)
result = lf.select(pl.sum_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls))
if ignore_nulls and dtype_in != pl.Null:
values = [2, 0, 1]
else:
values = [None, None, None] # type: ignore[list-item]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. If our sum doesn't ignore nulls, it doesn't propagate them, but replaces them with the identity: 0.

The horizontal semantics should be the same as the vertical semantics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have it reversed.

  • Ignore nulls: 1 + 2 + null = 3 (nulls replaced by 0)
  • Don't ignore nulls: 1 + 2 + null = null (nulls are propagated)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that our sum is agnostic to nulls. I think we made a mistake exposing this to sum_horizontal as our vertical sum is agnostic to nulls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you're saying is that the ignore_nulls parameter should be removed entirely, and the ignore_nulls=True behavior should simply be the default. Is that correct? If so, that sounds like a breaking change.

Given that the current implementation has the parameter, and that there is an issue with it (pl.Null columns are not subject to the parameter, but nulls in other columns are), is the path forward to accept this PR (if I didn't mess anything up! which I don't think I did), and 2) remove the parameter in a follow-up PR to be merged in 1.19.0?

This PR contains a small fix for the float32 case where mean_horizontal returns f64 for f32 columns. I could make that a separate PR as well if you don't want to accept this one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think you're right. Consider it an observation. ;) Will take a look a bit later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ritchie. I have a follow-up to this adding temporals for mean horizontal but I'll wait for this one first.

expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))
assert_frame_equal(result, expected)


@pytest.mark.parametrize("ignore_nulls", [True, False])
@pytest.mark.parametrize(
("dtype_in", "dtype_out"),
[
(pl.Null, pl.Float64),
(pl.Boolean, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_horizontal_mean_with_null_col_ignore_strategy(
dtype_in: PolarsDataType,
dtype_out: PolarsDataType,
ignore_nulls: bool,
) -> None:
lf = pl.LazyFrame(
{
"null": [None, None, None],
"s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False),
"s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False),
}
)
result = lf.select(pl.mean_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls))
if ignore_nulls and dtype_in != pl.Null:
values = [1, 0, 1]
else:
values = [None, None, None] # type: ignore[list-item]
expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))
assert_frame_equal(result, expected)
Loading