Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding ewm_mean #1298

Merged
merged 35 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
fc64937
wip
DeaMariaLeon Oct 30, 2024
3d1e466
wip
DeaMariaLeon Oct 30, 2024
ac0c3f7
latex works
DeaMariaLeon Oct 30, 2024
14dd1c5
doc test series
DeaMariaLeon Oct 31, 2024
a4b5bd7
expr docstring
DeaMariaLeon Oct 31, 2024
686e33c
added to pandaslikeexpr
DeaMariaLeon Oct 31, 2024
9113a9d
added wip test
DeaMariaLeon Nov 1, 2024
212b78a
wip
DeaMariaLeon Nov 1, 2024
1bf1571
Merge remote-tracking branch 'upstream/main' into ewm
DeaMariaLeon Nov 1, 2024
cd986f0
after merge
DeaMariaLeon Nov 2, 2024
e5b9486
added dask not implemented error test
DeaMariaLeon Nov 2, 2024
1dfab2c
added test with nulls
DeaMariaLeon Nov 3, 2024
6f738cd
example with nulls
DeaMariaLeon Nov 3, 2024
0cdb0c3
Merge remote-tracking branch 'upstream/main' into ewm
DeaMariaLeon Nov 3, 2024
73cc573
fixed mkdocs issue
DeaMariaLeon Nov 3, 2024
5cd4833
Match polars' None in input
DeaMariaLeon Nov 6, 2024
130322e
Merge remote-tracking branch 'upstream/main' into ewm
DeaMariaLeon Nov 6, 2024
19da3a3
polars version
DeaMariaLeon Nov 6, 2024
6cc0a96
polars version again
DeaMariaLeon Nov 6, 2024
afb3ed3
again
DeaMariaLeon Nov 6, 2024
ee2e916
wip
DeaMariaLeon Nov 6, 2024
7f872cf
add modin to xfail
DeaMariaLeon Nov 6, 2024
6fdaa29
Merge remote-tracking branch 'upstream/main' into ewm
DeaMariaLeon Nov 8, 2024
3cbfe53
ewm_mean not implemented yet for old Polars
DeaMariaLeon Nov 10, 2024
6368b04
after conflict
DeaMariaLeon Nov 10, 2024
a6c4525
remove unused test
DeaMariaLeon Nov 10, 2024
a67aef0
parametrize expected
DeaMariaLeon Nov 12, 2024
f8d438a
after merge conflict
DeaMariaLeon Nov 12, 2024
0c34a55
Merge remote-tracking branch 'upstream/main' into ewm
DeaMariaLeon Nov 12, 2024
eddceb4
test parameters
DeaMariaLeon Nov 12, 2024
fdeb4dc
Merge remote-tracking branch 'upstream/main' into ewm
DeaMariaLeon Nov 12, 2024
8787b65
Merge remote-tracking branch 'upstream/main' into ewm
DeaMariaLeon Nov 18, 2024
3f2a26d
added warning
DeaMariaLeon Nov 18, 2024
e8eb645
remove nan example, https://github.com/narwhals-dev/narwhals/issues/1401
MarcoGorelli Nov 19, 2024
bbe2cae
use None in test
MarcoGorelli Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- cum_sum
- diff
- drop_nulls
- ewm_mean
- fill_null
- filter
- gather_every
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- diff
- drop_nulls
- dtype
- ewm_mean
- fill_null
- filter
- gather_every
Expand Down
7 changes: 7 additions & 0 deletions docs/css/code_select.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.highlight .gp, .highlight .go { /* Generic.Prompt, Generic.Output */
user-select: none;
-webkit-user-select: none; /* Safari */
-moz-user-select: none; /* Firefox */
-ms-user-select: none; /* Internet Explorer/Edge */
color: red;
}
10 changes: 10 additions & 0 deletions docs/javascripts/katex.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
document$.subscribe(({ body }) => {
renderMathInElement(body, {
delimiters: [
{ left: "$$", right: "$$", display: true },
{ left: "$", right: "$", display: false },
{ left: "\\(", right: "\\)", display: false },
{ left: "\\[", right: "\\]", display: true }
],
})
})
9 changes: 9 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,12 @@ markdown_extensions:
- pymdownx.emoji:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- pymdownx.arithmatex:
generic: true
extra_javascript:
- javascripts/katex.js
- https://unpkg.com/katex@0/dist/katex.min.js
- https://unpkg.com/katex@0/dist/contrib/auto-render.min.js

extra_css:
- https://unpkg.com/katex@0/dist/katex.min.css
14 changes: 14 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,20 @@ def round(self, decimals: int) -> Self:
returns_scalar=False,
)

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> NoReturn:
msg = "`Expr.ewm_mean` is not supported for the Dask backend"
raise NotImplementedError(msg)

