diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 2e96ab27394a..0437abbb21b6 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -3,6 +3,7 @@ use std::borrow::Cow; use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::series::arithmetic::coerce_lhs_rhs; +use polars_core::utils::dtypes_to_supertype; use polars_core::{with_match_physical_numeric_polars_type, POOL}; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; @@ -37,7 +38,7 @@ impl MinMaxHorizontal for DataFrame { } } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq)] pub enum NullStrategy { Ignore, Propagate, @@ -192,20 +193,19 @@ pub fn sum_horizontal( null_strategy: NullStrategy, ) -> PolarsResult> { validate_column_lengths(columns)?; + let ignore_nulls = null_strategy == NullStrategy::Ignore; - let apply_null_strategy = |s: Series, null_strategy: NullStrategy| -> PolarsResult { - if let NullStrategy::Ignore = null_strategy { - // if has nulls - if s.null_count() > 0 { - return s.fill_null(FillNullStrategy::Zero); - } + let apply_null_strategy = |s: Series| -> PolarsResult { + if ignore_nulls && s.null_count() > 0 { + s.fill_null(FillNullStrategy::Zero) + } else { + Ok(s) } - Ok(s) }; - let sum_fn = |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult { - let acc: Series = apply_null_strategy(acc, null_strategy)?; - let s = apply_null_strategy(s, null_strategy)?; + let sum_fn = |acc: Series, s: Series| -> PolarsResult { + let acc: Series = apply_null_strategy(acc)?; + let s = apply_null_strategy(s)?; // This will do owned arithmetic and can be mutable std::ops::Add::add(acc, s) }; @@ -217,6 +217,20 @@ pub fn sum_horizontal( .map(|c| c.as_materialized_series()) .collect::>(); + // If we have any null columns and null strategy is not `Ignore`, we can return immediately. + if !ignore_nulls && non_null_cols.len() < columns.len() { + // We must first determine the correct return dtype. + let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? { + DataType::Boolean => DataType::UInt32, + dt => dt, + }; + return Ok(Some(Column::full_null( + columns[0].name().clone(), + columns[0].len(), + &return_dtype, + ))); + } + match non_null_cols.len() { 0 => { if columns.is_empty() { @@ -227,23 +241,16 @@ pub fn sum_horizontal( } }, 1 => Ok(Some( - apply_null_strategy( - if non_null_cols[0].dtype() == &DataType::Boolean { - non_null_cols[0].cast(&DataType::UInt32)? - } else { - non_null_cols[0].clone() - }, - null_strategy, - )? + apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean { + non_null_cols[0].cast(&DataType::UInt32)? + } else { + non_null_cols[0].clone() + })? .into(), )), - 2 => sum_fn( - non_null_cols[0].clone(), - non_null_cols[1].clone(), - null_strategy, - ) - .map(Column::from) - .map(Some), + 2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone()) + .map(Column::from) + .map(Some), _ => { // the try_reduce_with is a bit slower in parallelism, // but I don't think it matters here as we parallelize over columns, not over elements @@ -252,7 +259,7 @@ pub fn sum_horizontal( .into_par_iter() .cloned() .map(Ok) - .try_reduce_with(|l, r| sum_fn(l, r, null_strategy)) + .try_reduce_with(sum_fn) // We can unwrap because we started with at least 3 columns, so we always get a Some .unwrap() }); @@ -281,7 +288,8 @@ pub fn mean_horizontal( ); } let columns = numeric_columns.into_iter().cloned().collect::>(); - match columns.len() { + let num_rows = columns.len(); + match num_rows { 0 => Ok(None), 1 => Ok(Some(match columns[0].dtype() { dt if dt != &DataType::Float32 && !dt.is_decimal() => { @@ -317,7 +325,7 @@ pub fn mean_horizontal( // value lengths: len - null_count let value_length: UInt32Chunked = (Column::new_scalar( PlSmallStr::EMPTY, - Scalar::from(columns.len() as u32), + Scalar::from(num_rows as u32), null_count.len(), ) - null_count)? .u32() @@ -326,10 +334,18 @@ pub fn mean_horizontal( // make sure that we do not divide by zero // by replacing with None + let dt = if sum + .as_ref() + .is_some_and(|s| s.dtype() == &DataType::Float32) + { + &DataType::Float32 + } else { + &DataType::Float64 + }; let value_length = value_length .set(&value_length.equal(0), None)? .into_column() - .cast(&DataType::Float64)?; + .cast(dt)?; sum.map(|sum| std::ops::Div::div(&sum, &value_length)) .transpose() diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 106a05c2041f..d45f75c01e9d 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -338,7 +338,15 @@ impl FunctionExpr { } }) }, - MeanHorizontal { .. } => mapper.map_to_float_dtype(), + MeanHorizontal { .. } => { + mapper.map_to_supertype().map(|mut f| { + match f.dtype { + dt @ DataType::Float32 => { f.dtype = dt; }, + _ => { f.dtype = DataType::Float64; }, + }; + f + }) + } #[cfg(feature = "ewma")] EwmMean { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "ewma_by")] diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index ea93b1da97fd..3959e15e22ed 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -556,3 +556,68 @@ def test_horizontal_sum_boolean_with_null() -> None: ) assert_frame_equal(out.collect(), expected_df) + + +@pytest.mark.parametrize("ignore_nulls", [True, False]) +@pytest.mark.parametrize( + ("dtype_in", "dtype_out"), + [ + (pl.Null, pl.Null), + (pl.Boolean, pl.UInt32), + (pl.UInt8, pl.UInt8), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + (pl.Decimal(None, 5), pl.Decimal(None, 5)), + ], +) +def test_horizontal_sum_with_null_col_ignore_strategy( + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + ignore_nulls: bool, +) -> None: + lf = pl.LazyFrame( + { + "null": [None, None, None], + "s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False), + "s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False), + } + ) + result = lf.select(pl.sum_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls)) + if ignore_nulls and dtype_in != pl.Null: + values = [2, 0, 1] + else: + values = [None, None, None] # type: ignore[list-item] + expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("ignore_nulls", [True, False]) +@pytest.mark.parametrize( + ("dtype_in", "dtype_out"), + [ + (pl.Null, pl.Float64), + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_horizontal_mean_with_null_col_ignore_strategy( + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + ignore_nulls: bool, +) -> None: + lf = pl.LazyFrame( + { + "null": [None, None, None], + "s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False), + "s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False), + } + ) + result = lf.select(pl.mean_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls)) + if ignore_nulls and dtype_in != pl.Null: + values = [1, 0, 1] + else: + values = [None, None, None] # type: ignore[list-item] + expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) + assert_frame_equal(result, expected)