From e0df70693f7cc6eb66b2944042a5f6c7bf2638a0 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sun, 29 Dec 2024 08:16:37 +0000 Subject: [PATCH] feat(python): Allow use of Python types in `cs.by_dtype` and `col` --- py-polars/polars/functions/col.py | 77 +++++++++++++++++++++----- py-polars/polars/selectors.py | 16 ++++-- py-polars/tests/unit/test_selectors.py | 55 +++++++++++++++++- 3 files changed, 129 insertions(+), 19 deletions(-) diff --git a/py-polars/polars/functions/col.py b/py-polars/polars/functions/col.py index b354cd878176..ff85be2019ce 100644 --- a/py-polars/polars/functions/col.py +++ b/py-polars/polars/functions/col.py @@ -2,26 +2,40 @@ import contextlib from collections.abc import Iterable +from datetime import datetime, timedelta from typing import TYPE_CHECKING from polars._utils.wrap import wrap_expr -from polars.datatypes import is_polars_dtype +from polars.datatypes import Datetime, Duration, is_polars_dtype, parse_into_dtype +from polars.datatypes.group import ( + DATETIME_DTYPES, + DURATION_DTYPES, + FLOAT_DTYPES, + INTEGER_DTYPES, +) with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr if TYPE_CHECKING: - from polars._typing import PolarsDataType + from polars._typing import PolarsDataType, PythonDataType from polars.expr.expr import Expr __all__ = ["col"] def _create_col( - name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType], - *more_names: str | PolarsDataType, + name: ( + str + | PolarsDataType + | PythonDataType + | Iterable[str] + | Iterable[PolarsDataType | PythonDataType] + ), + *more_names: str | PolarsDataType | PythonDataType, ) -> Expr: """Create one or more column expressions representing column(s) in a DataFrame.""" + dtypes: list[PolarsDataType] if more_names: if isinstance(name, str): names_str = [name] @@ -41,7 +55,11 @@ def _create_col( if isinstance(name, str): return wrap_expr(plr.col(name)) elif is_polars_dtype(name): - return wrap_expr(plr.dtype_cols([name])) + dtypes = _polars_dtype_match(name) + return wrap_expr(plr.dtype_cols(dtypes)) + elif isinstance(name, type): + dtypes = _python_dtype_match(name) + return wrap_expr(plr.dtype_cols(dtypes)) elif isinstance(name, Iterable): names = list(name) if not names: @@ -51,7 +69,15 @@ def _create_col( if isinstance(item, str): return wrap_expr(plr.cols(names)) elif is_polars_dtype(item): - return wrap_expr(plr.dtype_cols(names)) + dtypes = [] + for nm in names: + dtypes.extend(_polars_dtype_match(nm)) # type: ignore[arg-type] + return wrap_expr(plr.dtype_cols(dtypes)) + elif isinstance(item, type): + dtypes = [] + for nm in names: + dtypes.extend(_python_dtype_match(nm)) # type: ignore[arg-type] + return wrap_expr(plr.dtype_cols(dtypes)) else: msg = ( "invalid input for `col`" @@ -67,6 +93,26 @@ def _create_col( raise TypeError(msg) +def _python_dtype_match(tp: PythonDataType) -> list[PolarsDataType]: + if tp is int: + return list(INTEGER_DTYPES) + elif tp is float: + return list(FLOAT_DTYPES) + elif tp is datetime: + return list(DATETIME_DTYPES) + elif tp is timedelta: + return list(DURATION_DTYPES) + return [parse_into_dtype(tp)] + + +def _polars_dtype_match(tp: PolarsDataType) -> list[PolarsDataType]: + if Datetime.is_(tp): + return list(DATETIME_DTYPES) + elif Duration.is_(tp): + return list(DURATION_DTYPES) + return [tp] + + class Col: """ Create Polars column expressions. @@ -79,8 +125,7 @@ class Col: This helper class enables an alternative syntax for creating a column expression through attribute lookup. For example `col.foo` creates an expression equal to - `col("foo")`. - See the :func:`__getattr__` method for further documentation. + `col("foo")`. See the :func:`__getattr__` method for further documentation. The function call syntax is considered the idiomatic way of constructing a column expression. The alternative attribute syntax can be useful for quick prototyping as @@ -126,18 +171,24 @@ class Col: def __call__( self, - name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType], - *more_names: str | PolarsDataType, + name: ( + str + | PolarsDataType + | PythonDataType + | Iterable[str] + | Iterable[PolarsDataType | PythonDataType] + ), + *more_names: str | PolarsDataType | PythonDataType, ) -> Expr: """ - Create one or more column expressions representing column(s) in a DataFrame. + Create one or more expressions representing columns in a DataFrame. Parameters ---------- name The name or datatype of the column(s) to represent. - Accepts regular expression input. - Regular expressions should start with `^` and end with `$`. + Accepts regular expression input; regular expressions + should start with `^` and end with `$`. *more_names Additional names or datatypes of columns to represent, specified as positional arguments. diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 9d3cedb47e85..fde8390d9e5b 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -40,9 +40,10 @@ if TYPE_CHECKING: import sys + from collections.abc import Iterable from polars import DataFrame, LazyFrame - from polars._typing import PolarsDataType, SelectorType, TimeUnit + from polars._typing import PolarsDataType, PythonDataType, SelectorType, TimeUnit if sys.version_info >= (3, 11): from typing import Self @@ -868,7 +869,12 @@ def boolean() -> SelectorType: def by_dtype( - *dtypes: PolarsDataType | Collection[PolarsDataType], + *dtypes: ( + PolarsDataType + | PythonDataType + | Iterable[PolarsDataType] + | Iterable[PythonDataType] + ), ) -> SelectorType: """ Select all columns matching the given dtypes. @@ -931,13 +937,13 @@ def by_dtype( │ foo ┆ -3265500 │ └───────┴──────────┘ """ - all_dtypes: list[PolarsDataType] = [] + all_dtypes: list[PolarsDataType | PythonDataType] = [] for tp in dtypes: - if is_polars_dtype(tp): + if is_polars_dtype(tp) or isinstance(tp, type): all_dtypes.append(tp) elif isinstance(tp, Collection): for t in tp: - if not is_polars_dtype(t): + if not (is_polars_dtype(t) or isinstance(t, type)): msg = f"invalid dtype: {t!r}" raise TypeError(msg) all_dtypes.append(t) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index 5b0358011101..8b13f1bee909 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -1,5 +1,6 @@ from collections import OrderedDict -from datetime import datetime +from datetime import datetime, timedelta +from decimal import Decimal as PyDecimal from typing import Any from zoneinfo import ZoneInfo @@ -110,6 +111,58 @@ def test_selector_by_dtype(df: pl.DataFrame) -> None: "qqR": pl.String(), } ) + assert df.select( + cs.by_dtype(pl.Datetime("ns"), pl.Float32, pl.UInt32, pl.Date) + ).schema == pl.Schema( + { + "bbb": pl.UInt32, + "def": pl.Float32, + "JJK": pl.Date, + } + ) + + # select using python types + assert df.select(cs.by_dtype(int, float)).schema == pl.Schema( + { + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + } + ) + assert df.select(cs.by_dtype(bool, datetime, timedelta)).schema == pl.Schema( + { + "eee": pl.Boolean(), + "fgg": pl.Boolean(), + "Lmn": pl.Duration("us"), + "opp": pl.Datetime("ms"), + } + ) + + # cover timezones and decimal + dfx = pl.DataFrame( + {"idx": [], "dt1": [], "dt2": []}, + schema_overrides={ + "idx": pl.Decimal(24), + "dt1": pl.Datetime("ms"), + "dt2": pl.Datetime(time_zone="Asia/Tokyo"), + }, + ) + assert dfx.select(cs.by_dtype(PyDecimal)).schema == pl.Schema( + {"idx": pl.Decimal(24)}, + ) + assert dfx.select(cs.by_dtype(pl.Datetime(time_zone="*"))).schema == pl.Schema( + {"dt2": pl.Datetime(time_zone="Asia/Tokyo")} + ) + assert dfx.select(cs.by_dtype(pl.Datetime("ms", None))).schema == pl.Schema( + {"dt1": pl.Datetime("ms")}, + ) + for dt in (datetime, pl.Datetime): + assert dfx.select(cs.by_dtype(dt)).schema == pl.Schema( + {"dt1": pl.Datetime("ms"), "dt2": pl.Datetime(time_zone="Asia/Tokyo")}, + ) + + # empty selection selects nothing assert df.select(cs.by_dtype()).schema == {} assert df.select(cs.by_dtype([])).schema == {}