From 7683476b2f415b02688c7acef70617d9fab25a72 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 10 Sep 2024 13:49:24 +0200 Subject: [PATCH 1/2] refactor: Fix a bunch of tests for new-streaming --- crates/polars-stream/src/nodes/reduce.rs | 3 +- py-polars/tests/unit/datatypes/test_list.py | 18 +++++------ py-polars/tests/unit/datatypes/test_struct.py | 1 + .../tests/unit/datatypes/test_temporal.py | 4 +-- .../namespaces/temporal/test_datetime.py | 30 ++++++++----------- 5 files changed, 23 insertions(+), 33 deletions(-) diff --git a/crates/polars-stream/src/nodes/reduce.rs b/crates/polars-stream/src/nodes/reduce.rs index 2ce9ee2c9464..f6de3bd1124a 100644 --- a/crates/polars-stream/src/nodes/reduce.rs +++ b/crates/polars-stream/src/nodes/reduce.rs @@ -59,9 +59,8 @@ impl ReduceNode { scope.spawn_task(TaskPriority::High, async move { while let Ok(morsel) = recv.recv().await { for (reducer, selector) in local_reducers.iter_mut().zip(selectors) { - // TODO: don't convert to physical representation here. let input = selector.evaluate(morsel.df(), state).await?; - reducer.update(&input.to_physical_repr())?; + reducer.update(&input)?; } } diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 4607cfa89426..8c5502d698fd 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -114,20 +114,16 @@ def test_cast_inner() -> None: def test_list_empty_group_by_result_3521() -> None: - # Create a left relation where the join column contains a null value - left = pl.DataFrame().with_columns( - pl.lit(1).alias("group_by_column"), - pl.lit(None).cast(pl.Int32).alias("join_column"), + # Create a left relation where the join column contains a null value. + left = pl.DataFrame( + {"group_by_column": [1], "join_column": [None]}, + schema_overrides={"join_column": pl.Int64}, ) - # Create a right relation where there is a column to count distinct on - right = pl.DataFrame().with_columns( - pl.lit(1).alias("join_column"), - pl.lit(1).alias("n_unique_column"), - ) + # Create a right relation where there is a column to count distinct on. + right = pl.DataFrame({"join_column": [1], "n_unique_column": [1]}) - # Calculate n_unique after dropping nulls - # This will panic on polars version 0.13.38 and 0.13.39 + # Calculate n_unique after dropping nulls. result = ( left.join(right, on="join_column", how="left") .group_by("group_by_column") diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 49a223f76fd4..6489a83e5a6b 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -265,6 +265,7 @@ def test_from_dicts_struct() -> None: ] +@pytest.mark.may_fail_auto_streaming def test_list_to_struct() -> None: df = pl.DataFrame({"a": [[1, 2, 3], [1, 2]]}) assert df.select([pl.col("a").list.to_struct()]).to_series().to_list() == [ diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index ea1798fe7114..e0c9f6498c65 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -1399,12 +1399,12 @@ def test_replace_time_zone_sortedness_expressions( from_tz: str | None, expected_sortedness: bool, ambiguous: str ) -> None: df = ( - pl.Series("ts", [1603584000000000, 1603587600000000]) + pl.Series("ts", [1603584000000000, 1603584060000000, 1603587600000000]) .cast(pl.Datetime("us", from_tz)) .sort() .to_frame() ) - df = df.with_columns(ambiguous=pl.Series([ambiguous] * 2)) + df = df.with_columns(ambiguous=pl.Series([ambiguous] * 3)) assert df["ts"].flags["SORTED_ASC"] result = df.select( pl.col("ts").dt.replace_time_zone("UTC", ambiguous=pl.col("ambiguous")) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py index fb4ddee68146..fc9484470604 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -138,15 +138,13 @@ def test_local_date_sortedness(time_zone: str | None, expected: bool) -> None: ser = (pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone)).sort() result = ser.dt.date() assert result.flags["SORTED_ASC"] - assert result.flags["SORTED_DESC"] is False # 2 elements - depends on time zone ser = ( pl.Series([datetime(2022, 1, 1, 23)] * 2).dt.replace_time_zone(time_zone) ).sort() result = ser.dt.date() - assert result.flags["SORTED_ASC"] == expected - assert result.flags["SORTED_DESC"] is False + assert result.flags["SORTED_ASC"] >= expected @pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu", "UTC"]) @@ -155,11 +153,10 @@ def test_local_time_sortedness(time_zone: str | None) -> None: ser = (pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone)).sort() result = ser.dt.time() assert result.flags["SORTED_ASC"] - assert not result.flags["SORTED_DESC"] - # two elements - not sorted + # three elements - not sorted ser = ( - pl.Series([datetime(2022, 1, 1, 23)] * 2).dt.replace_time_zone(time_zone) + pl.Series([datetime(2022, 1, 1, 23), datetime(2022, 1, 2, 21), datetime(2022, 1, 3, 22)]).dt.replace_time_zone(time_zone) ).sort() result = ser.dt.time() assert not result.flags["SORTED_ASC"] @@ -180,31 +177,28 @@ def test_local_time_before_epoch(time_unit: TimeUnit) -> None: ("time_zone", "offset", "expected"), [ (None, "1d", True), - ("Asia/Kathmandu", "1d", False), + ("Europe/London", "1d", False), ("UTC", "1d", True), (None, "1mo", True), - ("Asia/Kathmandu", "1mo", False), + ("Europe/London", "1mo", False), ("UTC", "1mo", True), (None, "1w", True), - ("Asia/Kathmandu", "1w", False), + ("Europe/London", "1w", False), ("UTC", "1w", True), (None, "1h", True), - ("Asia/Kathmandu", "1h", True), + ("Europe/London", "1h", True), ("UTC", "1h", True), ], ) def test_offset_by_sortedness( time_zone: str | None, offset: str, expected: bool ) -> None: - # create 2 values, as a single value is always sorted - ser = ( - pl.Series( - [datetime(2022, 1, 1, 22), datetime(2022, 1, 1, 22)] - ).dt.replace_time_zone(time_zone) - ).sort() - result = ser.dt.offset_by(offset) + s = pl.datetime_range(datetime(2020, 10, 25), datetime(2020, 10, 25, 3), '30m', time_zone=time_zone, eager=True).sort() + assert s.flags["SORTED_ASC"] + assert not s.flags["SORTED_DESC"] + result = s.dt.offset_by(offset) assert result.flags["SORTED_ASC"] == expected - assert result.flags["SORTED_DESC"] is False + assert not result.flags["SORTED_DESC"] def test_dt_datetime_date_time_invalid() -> None: From 85964d2e984eb0dde0a5f4e0eb7e8d9eef5c274e Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 10 Sep 2024 13:50:01 +0200 Subject: [PATCH 2/2] fmt --- .../namespaces/temporal/test_datetime.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py index fc9484470604..a4fcfde344cc 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -156,7 +156,13 @@ def test_local_time_sortedness(time_zone: str | None) -> None: # three elements - not sorted ser = ( - pl.Series([datetime(2022, 1, 1, 23), datetime(2022, 1, 2, 21), datetime(2022, 1, 3, 22)]).dt.replace_time_zone(time_zone) + pl.Series( + [ + datetime(2022, 1, 1, 23), + datetime(2022, 1, 2, 21), + datetime(2022, 1, 3, 22), + ] + ).dt.replace_time_zone(time_zone) ).sort() result = ser.dt.time() assert not result.flags["SORTED_ASC"] @@ -193,7 +199,13 @@ def test_local_time_before_epoch(time_unit: TimeUnit) -> None: def test_offset_by_sortedness( time_zone: str | None, offset: str, expected: bool ) -> None: - s = pl.datetime_range(datetime(2020, 10, 25), datetime(2020, 10, 25, 3), '30m', time_zone=time_zone, eager=True).sort() + s = pl.datetime_range( + datetime(2020, 10, 25), + datetime(2020, 10, 25, 3), + "30m", + time_zone=time_zone, + eager=True, + ).sort() assert s.flags["SORTED_ASC"] assert not s.flags["SORTED_DESC"] result = s.dt.offset_by(offset)