From 233f9b385a0a56b047483a77f292f904c8f4b6eb Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Fri, 31 Jan 2025 02:56:16 -0500 Subject: [PATCH] fix: Fix all-null list aggregations returning Null dtype (#20992) --- .../src/chunked_array/list/min_max.rs | 10 +++++---- .../src/chunked_array/list/sum_mean.rs | 5 +++-- .../operations/namespaces/list/test_list.py | 22 +++++++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs index ee76b74e7579..5d795ecaca7e 100644 --- a/crates/polars-ops/src/chunked_array/list/min_max.rs +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -89,7 +89,7 @@ pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult { unsafe { out.into_series().from_physical_unchecked(dt) } }) }, - _ => Ok(ca + dt => ca .try_apply_amortized(|s| { let s = s.as_ref(); let sc = s.min_reduce()?; @@ -97,7 +97,8 @@ pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult { })? .explode() .unwrap() - .into_series()), + .into_series() + .cast(dt), } } @@ -199,7 +200,7 @@ pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult { unsafe { out.into_series().from_physical_unchecked(dt) } }) }, - _ => Ok(ca + dt => ca .try_apply_amortized(|s| { let s = s.as_ref(); let sc = s.max_reduce()?; @@ -207,7 +208,8 @@ pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult { })? .explode() .unwrap() - .into_series()), + .into_series() + .cast(dt), } } diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index fb26cb17872e..9eb0c85feb71 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -106,7 +106,7 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars out.into_series() }, // slowest sum_as_series path - _ => ca + dt => ca .try_apply_amortized(|s| { s.as_ref() .sum_reduce() @@ -114,7 +114,8 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars })? .explode() .unwrap() - .into_series(), + .into_series() + .cast(dt)?, }; out.rename(ca.name().clone()); Ok(out) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index bed606e7c7ef..1ee41d52bea7 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -2,6 +2,7 @@ import re from datetime import date, datetime +from typing import TYPE_CHECKING import numpy as np import pytest @@ -14,6 +15,9 @@ ) from polars.testing import assert_frame_equal, assert_series_equal +if TYPE_CHECKING: + from polars._typing import PolarsDataType + def test_list_arr_get() -> None: a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) @@ -979,3 +983,21 @@ def test_list_eval_element_schema_19345() -> None: ), pl.DataFrame({"a": [[1]]}), ) + + +@pytest.mark.parametrize( + ("agg", "inner_dtype", "expected_dtype"), + [ + ("sum", pl.Int8, pl.Int64), + ("max", pl.Int8, pl.Int8), + ("sum", pl.Duration("us"), pl.Duration("us")), + ("min", pl.Duration("ms"), pl.Duration("ms")), + ("min", pl.String, pl.String), + ("max", pl.String, pl.String), + ], +) +def test_list_agg_all_null( + agg: str, inner_dtype: PolarsDataType, expected_dtype: PolarsDataType +) -> None: + s = pl.Series([None, None], dtype=pl.List(inner_dtype)) + assert getattr(s.list, agg)().dtype == expected_dtype