From 3ac1c40dbcc05041d18ad6ae162a7e246a944e21 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 30 May 2024 08:24:21 +0200 Subject: [PATCH] feat(python): Add `replace_all` expression to complement `replace` (#16557) --- .../reference/expressions/modify_select.rst | 1 + .../source/reference/series/computation.rst | 1 + py-polars/polars/_utils/various.py | 5 +- py-polars/polars/expr/expr.py | 192 +++++++++- py-polars/polars/series/series.py | 133 ++++++- .../tests/unit/operations/test_replace.py | 324 +---------------- .../tests/unit/operations/test_replace_all.py | 343 ++++++++++++++++++ 7 files changed, 646 insertions(+), 353 deletions(-) create mode 100644 py-polars/tests/unit/operations/test_replace_all.py diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index 31e82ba56fc2..eed974efb1f7 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -44,6 +44,7 @@ Manipulation/selection Expr.reinterpret Expr.repeat_by Expr.replace + Expr.replace_all Expr.reshape Expr.reverse Expr.rle diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst index 68f3ca57e640..3f6b324b54d5 100644 --- a/py-polars/docs/source/reference/series/computation.rst +++ b/py-polars/docs/source/reference/series/computation.rst @@ -50,6 +50,7 @@ Computation Series.peak_min Series.rank Series.replace + Series.replace_all Series.rolling_apply Series.rolling_map Series.rolling_max diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index 5644b1cd2a86..225b2ecd5056 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -312,10 +312,7 @@ def str_duration_(td: str | None) -> int | None: .cast(tp) ) elif tp == Boolean: - cast_cols[c] = F.col(c).replace( - {"true": True, "false": False}, - default=None, - ) + cast_cols[c] = F.col(c).replace_all({"true": True, "false": False}) elif tp in INTEGER_DTYPES: int_string = F.col(c).str.replace_all(r"[^\d+-]", "") cast_cols[c] = ( diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 2a33f8cb54ef..86f8c92c5a7d 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -11371,16 +11371,26 @@ def replace( Accepts expression input. Sequences are parsed as Series, other non-expression inputs are parsed as literals. Length must match the length of `old` or have length 1. + default Set values that were not replaced to this value. Defaults to keeping the original value. Accepts expression input. Non-expression inputs are parsed as literals. + + .. deprecated:: 0.20.31 + Use :meth:`replace_all` instead to set a default while replacing values. + return_dtype The data type of the resulting expression. If set to `None` (default), the data type is determined automatically based on the other inputs. + .. deprecated:: 0.20.31 + Use :meth:`replace_all` instead to set a return data type while + replacing values. + See Also -------- + replace_all str.replace Notes @@ -11422,25 +11432,23 @@ def replace( └─────┴──────────┘ Passing a mapping with replacements is also supported as syntactic sugar. - Specify a default to set all values that were not matched. >>> mapping = {2: 100, 3: 200} - >>> df.with_columns(replaced=pl.col("a").replace(mapping, default=-1)) + >>> df.with_columns(replaced=pl.col("a").replace(mapping)) shape: (4, 2) ┌─────┬──────────┐ │ a ┆ replaced │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞═════╪══════════╡ - │ 1 ┆ -1 │ + │ 1 ┆ 1 │ │ 2 ┆ 100 │ │ 2 ┆ 100 │ │ 3 ┆ 200 │ └─────┴──────────┘ Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and either the original data type or the - default data type if it was set. + a combination of the `new` data type and the original data type. >>> df = pl.DataFrame({"a": ["x", "y", "z"]}) >>> mapping = {"x": 1, "y": 2, "z": 3} @@ -11455,7 +11463,156 @@ def replace( │ y ┆ 2 │ │ z ┆ 3 │ └─────┴──────────┘ - >>> df.with_columns(replaced=pl.col("a").replace(mapping, default=None)) + + Expression input is supported. + + >>> df = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1.5, 2.5, 5.0, 1.0]}) + >>> df.with_columns( + ... replaced=pl.col("a").replace( + ... old=pl.col("a").max(), + ... new=pl.col("b").sum(), + ... ) + ... ) + shape: (4, 3) + ┌─────┬─────┬──────────┐ + │ a ┆ b ┆ replaced │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ f64 │ + ╞═════╪═════╪══════════╡ + │ 1 ┆ 1.5 ┆ 1.0 │ + │ 2 ┆ 2.5 ┆ 2.0 │ + │ 2 ┆ 5.0 ┆ 2.0 │ + │ 3 ┆ 1.0 ┆ 10.0 │ + └─────┴─────┴──────────┘ + """ + if new is no_default and isinstance(old, Mapping): + new = pl.Series(old.values()) + old = pl.Series(old.keys()) + else: + if isinstance(old, Sequence) and not isinstance(old, (str, pl.Series)): + old = pl.Series(old) + if isinstance(new, Sequence) and not isinstance(new, (str, pl.Series)): + new = pl.Series(new) + + old = parse_as_expression(old, str_as_lit=True) # type: ignore[arg-type] + new = parse_as_expression(new, str_as_lit=True) # type: ignore[arg-type] + + if default is no_default: + default = None + else: + issue_deprecation_warning( + "The `default` parameter for `replace` is deprecated." + " Use `replace_all` instead to set a default while replacing values.", + version="0.20.31", + ) + default = parse_as_expression(default, str_as_lit=True) + + if return_dtype is not None: + issue_deprecation_warning( + "The `return_dtype` parameter for `replace` is deprecated." + " Use `replace_all` instead to set a return data type while replacing values.", + version="0.20.31", + ) + + return self._from_pyexpr(self._pyexpr.replace(old, new, default, return_dtype)) + + def replace_all( + self, + old: IntoExpr | Sequence[Any] | Mapping[Any, Any], + new: IntoExpr | Sequence[Any] | NoDefault = no_default, + *, + default: IntoExpr = None, + return_dtype: PolarsDataType | None = None, + ) -> Self: + """ + Replace all values by different values. + + Parameters + ---------- + old + Value or sequence of values to replace. + Accepts expression input. Sequences are parsed as Series, + other non-expression inputs are parsed as literals. + Also accepts a mapping of values to their replacement as syntactic sugar for + `replace_all(old=Series(mapping.keys()), new=Series(mapping.values()))`. + new + Value or sequence of values to replace by. + Accepts expression input. Sequences are parsed as Series, + other non-expression inputs are parsed as literals. + Length must match the length of `old` or have length 1. + default + Set values that were not replaced to this value. Defaults to null. + Accepts expression input. Non-expression inputs are parsed as literals. + return_dtype + The data type of the resulting expression. If set to `None` (default), + the data type is determined automatically based on the other inputs. + + See Also + -------- + replace + str.replace + + Notes + ----- + The global string cache must be enabled when replacing categorical values. + + Examples + -------- + Replace a single value by another value. Values that were not replaced are set + to null. + + >>> df = pl.DataFrame({"a": [1, 2, 2, 3]}) + >>> df.with_columns(replaced=pl.col("a").replace_all(2, 100)) + shape: (4, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ i64 ┆ i32 │ + ╞═════╪══════════╡ + │ 1 ┆ null │ + │ 2 ┆ 100 │ + │ 2 ┆ 100 │ + │ 3 ┆ null │ + └─────┴──────────┘ + + Replace multiple values by passing sequences to the `old` and `new` parameters. + + >>> df.with_columns(replaced=pl.col("a").replace_all([2, 3], [100, 200])) + shape: (4, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════════╡ + │ 1 ┆ null │ + │ 2 ┆ 100 │ + │ 2 ┆ 100 │ + │ 3 ┆ 200 │ + └─────┴──────────┘ + + Passing a mapping with replacements is also supported as syntactic sugar. + Specify a default to set all values that were not matched. + + >>> mapping = {2: 100, 3: 200} + >>> df.with_columns(replaced=pl.col("a").replace_all(mapping, default=-1)) + shape: (4, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════════╡ + │ 1 ┆ -1 │ + │ 2 ┆ 100 │ + │ 2 ┆ 100 │ + │ 3 ┆ 200 │ + └─────┴──────────┘ + + Replacing by values of a different data type sets the return type based on + a combination of the `new` data type and the `default` data type. + + >>> df = pl.DataFrame({"a": ["x", "y", "z"]}) + >>> mapping = {"x": 1, "y": 2, "z": 3} + >>> df.with_columns(replaced=pl.col("a").replace_all(mapping)) shape: (3, 2) ┌─────┬──────────┐ │ a ┆ replaced │ @@ -11466,11 +11623,22 @@ def replace( │ y ┆ 2 │ │ z ┆ 3 │ └─────┴──────────┘ + >>> df.with_columns(replaced=pl.col("a").replace_all(mapping, default="x")) + shape: (3, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════╪══════════╡ + │ x ┆ 1 │ + │ y ┆ 2 │ + │ z ┆ 3 │ + └─────┴──────────┘ Set the `return_dtype` parameter to control the resulting data type directly. >>> df.with_columns( - ... replaced=pl.col("a").replace(mapping, return_dtype=pl.UInt8) + ... replaced=pl.col("a").replace_all(mapping, return_dtype=pl.UInt8) ... ) shape: (3, 2) ┌─────┬──────────┐ @@ -11487,7 +11655,7 @@ def replace( >>> df = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1.5, 2.5, 5.0, 1.0]}) >>> df.with_columns( - ... replaced=pl.col("a").replace( + ... replaced=pl.col("a").replace_all( ... old=pl.col("a").max(), ... new=pl.col("b").sum(), ... default=pl.col("b"), @@ -11517,11 +11685,7 @@ def replace( old = parse_as_expression(old, str_as_lit=True) # type: ignore[arg-type] new = parse_as_expression(new, str_as_lit=True) # type: ignore[arg-type] - default = ( - None - if default is no_default - else parse_as_expression(default, str_as_lit=True) - ) + default = parse_as_expression(default, str_as_lit=True) return self._from_pyexpr(self._pyexpr.replace(old, new, default, return_dtype)) @@ -11974,7 +12138,7 @@ def map_dict( return_dtype Set return dtype to override automatic return dtype determination. """ - return self.replace(mapping, default=default, return_dtype=return_dtype) + return self.replace_all(mapping, default=default, return_dtype=return_dtype) @classmethod def from_json(cls, value: str) -> Self: diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index bb51e859fbec..e61213175d6f 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6853,16 +6853,27 @@ def replace( new Value or sequence of values to replace by. Length must match the length of `old` or have length 1. + default Set values that were not replaced to this value. Defaults to keeping the original value. Accepts expression input. Non-expression inputs are parsed as literals. + + .. deprecated:: 0.20.31 + Use :meth:`replace_all` instead to set a default while replacing values. + return_dtype - The data type of the resulting Series. If set to `None` (default), + The data type of the resulting expression. If set to `None` (default), the data type is determined automatically based on the other inputs. + .. deprecated:: 0.20.31 + Use :meth:`replace_all` instead to set a return data type while + replacing values. + + See Also -------- + replace_all str.replace Notes @@ -6897,11 +6908,103 @@ def replace( 200 ] + Passing a mapping with replacements is also supported as syntactic sugar. + + >>> mapping = {2: 100, 3: 200} + >>> s.replace(mapping) + shape: (4,) + Series: '' [i64] + [ + 1 + 100 + 100 + 200 + ] + + Replacing by values of a different data type sets the return type based on + a combination of the `new` data type and the original data type. + + >>> s = pl.Series(["x", "y", "z"]) + >>> mapping = {"x": 1, "y": 2, "z": 3} + >>> s.replace(mapping) + shape: (3,) + Series: '' [str] + [ + "1" + "2" + "3" + ] + """ + + def replace_all( + self, + old: IntoExpr | Sequence[Any] | Mapping[Any, Any], + new: IntoExpr | Sequence[Any] | NoDefault = no_default, + *, + default: IntoExpr = None, + return_dtype: PolarsDataType | None = None, + ) -> Self: + """ + Replace all values by different values. + + Parameters + ---------- + old + Value or sequence of values to replace. + Also accepts a mapping of values to their replacement as syntactic sugar for + `replace_all(old=Series(mapping.keys()), new=Series(mapping.values()))`. + new + Value or sequence of values to replace by. + Length must match the length of `old` or have length 1. + default + Set values that were not replaced to this value. Defaults to null. + Accepts expression input. Non-expression inputs are parsed as literals. + return_dtype + The data type of the resulting Series. If set to `None` (default), + the data type is determined automatically based on the other inputs. + + See Also + -------- + replace + str.replace + + Notes + ----- + The global string cache must be enabled when replacing categorical values. + + Examples + -------- + Replace a single value by another value. Values that were not replaced are set + to null. + + >>> s = pl.Series([1, 2, 2, 3]) + >>> s.replace_all(2, 100) + shape: (4,) + Series: '' [i32] + [ + null + 100 + 100 + null + ] + + Replace multiple values by passing sequences to the `old` and `new` parameters. + + >>> s.replace_all([2, 3], [100, 200]) + shape: (4,) + Series: '' [i64] + [ + null + 100 + 100 + 200 + ] + Passing a mapping with replacements is also supported as syntactic sugar. Specify a default to set all values that were not matched. >>> mapping = {2: 100, 3: 200} - >>> s.replace(mapping, default=-1) + >>> s.replace_all(mapping, default=-1) shape: (4,) Series: '' [i64] [ @@ -6911,11 +7014,10 @@ def replace( 200 ] - The default can be another Series. >>> default = pl.Series([2.5, 5.0, 7.5, 10.0]) - >>> s.replace(2, 100, default=default) + >>> s.replace_all(2, 100, default=default) shape: (4,) Series: '' [f64] [ @@ -6926,20 +7028,11 @@ def replace( ] Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and either the original data type or the - default data type if it was set. + a combination of the `new` data type and the `default` data type. >>> s = pl.Series(["x", "y", "z"]) >>> mapping = {"x": 1, "y": 2, "z": 3} - >>> s.replace(mapping) - shape: (3,) - Series: '' [str] - [ - "1" - "2" - "3" - ] - >>> s.replace(mapping, default=None) + >>> s.replace_all(mapping) shape: (3,) Series: '' [i64] [ @@ -6947,10 +7040,18 @@ def replace( 2 3 ] + >>> s.replace_all(mapping, default="x") + shape: (3,) + Series: '' [str] + [ + "1" + "2" + "3" + ] Set the `return_dtype` parameter to control the resulting data type directly. - >>> s.replace(mapping, return_dtype=pl.UInt8) + >>> s.replace_all(mapping, return_dtype=pl.UInt8) shape: (3,) Series: '' [u8] [ diff --git a/py-polars/tests/unit/operations/test_replace.py b/py-polars/tests/unit/operations/test_replace.py index cd5f80a73364..f2d1fa7f4681 100644 --- a/py-polars/tests/unit/operations/test_replace.py +++ b/py-polars/tests/unit/operations/test_replace.py @@ -1,12 +1,10 @@ from __future__ import annotations -import contextlib from typing import Any import pytest import polars as pl -from polars.exceptions import CategoricalRemappingWarning from polars.testing import assert_frame_equal, assert_series_equal @@ -27,44 +25,6 @@ def test_replace_str_to_str(str_mapping: dict[str | None, str]) -> None: assert_frame_equal(result, expected) -def test_replace_str_to_str_default_self(str_mapping: dict[str | None, str]) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - result = df.select( - replaced=pl.col("country_code").replace( - str_mapping, default=pl.col("country_code") - ) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_str_default_null(str_mapping: dict[str | None, str]) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - result = df.select( - replaced=pl.col("country_code").replace(str_mapping, default=None) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", None, "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_str_default_other(str_mapping: dict[str | None, str]) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - - result = df.with_row_index().select( - replaced=pl.col("country_code").replace(str_mapping, default=pl.col("index")) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", "2", "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_cat() -> None: - s = pl.Series(["a", "b", "c"]) - mapping = {"a": "c", "b": "d"} - result = s.replace(mapping, return_dtype=pl.Categorical) - expected = pl.Series(["c", "d", "c"], dtype=pl.Categorical) - assert_series_equal(result, expected, categorical_as_str=True) - - def test_replace_enum() -> None: dtype = pl.Enum(["a", "b", "c", "d"]) s = pl.Series(["a", "b", "c"], dtype=dtype) @@ -87,19 +47,6 @@ def test_replace_enum_to_str() -> None: assert_series_equal(result, expected) -def test_replace_enum_to_new_enum() -> None: - s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) - old = ["a", "b"] - - new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) - new = pl.Series(["c", "e"], dtype=new_dtype) - - result = s.replace(old, new, return_dtype=new_dtype) - - expected = pl.Series(["c", "e", "c"], dtype=new_dtype) - assert_series_equal(result, expected) - - @pl.StringCache() def test_replace_cat_to_cat(str_mapping: dict[str | None, str]) -> None: lf = pl.LazyFrame( @@ -165,42 +112,6 @@ def test_replace_int_to_str_with_null() -> None: assert_frame_equal(result, expected) -def test_replace_int_to_int_null() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - result = df.select( - replaced=pl.col("int").replace(mapping, default=pl.lit(6).cast(pl.Int16)) - ) - expected = pl.DataFrame( - {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int16} - ) - assert_frame_equal(result, expected) - - -def test_replace_int_to_int_null_default_null() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - result = df.select(replaced=pl.col("int").replace(mapping, default=None)) - expected = pl.DataFrame( - {"replaced": [None, None, None, None]}, schema={"replaced": pl.Null} - ) - assert_frame_equal(result, expected) - - -def test_replace_int_to_int_null_return_dtype() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - - result = df.select( - replaced=pl.col("int").replace(mapping, default=6, return_dtype=pl.Int32) - ) - - expected = pl.DataFrame( - {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int32} - ) - assert_frame_equal(result, expected) - - def test_replace_empty_mapping() -> None: df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) mapping: dict[Any, Any] = {} @@ -208,14 +119,6 @@ def test_replace_empty_mapping() -> None: assert_frame_equal(result, df) -def test_replace_empty_mapping_default() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping: dict[Any, Any] = {} - result = df.select(pl.col("int").replace(mapping, default=pl.lit("A"))) - expected = pl.DataFrame({"int": ["A", "A", "A", "A"]}) - assert_frame_equal(result, expected) - - def test_replace_mapping_different_dtype_str_int() -> None: df = pl.DataFrame({"int": [None, "1", None, "3"]}) mapping = {1: "b", 3: "d"} @@ -250,60 +153,6 @@ def test_replace_str_to_str_replace_all() -> None: assert_frame_equal(result, expected) -def test_replace_int_to_int_df() -> None: - lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) - mapping = {1: 11, 2: 22} - - result = lf.select( - pl.col("a").replace( - old=pl.Series(mapping.keys()), - new=pl.Series(mapping.values(), dtype=pl.UInt8), - default=pl.lit(99).cast(pl.UInt8), - ) - ) - expected = pl.LazyFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_int_fill_null() -> None: - lf = pl.LazyFrame({"a": ["one", "two"]}) - mapping = {"one": 1} - - result = lf.select( - pl.col("a") - .replace(mapping, default=None, return_dtype=pl.UInt32) - .fill_null(999) - ) - - expected = pl.LazyFrame({"a": pl.Series([1, 999], dtype=pl.UInt32)}) - assert_frame_equal(result, expected) - - -def test_replace_mix() -> None: - df = pl.DataFrame( - [ - pl.Series("float_to_boolean", [1.0, None]), - pl.Series("boolean_to_int", [True, False]), - pl.Series("boolean_to_str", [True, False]), - ] - ) - - result = df.with_columns( - pl.col("float_to_boolean").replace({1.0: True}, default=None), - pl.col("boolean_to_int").replace({True: 1, False: 0}), - pl.col("boolean_to_str").replace({True: "1", False: "0"}), - ) - - expected = pl.DataFrame( - [ - pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), - pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), - pl.Series("boolean_to_str", ["1", "0"], dtype=pl.String), - ] - ) - assert_frame_equal(result, expected) - - @pytest.fixture(scope="module") def int_mapping() -> dict[int, int]: return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} @@ -316,20 +165,6 @@ def test_replace_int_to_int1(int_mapping: dict[int, int]) -> None: assert_series_equal(result, expected) -def test_replace_int_to_int2(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5]) - result = s.replace(int_mapping, default=None) - expected = pl.Series([11, None, None, None, None], dtype=pl.Int64) - assert_series_equal(result, expected) - - -def test_replace_int_to_int3(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace(int_mapping, default=9) - expected = pl.Series([11, 9, 9, 9, 9], dtype=pl.Int64) - assert_series_equal(result, expected) - - def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None: s = pl.Series([-1, 22, None, 44, -5]) result = s.replace(int_mapping) @@ -337,72 +172,6 @@ def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None: assert_series_equal(result, expected) -def test_replace_int_to_int4_return_dtype(int_mapping: dict[int, int]) -> None: - s = pl.Series([-1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace(int_mapping, return_dtype=pl.Float32) - expected = pl.Series([-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32) - assert_series_equal(result, expected) - - -def test_replace_int_to_int5_return_dtype(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace(int_mapping, default=9, return_dtype=pl.Float32) - expected = pl.Series([11.0, 9.0, 9.0, 9.0, 9.0], dtype=pl.Float32) - assert_series_equal(result, expected) - - -def test_replace_bool_to_int() -> None: - s = pl.Series([True, False, False, None]) - mapping = {True: 1, False: 0} - result = s.replace(mapping) - expected = pl.Series([1, 0, 0, None]) - assert_series_equal(result, expected) - - -def test_replace_bool_to_str() -> None: - s = pl.Series([True, False, False, None]) - mapping = {True: "1", False: "0"} - result = s.replace(mapping) - expected = pl.Series(["1", "0", "0", None]) - assert_series_equal(result, expected) - - -def test_replace_str_to_bool_without_default() -> None: - s = pl.Series(["True", "False", "False", None]) - mapping = {"True": True, "False": False} - result = s.replace(mapping) - expected = pl.Series(["true", "false", "false", None]) - assert_series_equal(result, expected) - - -def test_replace_str_to_bool_with_default() -> None: - s = pl.Series(["True", "False", "False", None]) - mapping = {"True": True, "False": False} - result = s.replace(mapping, default=None) - expected = pl.Series([True, False, False, None]) - assert_series_equal(result, expected) - - -def test_replace_int_to_str() -> None: - s = pl.Series("a", [-1, 2, None, 4, -5]) - mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} - - result = s.replace(mapping) - - expected = pl.Series("a", ["-1", "two", None, "four", "-5"]) - assert_series_equal(result, expected) - - -def test_replace_int_to_str_with_default() -> None: - s = pl.Series("a", [1, 2, None, 4, 5]) - mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} - - result = s.replace(mapping, default="?") - - expected = pl.Series("a", ["one", "two", "?", "four", "five"]) - assert_series_equal(result, expected) - - # https://github.com/pola-rs/polars/issues/12728 def test_replace_str_to_int2() -> None: s = pl.Series(["a", "b"]) @@ -412,11 +181,11 @@ def test_replace_str_to_int2() -> None: assert_series_equal(result, expected) -def test_replace_str_to_int_with_default() -> None: - s = pl.Series(["a", "b"]) - mapping = {"a": 1, "b": 2} - result = s.replace(mapping, default=None) - expected = pl.Series([1, 2]) +def test_replace_str_to_bool_without_default() -> None: + s = pl.Series(["True", "False", "False", None]) + mapping = {"True": True, "False": False} + result = s.replace(mapping) + expected = pl.Series(["true", "false", "false", None]) assert_series_equal(result, expected) @@ -469,20 +238,6 @@ def test_replace_fast_path_many_to_one() -> None: assert_frame_equal(result, expected) -def test_replace_fast_path_many_to_one_default() -> None: - lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) - result = lf.select(pl.col("a").replace([2, 3], 100, default=-1)) - expected = pl.LazyFrame({"a": [-1, 100, 100, 100]}, schema={"a": pl.Int64}) - assert_frame_equal(result, expected) - - -def test_replace_fast_path_many_to_one_null() -> None: - lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) - result = lf.select(pl.col("a").replace([2, 3], None, default=-1)) - expected = pl.LazyFrame({"a": [-1, None, None, None]}, schema={"a": pl.Int64}) - assert_frame_equal(result, expected) - - @pytest.mark.parametrize( ("old", "new"), [ @@ -505,72 +260,3 @@ def test_replace_duplicates_new() -> None: result = s.replace([1, 2], [100, 100]) expected = s = pl.Series([100, 100, 3, 100, 3]) assert_series_equal(result, expected) - - -def test_map_dict_deprecated() -> None: - s = pl.Series("a", [1, 2, 3]) - with pytest.deprecated_call(): - result = s.map_dict({2: 100}) - expected = pl.Series("a", [None, 100, None]) - assert_series_equal(result, expected) - - with pytest.deprecated_call(): - result = s.to_frame().select(pl.col("a").map_dict({2: 100})).to_series() - assert_series_equal(result, expected) - - -@pytest.mark.parametrize( - ("context", "dtype"), - [ - (pl.StringCache(), pl.Categorical), - (pytest.warns(CategoricalRemappingWarning), pl.Categorical), - (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), - ], -) -def test_replace_cat_str( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] - dtype: pl.DataType, -) -> None: - with context: - for old, new, expected in [ - ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), - (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - ( - pl.Series(["a", "b"], dtype=dtype), - ["c", "d"], - pl.Series("s", ["c", "d"], dtype=pl.Utf8), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace(old, new, default="OTHER") # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) - - -@pytest.mark.parametrize( - "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] -) -def test_replace_cat_cat( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] -) -> None: - with context: - dt = pl.Categorical - for old, new, expected in [ - ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), - ( - ["a", "b"], - pl.Series(["c", "d"], dtype=dt), - pl.Series("s", ["c", "d"], dtype=dt), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) diff --git a/py-polars/tests/unit/operations/test_replace_all.py b/py-polars/tests/unit/operations/test_replace_all.py new file mode 100644 index 000000000000..844c2488b953 --- /dev/null +++ b/py-polars/tests/unit/operations/test_replace_all.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import contextlib +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import CategoricalRemappingWarning +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.fixture(scope="module") +def str_mapping() -> dict[str | None, str]: + return { + "CA": "Canada", + "DE": "Germany", + "FR": "France", + None: "Not specified", + } + + +def test_replace_all_fast_path_many_to_one_default() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace_all([2, 3], 100, default=-1)) + expected = pl.LazyFrame({"a": [-1, 100, 100, 100]}, schema={"a": pl.Int64}) + assert_frame_equal(result, expected) + + +def test_replace_all_fast_path_many_to_one_null() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace_all([2, 3], None, default=-1)) + expected = pl.LazyFrame({"a": [-1, None, None, None]}, schema={"a": pl.Int64}) + assert_frame_equal(result, expected) + + +def test_replace_all_str_to_str_default_self( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select( + replaced=pl.col("country_code").replace_all( + str_mapping, default=pl.col("country_code") + ) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_all_str_to_str_default_null( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select(replaced=pl.col("country_code").replace_all(str_mapping)) + expected = pl.DataFrame({"replaced": ["France", "Not specified", None, "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_all_str_to_str_default_other( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + + result = df.with_row_index().select( + replaced=pl.col("country_code").replace_all( + str_mapping, default=pl.col("index") + ) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "2", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_str_to_cat() -> None: + s = pl.Series(["a", "b", "c"]) + mapping = {"a": "c", "b": "d"} + result = s.replace_all(mapping, return_dtype=pl.Categorical) + expected = pl.Series(["c", "d", None], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_replace_all_enum_to_new_enum() -> None: + s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) + old = ["a", "b"] + + new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) + new = pl.Series(["c", "e"], dtype=new_dtype) + + result = s.replace_all(old, new, return_dtype=new_dtype) + + expected = pl.Series(["c", "e", None], dtype=new_dtype) + assert_series_equal(result, expected) + + +def test_replace_all_int_to_int_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select( + replaced=pl.col("int").replace_all(mapping, default=pl.lit(6).cast(pl.Int16)) + ) + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int16} + ) + assert_frame_equal(result, expected) + + +def test_replace_all_int_to_int_null_default_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select(replaced=pl.col("int").replace_all(mapping)) + expected = pl.DataFrame( + {"replaced": [None, None, None, None]}, schema={"replaced": pl.Null} + ) + assert_frame_equal(result, expected) + + +def test_replace_all_int_to_int_null_return_dtype() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + + result = df.select( + replaced=pl.col("int").replace_all(mapping, default=6, return_dtype=pl.Int32) + ) + + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int32} + ) + assert_frame_equal(result, expected) + + +def test_replace_all_empty_mapping_default() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping: dict[Any, Any] = {} + result = df.select(pl.col("int").replace_all(mapping, default=pl.lit("A"))) + expected = pl.DataFrame({"int": ["A", "A", "A", "A"]}) + assert_frame_equal(result, expected) + + +def test_replace_all_int_to_int_df() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) + mapping = {1: 11, 2: 22} + + result = lf.select( + pl.col("a").replace_all( + old=pl.Series(mapping.keys()), + new=pl.Series(mapping.values(), dtype=pl.UInt8), + default=pl.lit(99).cast(pl.UInt8), + ) + ) + expected = pl.LazyFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}) + assert_frame_equal(result, expected) + + +def test_replace_all_str_to_int_fill_null() -> None: + lf = pl.LazyFrame({"a": ["one", "two"]}) + mapping = {"one": 1} + + result = lf.select( + pl.col("a") + .replace_all(mapping, default=None, return_dtype=pl.UInt32) + .fill_null(999) + ) + + expected = pl.LazyFrame({"a": pl.Series([1, 999], dtype=pl.UInt32)}) + assert_frame_equal(result, expected) + + +def test_replace_mix() -> None: + df = pl.DataFrame( + [ + pl.Series("float_to_boolean", [1.0, None]), + pl.Series("boolean_to_int", [True, False]), + pl.Series("boolean_to_str", [True, False]), + ] + ) + + result = df.with_columns( + pl.col("float_to_boolean").replace_all({1.0: True}), + pl.col("boolean_to_int").replace_all({True: 1, False: 0}), + pl.col("boolean_to_str").replace_all({True: "1", False: "0"}), + ) + + expected = pl.DataFrame( + [ + pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), + pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), + pl.Series("boolean_to_str", ["1", "0"], dtype=pl.String), + ] + ) + assert_frame_equal(result, expected) + + +@pytest.fixture(scope="module") +def int_mapping() -> dict[int, int]: + return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} + + +def test_replace_all_int_to_int2(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5]) + result = s.replace_all(int_mapping) + expected = pl.Series([11, None, None, None, None], dtype=pl.Int64) + assert_series_equal(result, expected) + + +def test_replace_all_int_to_int3(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_all(int_mapping, default=9) + expected = pl.Series([11, 9, 9, 9, 9], dtype=pl.Int64) + assert_series_equal(result, expected) + + +def test_replace_all_int_to_int4_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([-1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_all(int_mapping, default=s, return_dtype=pl.Float32) + expected = pl.Series([-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_all_int_to_int5_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_all(int_mapping, default=9, return_dtype=pl.Float32) + expected = pl.Series([11.0, 9.0, 9.0, 9.0, 9.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_all_bool_to_int() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: 1, False: 0} + result = s.replace_all(mapping) + expected = pl.Series([1, 0, 0, None]) + assert_series_equal(result, expected) + + +def test_replace_bool_to_str() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: "1", False: "0"} + result = s.replace_all(mapping) + expected = pl.Series(["1", "0", "0", None]) + assert_series_equal(result, expected) + + +def test_replace_str_to_bool_with_default() -> None: + s = pl.Series(["True", "False", "False", None]) + mapping = {"True": True, "False": False} + result = s.replace_all(mapping) + expected = pl.Series([True, False, False, None]) + assert_series_equal(result, expected) + + +def test_replace_int_to_str() -> None: + s = pl.Series("a", [-1, 2, None, 4, -5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace_all(mapping) + + expected = pl.Series("a", [None, "two", None, "four", None]) + assert_series_equal(result, expected) + + +def test_replace_int_to_str_with_default() -> None: + s = pl.Series("a", [1, 2, None, 4, 5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace_all(mapping, default="?") + + expected = pl.Series("a", ["one", "two", "?", "four", "five"]) + assert_series_equal(result, expected) + + +def test_replace_all_str_to_int() -> None: + s = pl.Series(["a", "b"]) + mapping = {"a": 1, "b": 2} + result = s.replace_all(mapping) + expected = pl.Series([1, 2]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("context", "dtype"), + [ + (pl.StringCache(), pl.Categorical), + (pytest.warns(CategoricalRemappingWarning), pl.Categorical), + (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), + ], +) +def test_replace_cat_str( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] + dtype: pl.DataType, +) -> None: + with context: + for old, new, expected in [ + ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), + (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + ( + pl.Series(["a", "b"], dtype=dtype), + ["c", "d"], + pl.Series("s", ["c", "d"], dtype=pl.Utf8), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_all(old, new) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_all(old, new, default="OTHER") # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +@pytest.mark.parametrize( + "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] +) +def test_replace_cat_cat( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] +) -> None: + with context: + dt = pl.Categorical + for old, new, expected in [ + ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), + ( + ["a", "b"], + pl.Series(["c", "d"], dtype=dt), + pl.Series("s", ["c", "d"], dtype=dt), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_all(old, new) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_all(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +def test_map_dict_deprecated() -> None: + s = pl.Series("a", [1, 2, 3]) + with pytest.deprecated_call(): + result = s.map_dict({2: 100}) + expected = pl.Series("a", [None, 100, None]) + assert_series_equal(result, expected) + + with pytest.deprecated_call(): + result = s.to_frame().select(pl.col("a").map_dict({2: 100})).to_series() + assert_series_equal(result, expected)