Skip to content

Commit

Permalink
feat: track whether expressions change length but don't aggregate, an…
Browse files Browse the repository at this point in the history
…d only allow length-changing expressions if they're followed by aggregations in the lazy API (#1828)
  • Loading branch information
MarcoGorelli authored Jan 20, 2025
1 parent 5ca7688 commit b26358b
Show file tree
Hide file tree
Showing 17 changed files with 581 additions and 106 deletions.
56 changes: 13 additions & 43 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any
from typing import Callable
from typing import Literal
from typing import NoReturn
from typing import Sequence

from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
Expand Down Expand Up @@ -448,46 +447,27 @@ def round(self, decimals: int) -> Self:
returns_scalar=self._returns_scalar,
)

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> NoReturn:
msg = "`Expr.ewm_mean` is not supported for the Dask backend"
raise NotImplementedError(msg)

def unique(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead."
raise NotImplementedError(msg)

def drop_nulls(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.drop_nulls` is not supported for the Dask backend. Please use `LazyFrame.drop_nulls` instead."
raise NotImplementedError(msg)
def unique(self, *, maintain_order: bool) -> Self:
# TODO(marco): maintain_order has no effect and will be deprecated
return self._from_call(
lambda _input: _input.unique(),
"unique",
returns_scalar=self._returns_scalar,
)

def head(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.head` is not supported for the Dask backend. Please use `LazyFrame.head` instead."
raise NotImplementedError(msg)
def drop_nulls(self) -> Self:
return self._from_call(
lambda _input: _input.dropna(),
"drop_nulls",
returns_scalar=self._returns_scalar,
)

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
msg = "`replace_strict` is not yet supported for Dask expressions"
raise NotImplementedError(msg)

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead."
raise NotImplementedError(msg)

def abs(self) -> Self:
return self._from_call(
lambda _input: _input.abs(), "abs", returns_scalar=self._returns_scalar
Expand Down Expand Up @@ -678,16 +658,6 @@ def null_count(self: Self) -> Self:
returns_scalar=True,
)

def tail(self: Self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.tail` is not supported for the Dask backend. Please use `LazyFrame.tail` instead."
raise NotImplementedError(msg)

def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead."
raise NotImplementedError(msg)

def over(self: Self, keys: list[str]) -> Self:
def func(df: DaskLazyFrame) -> list[Any]:
if self._output_names is None:
Expand Down
41 changes: 41 additions & 0 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidIntoExprError
from narwhals.exceptions import LengthChangingExprError
from narwhals.utils import Implementation

if TYPE_CHECKING:
Expand Down Expand Up @@ -342,3 +343,43 @@ def operation_is_order_dependent(*args: IntoExpr | Any) -> bool:
# it means that it was a scalar (e.g. nw.col('a') + 1) or a column name,
# neither of which is order-dependent, so we default to `False`.
return any(getattr(x, "_is_order_dependent", False) for x in args)


def operation_changes_length(*args: IntoExpr | Any) -> bool:
"""Track whether operation changes length.
n-ary operations between expressions which change length are not
allowed. This is because the output might be non-relational. For
example:
df = pl.LazyFrame({'a': [1,2,None], 'b': [4,None,6]})
df.select(pl.col('a', 'b').drop_nulls())
Polars does allow this, but in the result we end up with the
tuple (2, 6) which wasn't part of the original data.
Rules are:
- in an n-ary operation, if any one of them changes length, then
it must be the only expression present
- in a comparison between a changes-length expression and a
scalar, the output changes length
"""
from narwhals.expr import Expr

n_exprs = len([x for x in args if isinstance(x, Expr)])
changes_length = any(isinstance(x, Expr) and x._changes_length for x in args)
if n_exprs > 1 and changes_length:
msg = (
"Found multiple expressions at least one of which changes length.\n"
"Any length-changing expression can only be used in isolation, unless\n"
"it is followed by an aggregation."
)
raise LengthChangingExprError(msg)
return changes_length


def operation_aggregates(*args: IntoExpr | Any) -> bool:
# If an arg is an Expr, we look at `_aggregates`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a').sum() + 1),
# which is already length-1, so we default to `True`. If any
# expression does not aggregate, then broadcasting will take
# place and the result will not be an aggregate.
return all(getattr(x, "_aggregates", True) for x in args)
11 changes: 11 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from narwhals.dependencies import get_polars
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import LengthChangingExprError
from narwhals.exceptions import OrderDependentExprError
from narwhals.schema import Schema
from narwhals.translate import to_native
Expand Down Expand Up @@ -3648,6 +3649,16 @@ def _extract_compliant(self, arg: Any) -> Any:
" they will be supported."
)
raise OrderDependentExprError(msg)
if arg._changes_length:
msg = (
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
"followed by an aggregation.\n\n"
"Hints:\n"
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
"- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n"
" use `lf.select(nw.col('a').drop_nulls().sum())\n"
)
raise LengthChangingExprError(msg)
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)): # pragma: no cover
msg = (
Expand Down
8 changes: 8 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def __init__(self, message: str) -> None:
super().__init__(self.message)


class LengthChangingExprError(ValueError):
"""Exception raised when trying to use an expression which changes length with LazyFrames."""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)


class UnsupportedDTypeError(ValueError):
"""Exception raised when trying to convert to a DType which is not supported by the given backend."""

Expand Down
Loading

0 comments on commit b26358b

Please sign in to comment.