Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Allow use of Python types in cs.by_dtype and col #20491

Merged
merged 1 commit into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 64 additions & 13 deletions py-polars/polars/functions/col.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,40 @@

import contextlib
from collections.abc import Iterable
from datetime import datetime, timedelta
from typing import TYPE_CHECKING

from polars._utils.wrap import wrap_expr
from polars.datatypes import is_polars_dtype
from polars.datatypes import Datetime, Duration, is_polars_dtype, parse_into_dtype
from polars.datatypes.group import (
DATETIME_DTYPES,
DURATION_DTYPES,
FLOAT_DTYPES,
INTEGER_DTYPES,
)

with contextlib.suppress(ImportError): # Module not available when building docs
import polars.polars as plr

if TYPE_CHECKING:
from polars._typing import PolarsDataType
from polars._typing import PolarsDataType, PythonDataType
from polars.expr.expr import Expr

__all__ = ["col"]


def _create_col(
name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType],
*more_names: str | PolarsDataType,
name: (
str
| PolarsDataType
| PythonDataType
| Iterable[str]
| Iterable[PolarsDataType | PythonDataType]
),
*more_names: str | PolarsDataType | PythonDataType,
) -> Expr:
"""Create one or more column expressions representing column(s) in a DataFrame."""
dtypes: list[PolarsDataType]
if more_names:
if isinstance(name, str):
names_str = [name]
Expand All @@ -41,7 +55,11 @@ def _create_col(
if isinstance(name, str):
return wrap_expr(plr.col(name))
elif is_polars_dtype(name):
return wrap_expr(plr.dtype_cols([name]))
dtypes = _polars_dtype_match(name)
return wrap_expr(plr.dtype_cols(dtypes))
elif isinstance(name, type):
dtypes = _python_dtype_match(name)
return wrap_expr(plr.dtype_cols(dtypes))
elif isinstance(name, Iterable):
names = list(name)
if not names:
Expand All @@ -51,7 +69,15 @@ def _create_col(
if isinstance(item, str):
return wrap_expr(plr.cols(names))
elif is_polars_dtype(item):
return wrap_expr(plr.dtype_cols(names))
dtypes = []
for nm in names:
dtypes.extend(_polars_dtype_match(nm)) # type: ignore[arg-type]
return wrap_expr(plr.dtype_cols(dtypes))
elif isinstance(item, type):
dtypes = []
for nm in names:
dtypes.extend(_python_dtype_match(nm)) # type: ignore[arg-type]
return wrap_expr(plr.dtype_cols(dtypes))
else:
msg = (
"invalid input for `col`"
Expand All @@ -67,6 +93,26 @@ def _create_col(
raise TypeError(msg)


def _python_dtype_match(tp: PythonDataType) -> list[PolarsDataType]:
if tp is int:
return list(INTEGER_DTYPES)
elif tp is float:
return list(FLOAT_DTYPES)
elif tp is datetime:
return list(DATETIME_DTYPES)
elif tp is timedelta:
return list(DURATION_DTYPES)
return [parse_into_dtype(tp)]


def _polars_dtype_match(tp: PolarsDataType) -> list[PolarsDataType]:
if Datetime.is_(tp):
return list(DATETIME_DTYPES)
elif Duration.is_(tp):
return list(DURATION_DTYPES)
return [tp]


class Col:
"""
Create Polars column expressions.
Expand All @@ -79,8 +125,7 @@ class Col:

This helper class enables an alternative syntax for creating a column expression
through attribute lookup. For example `col.foo` creates an expression equal to
`col("foo")`.
See the :func:`__getattr__` method for further documentation.
`col("foo")`. See the :func:`__getattr__` method for further documentation.

The function call syntax is considered the idiomatic way of constructing a column
expression. The alternative attribute syntax can be useful for quick prototyping as
Expand Down Expand Up @@ -126,18 +171,24 @@ class Col:

def __call__(
self,
name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType],
*more_names: str | PolarsDataType,
name: (
str
| PolarsDataType
| PythonDataType
| Iterable[str]
| Iterable[PolarsDataType | PythonDataType]
),
*more_names: str | PolarsDataType | PythonDataType,
) -> Expr:
"""
Create one or more column expressions representing column(s) in a DataFrame.
Create one or more expressions representing columns in a DataFrame.

Parameters
----------
name
The name or datatype of the column(s) to represent.
Accepts regular expression input.
Regular expressions should start with `^` and end with `$`.
Accepts regular expression input; regular expressions
should start with `^` and end with `$`.
*more_names
Additional names or datatypes of columns to represent,
specified as positional arguments.
Expand Down
16 changes: 11 additions & 5 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@

if TYPE_CHECKING:
import sys
from collections.abc import Iterable

from polars import DataFrame, LazyFrame
from polars._typing import PolarsDataType, SelectorType, TimeUnit
from polars._typing import PolarsDataType, PythonDataType, SelectorType, TimeUnit

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -868,7 +869,12 @@ def boolean() -> SelectorType:


def by_dtype(
*dtypes: PolarsDataType | Collection[PolarsDataType],
*dtypes: (
PolarsDataType
| PythonDataType
| Iterable[PolarsDataType]
| Iterable[PythonDataType]
),
) -> SelectorType:
"""
Select all columns matching the given dtypes.
Expand Down Expand Up @@ -931,13 +937,13 @@ def by_dtype(
│ foo ┆ -3265500 │
└───────┴──────────┘
"""
all_dtypes: list[PolarsDataType] = []
all_dtypes: list[PolarsDataType | PythonDataType] = []
for tp in dtypes:
if is_polars_dtype(tp):
if is_polars_dtype(tp) or isinstance(tp, type):
all_dtypes.append(tp)
elif isinstance(tp, Collection):
for t in tp:
if not is_polars_dtype(t):
if not (is_polars_dtype(t) or isinstance(t, type)):
msg = f"invalid dtype: {t!r}"
raise TypeError(msg)
all_dtypes.append(t)
Expand Down
55 changes: 54 additions & 1 deletion py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, timedelta
from decimal import Decimal as PyDecimal
from typing import Any
from zoneinfo import ZoneInfo

Expand Down Expand Up @@ -110,6 +111,58 @@ def test_selector_by_dtype(df: pl.DataFrame) -> None:
"qqR": pl.String(),
}
)
assert df.select(
cs.by_dtype(pl.Datetime("ns"), pl.Float32, pl.UInt32, pl.Date)
).schema == pl.Schema(
{
"bbb": pl.UInt32,
"def": pl.Float32,
"JJK": pl.Date,
}
)

# select using python types
assert df.select(cs.by_dtype(int, float)).schema == pl.Schema(
{
"abc": pl.UInt16,
"bbb": pl.UInt32,
"cde": pl.Float64,
"def": pl.Float32,
}
)
assert df.select(cs.by_dtype(bool, datetime, timedelta)).schema == pl.Schema(
{
"eee": pl.Boolean(),
"fgg": pl.Boolean(),
"Lmn": pl.Duration("us"),
"opp": pl.Datetime("ms"),
}
)

# cover timezones and decimal
dfx = pl.DataFrame(
{"idx": [], "dt1": [], "dt2": []},
schema_overrides={
"idx": pl.Decimal(24),
"dt1": pl.Datetime("ms"),
"dt2": pl.Datetime(time_zone="Asia/Tokyo"),
},
)
assert dfx.select(cs.by_dtype(PyDecimal)).schema == pl.Schema(
{"idx": pl.Decimal(24)},
)
assert dfx.select(cs.by_dtype(pl.Datetime(time_zone="*"))).schema == pl.Schema(
{"dt2": pl.Datetime(time_zone="Asia/Tokyo")}
)
assert dfx.select(cs.by_dtype(pl.Datetime("ms", None))).schema == pl.Schema(
{"dt1": pl.Datetime("ms")},
)
for dt in (datetime, pl.Datetime):
assert dfx.select(cs.by_dtype(dt)).schema == pl.Schema(
{"dt1": pl.Datetime("ms"), "dt2": pl.Datetime(time_zone="Asia/Tokyo")},
)

# empty selection selects nothing
assert df.select(cs.by_dtype()).schema == {}
assert df.select(cs.by_dtype([])).schema == {}

Expand Down
Loading