def unique(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead."
Expand Down
23 changes: 23 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,29 @@ def is_in(self, other: Any) -> Self:
def arg_true(self) -> Self:
return reuse_series_implementation(self, "arg_true")

def ewm_mean(
self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
return reuse_series_implementation(
self,
"ewm_mean",
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)

def filter(self, *predicates: Any) -> Self:
plx = self.__narwhals_namespace__()
other = plx.all_horizontal(*predicates)
Expand Down
19 changes: 19 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ def dtype(self: Self) -> DType:
self._native_series, self._dtypes, self._implementation
)

def ewm_mean(
self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
Comment on lines +181 to +187
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a lot of parameters here, do we have a test which hits each of them?

Copy link
Member Author

@DeaMariaLeon DeaMariaLeon Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add tests for the parameters then... (I can't hit all the parameters in only one test, at least not the first 4 I think).. Is that what you meant?

) -> PandasLikeSeries:
ser = self._native_series
mask_na = ser.isna()
result = ser.ewm(
com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
).mean()
result[mask_na] = None
return self._from_native_series(result)

def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
if isinstance(values, self.__class__):
# .copy() is necessary in some pre-2.2 versions of pandas to avoid
Expand Down
30 changes: 30 additions & 0 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.utils import Implementation
from narwhals.utils import parse_version

if TYPE_CHECKING:
import polars as pl
Expand Down Expand Up @@ -49,6 +50,35 @@ def cast(self, dtype: DType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._dtypes)
return self._from_native_expr(expr.cast(dtype_pl))

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
import polars as pl # ignore-banned-import()

if parse_version(pl.__version__) <= (0, 20, 31): # pragma: no cover
msg = "`ewm_mean` not implemented for polars older than 0.20.31"
raise NotImplementedError(msg)
expr = self._native_expr
return self._from_native_expr(
expr.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
)

def map_batches(
self,
function: Callable[[Any], Self],
Expand Down
27 changes: 27 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,33 @@ def to_dummies(
result, backend_version=self._backend_version, dtypes=self._dtypes
)

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
if self._backend_version < (0, 20, 31): # pragma: no cover
msg = "`ewm_mean` not implemented for polars older than 0.20.31"
raise NotImplementedError(msg)
expr = self._native_series
return self._from_native_series(
expr.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
)

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:
if self._backend_version < (0, 20, 6):
result = self._native_series.sort(descending=descending)
Expand Down
97 changes: 97 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,103 @@ def all(self) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).all())

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
r"""Compute exponentially-weighted moving average.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

Arguments:
com: Specify decay in terms of center of mass, $\gamma$, with <br> $\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$
span: Specify decay in terms of span, $\theta$, with <br> $\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$
half_life: Specify decay in terms of half-life, $\tau$, with <br> $\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$
alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$.
adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings
- When `adjust=True` (the default) the EW function is calculated
using weights $w_i = (1 - \alpha)^i$
- When `adjust=False` the EW function is calculated recursively by
$$
y_0=x_0
$$
$$
y_t = (1 - \alpha)y_{t - 1} + \alpha x_t
$$
min_periods: Minimum number of observations in window required to have a value, (otherwise result is null).
ignore_nulls: Ignore missing values when calculating weights.

- When `ignore_nulls=False` (default), weights are based on absolute
positions.
For example, the weights of $x_0$ and $x_2$ used in
calculating the final weighted average of $[x_0, None, x_2]$ are
$(1-\alpha)^2$ and $1$ if `adjust=True`, and
$(1-\alpha)^2$ and $\alpha$ if `adjust=False`.
- When `ignore_nulls=True`, weights are based
on relative positions. For example, the weights of
$x_0$ and $x_2$ used in calculating the final weighted
average of $[x_0, None, x_2]$ are
$1-\alpha$ and $1$ if `adjust=True`,
and $1-\alpha$ and $\alpha$ if `adjust=False`.

Returns:
Expr

Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> data = {"a": [1, 2, 3]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)

We define a library agnostic function:

>>> @nw.narwhalify
... def func(df):
... return df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=False))

We can then pass either pandas or Polars to `func`:

>>> func(df_pd)
a
0 1.000000
1 1.666667
2 2.428571

>>> func(df_pl) # doctest: +NORMALIZE_WHITESPACE
shape: (3, 1)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ a β”‚
β”‚ --- β”‚
β”‚ f64 β”‚
β•žβ•β•β•β•β•β•β•β•β•β•β•‘
β”‚ 1.0 β”‚
β”‚ 1.666667 β”‚
β”‚ 2.428571 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._call(plx).ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
)

def mean(self) -> Self:
"""Get mean value.

Expand Down
Loading
Loading