Skip to content

Commit

Permalink
fix: Incorrect output type for map_groups returning all-NULL column (
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Jan 16, 2025
1 parent 233b396 commit 68f547e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
6 changes: 6 additions & 0 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ impl ApplyExpr {

fn all_unit_length(ca: &ListChunked) -> bool {
assert_eq!(ca.chunks().len(), 1);

// Handles the Null dtype - in that case the offsets can be (e.g. [0,0,0 ...])
if ca.null_count() == ca.len() {
return true;
}

let list_arr = ca.downcast_iter().next().unwrap();
let offset = list_arr.offsets().as_slice();
(offset[offset.len() - 1] as usize) == list_arr.len()
Expand Down
39 changes: 39 additions & 0 deletions crates/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,42 @@ fn test_take_in_groups() -> PolarsResult<()> {
);
Ok(())
}

#[test]
fn test_anonymous_function_returns_scalar_all_null_20679() {
use std::sync::Arc;

fn reduction_function(column: Column) -> PolarsResult<Option<Column>> {
let val = column.get(0)?.into_static();
let col = Column::new_scalar("".into(), Scalar::new(column.dtype().clone(), val), 1);
Ok(Some(col))
}

let a = Column::new("a".into(), &[0, 0, 1]);
let dtype = DataType::Null;
let b = Column::new_scalar("b".into(), Scalar::new(dtype, AnyValue::Null), 3);
let df = DataFrame::new(vec![a, b]).unwrap();

let f = move |c: &mut [Column]| reduction_function(std::mem::take(&mut c[0]));

let expr = Expr::AnonymousFunction {
input: vec![col("b")],
function: LazySerde::Deserialized(SpecialEq::new(Arc::new(f))),
output_type: Default::default(),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
fmt_str: "",
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
};

let grouped_df = df
.lazy()
.group_by([col("a")])
.agg([expr])
.collect()
.unwrap();

assert_eq!(grouped_df.get_columns()[1].dtype(), &DataType::Null);
}
13 changes: 13 additions & 0 deletions py-polars/tests/unit/operations/map/test_map_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,16 @@ def test_map_groups_numpy_output_3057() -> None:

expected = pl.DataFrame({"id": [0, 1], "result": [2.266666, 7.333333]})
assert_frame_equal(result, expected)


def test_map_groups_return_all_null_15260() -> None:
def foo(x: pl.Series) -> pl.Series:
return pl.Series([x[0][0]], dtype=x[0].dtype)

assert_frame_equal(
pl.DataFrame({"key": [0, 0, 1], "a": [None, None, None]})
.group_by("key")
.agg(pl.map_groups(exprs=["a"], function=foo)) # type: ignore[arg-type]
.sort("a"),
pl.DataFrame({"key": [0, 1], "a": [None, None]}),
)

0 comments on commit 68f547e

Please sign in to comment.