From 711bfec4d410ddc4b4d35522e59fcb8a517deb70 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 8 Sep 2022 22:19:21 +0000 Subject: [PATCH 1/6] [AIR] Add `KBinsDiscretizer` Signed-off-by: Antoni Baum --- python/ray/data/preprocessors/__init__.py | 6 + python/ray/data/preprocessors/discretizer.py | 350 +++++++++++++++++++ python/ray/data/tests/test_preprocessors.py | 146 ++++++++ 3 files changed, 502 insertions(+) create mode 100644 python/ray/data/preprocessors/discretizer.py diff --git a/python/ray/data/preprocessors/__init__.py b/python/ray/data/preprocessors/__init__.py index 61ddabac4aade..780a2391a1c34 100644 --- a/python/ray/data/preprocessors/__init__.py +++ b/python/ray/data/preprocessors/__init__.py @@ -20,6 +20,10 @@ from ray.data.preprocessors.tokenizer import Tokenizer from ray.data.preprocessors.transformer import PowerTransformer from ray.data.preprocessors.vectorizer import CountVectorizer, HashingVectorizer +from ray.data.preprocessors.discretizer import ( + CustomKBinsDiscretizer, + UniformKBinsDiscretizer, +) __all__ = [ "BatchMapper", @@ -41,4 +45,6 @@ "StandardScaler", "Concatenator", "Tokenizer", + "CustomKBinsDiscretizer", + "UniformKBinsDiscretizer", ] diff --git a/python/ray/data/preprocessors/discretizer.py b/python/ray/data/preprocessors/discretizer.py new file mode 100644 index 0000000000000..239cd6c7195c7 --- /dev/null +++ b/python/ray/data/preprocessors/discretizer.py @@ -0,0 +1,350 @@ +from typing import Iterable, List, Dict, Optional, Type, Union + +import pandas as pd +import numpy as np + +from ray.data import Dataset +from ray.data.aggregate import Max, Min +from ray.data.preprocessor import Preprocessor + + +class _AbstractKBinsDiscretizer(Preprocessor): + """Abstract base class for all KBinsDiscretizers. + + Essentially a thin wraper around ``pd.cut``. + + Expects either ``self.stats_`` or ``self.bins`` to be set and + contain {column:list_of_bin_intervals}. + """ + + def _transform_pandas(self, df: pd.DataFrame): + def bin_values(s: pd.Series) -> pd.Series: + labels = self.dtypes.get(s.name) if self.dtypes else False + ordered = True + if labels: + if isinstance(labels, pd.CategoricalDtype): + ordered = labels.ordered + labels = list(labels.categories) + else: + labels = False + + bins = self.stats_ if self._is_fittable else self.bins + return pd.cut( + s, + bins[s.name] if isinstance(bins, dict) else bins, + right=self.right, + labels=labels, + ordered=ordered, + retbins=False, + include_lowest=self.include_lowest, + duplicates=self.duplicates, + ) + + return df.apply(bin_values, axis=0) + + def __repr__(self): + attr_str = ", ".join( + [ + f"{attr_name}={attr_value!r}" + for attr_name, attr_value in vars(self).items() + ] + ) + return f"{self.__class__.__name__}({attr_str})" + + +class CustomKBinsDiscretizer(_AbstractKBinsDiscretizer): + """Bin values into discrete intervals using custom bin edges. + + Columns must contain numerical values. + + Examples: + Use :class:`CustomKBinsDiscretizer` to bin continuous features. + + >>> import pandas as pd + >>> import ray + >>> from ray.data.preprocessors import CustomKBinsDiscretizer + >>> df = pd.DataFrame({ + ... "value_1": [0.2, 1.4, 2.5, 6.2, 9.7, 2.1], + ... "value_2": [10, 15, 13, 12, 23, 25], + ... }) + >>> ds = ray.data.from_pandas(df) # doctest: +SKIP + >>> discretizer = CustomKBinsDiscretizer( + ... columns=["value_1", "value_2"], + ... bins=[0, 1, 4, 10, 25] + ... ) + >>> discretizer.transform(ds).to_pandas() # doctest: +SKIP + value_1 value_2 + 0 0 2 + 1 1 3 + 2 1 3 + 3 2 3 + 4 2 3 + 5 1 3 + + You can also specify different bin edges per column. + + >>> discretizer = CustomKBinsDiscretizer( + ... columns=["value_1", "value_2"], + ... bins={"value_1": [0, 1, 4], "value_2": [0, 18, 35, 70]}, + ... ) + >>> discretizer.transform(ds).to_pandas() # doctest: +SKIP + value_1 value_2 + 0 0 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 1 1 + 5 1 1 + + + Args: + columns: The columns to discretize. + bins: Defines custom bin edges. Can be either an interable of numbers, + a ``pd.IntervalIndex``, or a dict mapping columns to either of them. + Note that ``pd.IntervalIndex`` for bins must be non-overlapping. + right: Indicates whether bins includes the rightmost edge or not. + include_lowest: Whether the first interval should be left-inclusive + or not. + duplicates: Can be either 'raise' or 'drop'. If bin edges are not unique, + raise ``ValueError`` or drop non-uniques. + dtypes: An optional dictionary that maps columns to ``pd.CategoricalDtype`` + objects or ``np.integer`` types. If you don't include a column in ``dtypes`` + or specify it as an integer dtype, the outputted column will consist of + ordered integers corresponding to bins. If you use a + ``pd.CategoricalDtype``, the outputted column will be a + ``pd.CategoricalDtype`` with the categories being mapped to bins. + You can use ``pd.CategoricalDtype(categories, ordered=True)`` to + preserve information about bin order. + + .. seealso:: + + :class:`UniformKBinsDiscretizer` + If you want to bin data into uniform width bins. + """ + + def __init__( + self, + columns: List[str], + bins: Union[ + Iterable[float], + pd.IntervalIndex, + Dict[str, Union[Iterable[float], pd.IntervalIndex]], + ], + *, + right: bool = True, + include_lowest: bool = False, + duplicates: str = "raise", + dtypes: Optional[ + Dict[str, Union[pd.CategoricalDtype, Type[np.integer]]] + ] = None, + ): + self.columns = columns + self.bins = bins + self.right = right + self.include_lowest = include_lowest + self.duplicates = duplicates + self.dtypes = dtypes + + _is_fittable = False + + +class UniformKBinsDiscretizer(_AbstractKBinsDiscretizer): + """Bin values into discrete intervals (bins) of uniform width. + + Columns must contain numerical values. + + Examples: + Use :class:`UniformKBinsDiscretizer` to bin continuous features. + + >>> import pandas as pd + >>> import ray + >>> from ray.data.preprocessors import UniformKBinsDiscretizer + >>> df = pd.DataFrame({ + ... "value_1": [0.2, 1.4, 2.5, 6.2, 9.7, 2.1], + ... "value_2": [10, 15, 13, 12, 23, 25], + ... }) + >>> ds = ray.data.from_pandas(df) # doctest: +SKIP + >>> discretizer = UniformKBinsDiscretizer( + ... columns=["value_1", "value_2"], bins=4 + ... ) + >>> discretizer.fit_transform(ds).to_pandas() # doctest: +SKIP + value_1 value_2 + 0 0 0 + 1 0 1 + 2 0 0 + 3 2 0 + 4 3 3 + 5 0 3 + + You can also specify different number of bins per column. + + >>> discretizer = UniformKBinsDiscretizer( + ... columns=["value_1", "value_2"], bins={"value_1": 4, "value_2": 3} + ... ) + >>> discretizer.fit_transform(ds).to_pandas() # doctest: +SKIP + value_1 value_2 + 0 0 0 + 1 0 0 + 2 0 0 + 3 2 0 + 4 3 2 + 5 0 2 + + + Args: + columns: The columns to discretize. + bins: Defines the number of equal-width bins. + Can be either an integer (which will be applied to all columns), + or a dict that maps columns to integers. + The range is extended by .1% on each side to include + the minimum and maximum values. + right: Indicates whether bins includes the rightmost edge or not. + include_lowest: Whether the first interval should be left-inclusive + or not. + duplicates: Can be either 'raise' or 'drop'. If bin edges are not unique, + raise ``ValueError`` or drop non-uniques. + dtypes: An optional dictionary that maps columns to ``pd.CategoricalDtype`` + objects or ``np.integer`` types. If you don't include a column in ``dtypes`` + or specify it as an integer dtype, the outputted column will consist of + ordered integers corresponding to bins. If you use a + ``pd.CategoricalDtype``, the outputted column will be a + ``pd.CategoricalDtype`` with the categories being mapped to bins. + You can use ``pd.CategoricalDtype(categories, ordered=True)`` to + preserve information about bin order. + + .. seealso:: + + :class:`CustomKBinsDiscretizer` + If you want to specify your own bin edges. + """ + + def __init__( + self, + columns: List[str], + bins: Union[int, Dict[str, int]], + *, + right: bool = True, + include_lowest: bool = False, + duplicates: str = "raise", + dtypes: Optional[ + Dict[str, Union[pd.CategoricalDtype, Type[np.integer]]] + ] = None, + ): + self.columns = columns + self.bins = bins + self.right = right + self.include_lowest = include_lowest + self.duplicates = duplicates + self.dtypes = dtypes + + def _fit(self, dataset: Dataset) -> Preprocessor: + self._validate_on_fit() + stats = {} + if isinstance(self.bins, dict): + for column, bin in self.bins.items(): + stats[column] = self._fit_uniform_covert_bin_to_aggregate_if_needed( + column, bin + ) + else: + for column in self.columns: + stats[column] = self._fit_uniform_covert_bin_to_aggregate_if_needed( + column, self.bins + ) + + aggregates = [] + bin_sizes = {} + for column, stat in stats.items(): + aggregates.extend(stat[:-1]) + bin_sizes[column] = stat[-1] + + if aggregates: + aggregate_stats = dataset.aggregate(*aggregates) + mins = {} + maxes = {} + for key, value in aggregate_stats.items(): + column_name = key[4:-1] # min(column) -> column + if key.startswith("min"): + mins[column_name] = value + if key.startswith("max"): + maxes[column_name] = value + + for column in mins.keys(): + stats[column] = _translate_min_max_number_of_bins_to_bin_edges( + mins[column], maxes[column], bin_sizes[column], self.right + ) + + self.stats_ = stats + return self + + def _validate_on_fit(self): + pass + + def _fit_uniform_covert_bin_to_aggregate_if_needed(self, column: str, bin): + if isinstance(bin, int): + return (Min(column), Max(column), bin) + else: + raise TypeError() + + +# Copied from +# https://github.com/pandas-dev/pandas/blob/v1.4.4/pandas/core/reshape/tile.py#L257 +# under +# BSD 3-Clause License +# +# Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. +# and PyData Development Team +# All rights reserved. +# +# Copyright (c) 2011-2022, Open source contributors. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +def _translate_min_max_number_of_bins_to_bin_edges( + mn: float, mx: float, bins: int, right: bool +) -> List[float]: + """Translates a range and desired number of bins into list of bin edges.""" + rng = (mn, mx) + mn, mx = (mi + 0.0 for mi in rng) + + if np.isinf(mn) or np.isinf(mx): + raise ValueError( + "Cannot specify integer `bins` when input data contains infinity." + ) + elif mn == mx: # adjust end points before binning + mn -= 0.001 * abs(mn) if mn != 0 else 0.001 + mx += 0.001 * abs(mx) if mx != 0 else 0.001 + bins = np.linspace(mn, mx, bins + 1, endpoint=True) + else: # adjust end points after binning + bins = np.linspace(mn, mx, bins + 1, endpoint=True) + adj = (mx - mn) * 0.001 # 0.1% of the range + if right: + bins[0] -= adj + else: + bins[-1] += adj + return bins + + +# TODO(ml-team) +# Add QuantileKBinsDiscretizer diff --git a/python/ray/data/tests/test_preprocessors.py b/python/ray/data/tests/test_preprocessors.py index 362932b2576b0..1818aa80cb87c 100644 --- a/python/ray/data/tests/test_preprocessors.py +++ b/python/ray/data/tests/test_preprocessors.py @@ -20,6 +20,8 @@ OrdinalEncoder, SimpleImputer, StandardScaler, + CustomKBinsDiscretizer, + UniformKBinsDiscretizer, ) from ray.data.preprocessors.encoder import Categorizer, MultiHotEncoder from ray.data.preprocessors.hasher import FeatureHasher @@ -1278,6 +1280,150 @@ def test_concatenator(): ctx.enable_tensor_extension_casting = old_config +@pytest.mark.parametrize("bins", (3, {"A": 4, "B": 3})) +@pytest.mark.parametrize( + "dtypes", + ( + None, + {"A": int, "B": int}, + {"A": int, "B": pd.CategoricalDtype(["cat1", "cat2", "cat3"], ordered=True)}, + ), +) +@pytest.mark.parametrize("right", (True, False)) +@pytest.mark.parametrize("include_lowest", (True, False)) +def test_uniform_kbins_discretizer( + bins, + dtypes, + right, + include_lowest, +): + """Tests basic UniformKBinsDiscretizer functionality.""" + + col_a = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + col_b = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b}) + ds = ray.data.from_pandas(in_df).repartition(2) + + discretizer = UniformKBinsDiscretizer( + ["A", "B"], bins=bins, dtypes=dtypes, right=right, include_lowest=include_lowest + ) + + transformed = discretizer.fit_transform(ds) + out_df = transformed.to_pandas() + + if isinstance(bins, dict): + bins_A = bins["A"] + bins_B = bins["B"] + else: + bins_A = bins_B = bins + + labels_A = False + ordered_A = True + labels_B = False + ordered_B = True + if isinstance(dtypes, dict): + if isinstance(dtypes.get("A"), pd.CategoricalDtype): + labels_A = dtypes.get("A").categories + ordered_A = dtypes.get("A").ordered + if isinstance(dtypes.get("B"), pd.CategoricalDtype): + labels_B = dtypes.get("B").categories + ordered_B = dtypes.get("B").ordered + + assert out_df["A"].equals( + pd.cut( + in_df["A"], + bins_A, + labels=labels_A, + ordered=ordered_A, + right=right, + include_lowest=include_lowest, + ) + ) + assert out_df["B"].equals( + pd.cut( + in_df["B"], + bins_B, + labels=labels_B, + ordered=ordered_B, + right=right, + include_lowest=include_lowest, + ) + ) + + +@pytest.mark.parametrize( + "bins", ([3, 4, 6, 9], {"A": [3, 4, 6, 8, 9], "B": [3, 4, 6, 9]}) +) +@pytest.mark.parametrize( + "dtypes", + ( + None, + {"A": int, "B": int}, + {"A": int, "B": pd.CategoricalDtype(["cat1", "cat2", "cat3"], ordered=True)}, + ), +) +@pytest.mark.parametrize("right", (True, False)) +@pytest.mark.parametrize("include_lowest", (True, False)) +def test_custom_kbins_discretizer( + bins, + dtypes, + right, + include_lowest, +): + """Tests basic CustomKBinsDiscretizer functionality.""" + + col_a = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + col_b = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b}) + ds = ray.data.from_pandas(in_df).repartition(2) + + discretizer = CustomKBinsDiscretizer( + ["A", "B"], bins=bins, dtypes=dtypes, right=right, include_lowest=include_lowest + ) + + transformed = discretizer.transform(ds) + out_df = transformed.to_pandas() + + if isinstance(bins, dict): + bins_A = bins["A"] + bins_B = bins["B"] + else: + bins_A = bins_B = bins + + labels_A = False + ordered_A = True + labels_B = False + ordered_B = True + if isinstance(dtypes, dict): + if isinstance(dtypes.get("A"), pd.CategoricalDtype): + labels_A = dtypes.get("A").categories + ordered_A = dtypes.get("A").ordered + if isinstance(dtypes.get("B"), pd.CategoricalDtype): + labels_B = dtypes.get("B").categories + ordered_B = dtypes.get("B").ordered + + assert out_df["A"].equals( + pd.cut( + in_df["A"], + bins_A, + labels=labels_A, + ordered=ordered_A, + right=right, + include_lowest=include_lowest, + ) + ) + assert out_df["B"].equals( + pd.cut( + in_df["B"], + bins_B, + labels=labels_B, + ordered=ordered_B, + right=right, + include_lowest=include_lowest, + ) + ) + + def test_tokenizer(): """Tests basic Tokenizer functionality.""" From 87323a99b946ee2a16cfdead791ecfe21491eb1d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 8 Sep 2022 22:29:58 +0000 Subject: [PATCH 2/6] Add docs Signed-off-by: Antoni Baum --- doc/source/custom_directives.py | 1 + doc/source/ray-air/package-ref.rst | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index 4fc8a3176dbed..c674d471d12cf 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -59,6 +59,7 @@ def update_context(app, pagename, templatename, context, doctree): "dask.distributed", "datasets", "datasets.iterable_dataset", + "datasets.load", "gym", "gym.spaces", "horovod", diff --git a/doc/source/ray-air/package-ref.rst b/doc/source/ray-air/package-ref.rst index 8b0bdcc5684fb..289577d81c8ea 100644 --- a/doc/source/ray-air/package-ref.rst +++ b/doc/source/ray-air/package-ref.rst @@ -74,6 +74,15 @@ Feature Scalers .. autoclass:: ray.data.preprocessors.StandardScaler :show-inheritance: +K-Bins Discretizers +################### + +.. autoclass:: ray.data.preprocessors.CustomKBinsDiscretizer + :show-inheritance: + +.. autoclass:: ray.data.preprocessors.UniformKBinsDiscretizer + :show-inheritance: + Text Encoders ############# From 36659b84cfdb2afb7597b2269248ed9ba9e19cab Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 8 Sep 2022 22:39:04 +0000 Subject: [PATCH 3/6] Cleanup Signed-off-by: Antoni Baum --- python/ray/data/preprocessors/discretizer.py | 71 ++++++++++---------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/python/ray/data/preprocessors/discretizer.py b/python/ray/data/preprocessors/discretizer.py index 239cd6c7195c7..4b04ff95286a7 100644 --- a/python/ray/data/preprocessors/discretizer.py +++ b/python/ray/data/preprocessors/discretizer.py @@ -240,50 +240,53 @@ def __init__( def _fit(self, dataset: Dataset) -> Preprocessor: self._validate_on_fit() stats = {} + aggregates = [] if isinstance(self.bins, dict): - for column, bin in self.bins.items(): - stats[column] = self._fit_uniform_covert_bin_to_aggregate_if_needed( - column, bin - ) + columns = self.bins.keys() else: - for column in self.columns: - stats[column] = self._fit_uniform_covert_bin_to_aggregate_if_needed( - column, self.bins - ) + columns = self.columns - aggregates = [] - bin_sizes = {} - for column, stat in stats.items(): - aggregates.extend(stat[:-1]) - bin_sizes[column] = stat[-1] - - if aggregates: - aggregate_stats = dataset.aggregate(*aggregates) - mins = {} - maxes = {} - for key, value in aggregate_stats.items(): - column_name = key[4:-1] # min(column) -> column - if key.startswith("min"): - mins[column_name] = value - if key.startswith("max"): - maxes[column_name] = value - - for column in mins.keys(): - stats[column] = _translate_min_max_number_of_bins_to_bin_edges( - mins[column], maxes[column], bin_sizes[column], self.right - ) + for column in columns: + aggregates.extend( + self._fit_uniform_covert_bin_to_aggregate_if_needed(column) + ) + + aggregate_stats = dataset.aggregate(*aggregates) + mins = {} + maxes = {} + for key, value in aggregate_stats.items(): + column_name = key[4:-1] # min(column) -> column + if key.startswith("min"): + mins[column_name] = value + if key.startswith("max"): + maxes[column_name] = value + + for column in mins.keys(): + bins = self.bins[column] if isinstance(self.bins, dict) else self.bins + stats[column] = _translate_min_max_number_of_bins_to_bin_edges( + mins[column], maxes[column], bins, self.right + ) self.stats_ = stats return self def _validate_on_fit(self): - pass + if isinstance(self.bins, dict) and not all( + col in self.bins for col in self.columns + ): + raise ValueError( + "If `bins` is a dictionary, all elements of `columns` must be present " + "in it." + ) - def _fit_uniform_covert_bin_to_aggregate_if_needed(self, column: str, bin): - if isinstance(bin, int): - return (Min(column), Max(column), bin) + def _fit_uniform_covert_bin_to_aggregate_if_needed(self, column: str): + bins = self.bins[column] if isinstance(self.bins, dict) else self.bins + if isinstance(bins, int): + return (Min(column), Max(column)) else: - raise TypeError() + raise TypeError( + f"`bins` must be an integer or a dict of integers, got {bins}" + ) # Copied from From b98236b4307a1b6d7165e4c01917128096eaa695 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 8 Sep 2022 22:41:19 +0000 Subject: [PATCH 4/6] Cleanup Signed-off-by: Antoni Baum --- python/ray/data/preprocessors/discretizer.py | 21 +++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/ray/data/preprocessors/discretizer.py b/python/ray/data/preprocessors/discretizer.py index 4b04ff95286a7..b601b91435966 100644 --- a/python/ray/data/preprocessors/discretizer.py +++ b/python/ray/data/preprocessors/discretizer.py @@ -42,6 +42,15 @@ def bin_values(s: pd.Series) -> pd.Series: return df.apply(bin_values, axis=0) + def _validate_bins_columns(self): + if isinstance(self.bins, dict) and not all( + col in self.bins for col in self.columns + ): + raise ValueError( + "If `bins` is a dictionary, all elements of `columns` must be present " + "in it." + ) + def __repr__(self): attr_str = ", ".join( [ @@ -147,6 +156,10 @@ def __init__( _is_fittable = False + def _transform(self, dataset: Dataset) -> Dataset: + self._validate_bins_columns() + return super()._transform(dataset) + class UniformKBinsDiscretizer(_AbstractKBinsDiscretizer): """Bin values into discrete intervals (bins) of uniform width. @@ -271,13 +284,7 @@ def _fit(self, dataset: Dataset) -> Preprocessor: return self def _validate_on_fit(self): - if isinstance(self.bins, dict) and not all( - col in self.bins for col in self.columns - ): - raise ValueError( - "If `bins` is a dictionary, all elements of `columns` must be present " - "in it." - ) + self._validate_bins_columns() def _fit_uniform_covert_bin_to_aggregate_if_needed(self, column: str): bins = self.bins[column] if isinstance(self.bins, dict) else self.bins From 33465e015ab9a18eee02a8d162d881bed5682c3a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 12 Sep 2022 02:13:46 +0200 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Balaji Veeramani Signed-off-by: Antoni Baum --- python/ray/data/preprocessors/discretizer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/data/preprocessors/discretizer.py b/python/ray/data/preprocessors/discretizer.py index b601b91435966..0d17eaf395497 100644 --- a/python/ray/data/preprocessors/discretizer.py +++ b/python/ray/data/preprocessors/discretizer.py @@ -108,12 +108,11 @@ class CustomKBinsDiscretizer(_AbstractKBinsDiscretizer): Args: columns: The columns to discretize. - bins: Defines custom bin edges. Can be either an interable of numbers, + bins: Defines custom bin edges. Can be an iterable of numbers, a ``pd.IntervalIndex``, or a dict mapping columns to either of them. Note that ``pd.IntervalIndex`` for bins must be non-overlapping. - right: Indicates whether bins includes the rightmost edge or not. - include_lowest: Whether the first interval should be left-inclusive - or not. + right: Indicates whether bins include the rightmost edge. + include_lowest: Indicates whether the first interval should be left-inclusive. duplicates: Can be either 'raise' or 'drop'. If bin edges are not unique, raise ``ValueError`` or drop non-uniques. dtypes: An optional dictionary that maps columns to ``pd.CategoricalDtype`` From 464d4c132b53002871883836aaf55d0c4f97154a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 12 Sep 2022 16:42:09 +0000 Subject: [PATCH 6/6] Apply suggestions from code review Signed-off-by: Antoni Baum --- python/ray/data/preprocessors/discretizer.py | 25 ++++++++++---------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/python/ray/data/preprocessors/discretizer.py b/python/ray/data/preprocessors/discretizer.py index 0d17eaf395497..11e434fa4384c 100644 --- a/python/ray/data/preprocessors/discretizer.py +++ b/python/ray/data/preprocessors/discretizer.py @@ -56,6 +56,7 @@ def __repr__(self): [ f"{attr_name}={attr_value!r}" for attr_name, attr_value in vars(self).items() + if not attr_name.startswith("_") ] ) return f"{self.__class__.__name__}({attr_str})" @@ -76,12 +77,12 @@ class CustomKBinsDiscretizer(_AbstractKBinsDiscretizer): ... "value_1": [0.2, 1.4, 2.5, 6.2, 9.7, 2.1], ... "value_2": [10, 15, 13, 12, 23, 25], ... }) - >>> ds = ray.data.from_pandas(df) # doctest: +SKIP + >>> ds = ray.data.from_pandas(df) >>> discretizer = CustomKBinsDiscretizer( ... columns=["value_1", "value_2"], ... bins=[0, 1, 4, 10, 25] ... ) - >>> discretizer.transform(ds).to_pandas() # doctest: +SKIP + >>> discretizer.transform(ds).to_pandas() value_1 value_2 0 0 2 1 1 3 @@ -96,14 +97,14 @@ class CustomKBinsDiscretizer(_AbstractKBinsDiscretizer): ... columns=["value_1", "value_2"], ... bins={"value_1": [0, 1, 4], "value_2": [0, 18, 35, 70]}, ... ) - >>> discretizer.transform(ds).to_pandas() # doctest: +SKIP + >>> discretizer.transform(ds).to_pandas() value_1 value_2 - 0 0 0 - 1 1 0 - 2 1 0 - 3 1 0 - 4 1 1 - 5 1 1 + 0 0.0 0 + 1 1.0 0 + 2 1.0 0 + 3 NaN 0 + 4 NaN 1 + 5 1.0 1 Args: @@ -175,11 +176,11 @@ class UniformKBinsDiscretizer(_AbstractKBinsDiscretizer): ... "value_1": [0.2, 1.4, 2.5, 6.2, 9.7, 2.1], ... "value_2": [10, 15, 13, 12, 23, 25], ... }) - >>> ds = ray.data.from_pandas(df) # doctest: +SKIP + >>> ds = ray.data.from_pandas(df) >>> discretizer = UniformKBinsDiscretizer( ... columns=["value_1", "value_2"], bins=4 ... ) - >>> discretizer.fit_transform(ds).to_pandas() # doctest: +SKIP + >>> discretizer.fit_transform(ds).to_pandas() value_1 value_2 0 0 0 1 0 1 @@ -193,7 +194,7 @@ class UniformKBinsDiscretizer(_AbstractKBinsDiscretizer): >>> discretizer = UniformKBinsDiscretizer( ... columns=["value_1", "value_2"], bins={"value_1": 4, "value_2": 3} ... ) - >>> discretizer.fit_transform(ds).to_pandas() # doctest: +SKIP + >>> discretizer.fit_transform(ds).to_pandas() value_1 value_2 0 0 0 1 0 0