Skip to content

Commit

Permalink
fix: Fix all-null list aggregations returning Null dtype (#20992)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Jan 31, 2025
1 parent ea1ea5a commit 233f9b3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
10 changes: 6 additions & 4 deletions crates/polars-ops/src/chunked_array/list/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult<Series> {
unsafe { out.into_series().from_physical_unchecked(dt) }
})
},
_ => Ok(ca
dt => ca
.try_apply_amortized(|s| {
let s = s.as_ref();
let sc = s.min_reduce()?;
Ok(sc.into_series(s.name().clone()))
})?
.explode()
.unwrap()
.into_series()),
.into_series()
.cast(dt),
}
}

Expand Down Expand Up @@ -199,15 +200,16 @@ pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult<Series> {
unsafe { out.into_series().from_physical_unchecked(dt) }
})
},
_ => Ok(ca
dt => ca
.try_apply_amortized(|s| {
let s = s.as_ref();
let sc = s.max_reduce()?;
Ok(sc.into_series(s.name().clone()))
})?
.explode()
.unwrap()
.into_series()),
.into_series()
.cast(dt),
}
}

Expand Down
5 changes: 3 additions & 2 deletions crates/polars-ops/src/chunked_array/list/sum_mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,16 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars
out.into_series()
},
// slowest sum_as_series path
_ => ca
dt => ca
.try_apply_amortized(|s| {
s.as_ref()
.sum_reduce()
.map(|sc| sc.into_series(PlSmallStr::EMPTY))
})?
.explode()
.unwrap()
.into_series(),
.into_series()
.cast(dt)?,
};
out.rename(ca.name().clone());
Ok(out)
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from datetime import date, datetime
from typing import TYPE_CHECKING

import numpy as np
import pytest
Expand All @@ -14,6 +15,9 @@
)
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars._typing import PolarsDataType


def test_list_arr_get() -> None:
a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]])
Expand Down Expand Up @@ -979,3 +983,21 @@ def test_list_eval_element_schema_19345() -> None:
),
pl.DataFrame({"a": [[1]]}),
)


@pytest.mark.parametrize(
("agg", "inner_dtype", "expected_dtype"),
[
("sum", pl.Int8, pl.Int64),
("max", pl.Int8, pl.Int8),
("sum", pl.Duration("us"), pl.Duration("us")),
("min", pl.Duration("ms"), pl.Duration("ms")),
("min", pl.String, pl.String),
("max", pl.String, pl.String),
],
)
def test_list_agg_all_null(
agg: str, inner_dtype: PolarsDataType, expected_dtype: PolarsDataType
) -> None:
s = pl.Series([None, None], dtype=pl.List(inner_dtype))
assert getattr(s.list, agg)().dtype == expected_dtype

0 comments on commit 233f9b3

Please sign in to comment.