Skip to content

Commit

Permalink
rolling_exp: keep_attrs and typing (#4592)
Browse files Browse the repository at this point in the history
* rolling_exp: keep_attrs and typing

* Update doc/whats-new.rst

* update whats-new
  • Loading branch information
mathause authored Nov 20, 2020
1 parent 19c2626 commit 4be6653
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 5 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ New Features
By `Michal Baumgartner <https://github.com/m1so>`_.
- :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`).
By `Julius Busecke <https://github.com/jbusecke>`_.
- Added the ``keep_attrs`` keyword to ``rolling_exp.mean()``; it now keeps attributes
per default. By `Mathias Hauser <https://github.com/mathause>`_ (:pull:`4592`).

Bug fixes
~~~~~~~~~
Expand Down
36 changes: 31 additions & 5 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar

import numpy as np

from .options import _get_keep_attrs
from .pdcompat import count_not_none
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .dataarray import DataArray # noqa: F401
from .dataset import Dataset # noqa: F401

T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")


def _get_alpha(com=None, span=None, halflife=None, alpha=None):
# pandas defines in terms of com (converting to alpha in the algo)
Expand Down Expand Up @@ -56,7 +65,7 @@ def _get_center_of_mass(comass, span, halflife, alpha):
return float(comass)


class RollingExp:
class RollingExp(Generic[T_DSorDA]):
"""
Exponentially-weighted moving window object.
Similar to EWM in pandas
Expand All @@ -78,16 +87,28 @@ class RollingExp:
RollingExp : type of input argument
"""

def __init__(self, obj, windows, window_type="span"):
self.obj = obj
def __init__(
self,
obj: T_DSorDA,
windows: Mapping[Hashable, int],
window_type: str = "span",
):
self.obj: T_DSorDA = obj
dim, window = next(iter(windows.items()))
self.dim = dim
self.alpha = _get_alpha(**{window_type: window})

def mean(self):
def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA:
"""
Exponentially weighted moving average
Parameters
----------
keep_attrs : bool, default: None
If True, the attributes (``attrs``) will be copied from the original
object to the new one. If False, the new object will be returned
without attributes. If None uses the global default.
Examples
--------
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
Expand All @@ -97,4 +118,9 @@ def mean(self):
Dimensions without coordinates: x
"""

return self.obj.reduce(move_exp_nanmean, dim=self.dim, alpha=self.alpha)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

return self.obj.reduce(
move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
)
29 changes: 29 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6931,6 +6931,35 @@ def test_rolling_exp(da, dim, window_type, window):
assert_allclose(expected.variable, result.variable)


@requires_numbagg
def test_rolling_exp_keep_attrs(da):

attrs = {"attrs": "da"}
da.attrs = attrs

# attrs are kept per default
result = da.rolling_exp(time=10).mean()
assert result.attrs == attrs

# discard attrs
result = da.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}

# test discard attrs using global option
with set_options(keep_attrs=False):
result = da.rolling_exp(time=10).mean()
assert result.attrs == {}

# keyword takes precedence over global option
with set_options(keep_attrs=False):
result = da.rolling_exp(time=10).mean(keep_attrs=True)
assert result.attrs == attrs

with set_options(keep_attrs=True):
result = da.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}


def test_no_dict():
d = DataArray()
with pytest.raises(AttributeError):
Expand Down
37 changes: 37 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6150,6 +6150,43 @@ def test_rolling_exp(ds):
assert isinstance(result, Dataset)


@requires_numbagg
def test_rolling_exp_keep_attrs(ds):

attrs_global = {"attrs": "global"}
attrs_z1 = {"attr": "z1"}

ds.attrs = attrs_global
ds.z1.attrs = attrs_z1

# attrs are kept per default
result = ds.rolling_exp(time=10).mean()
assert result.attrs == attrs_global
assert result.z1.attrs == attrs_z1

# discard attrs
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}
assert result.z1.attrs == {}

# test discard attrs using global option
with set_options(keep_attrs=False):
result = ds.rolling_exp(time=10).mean()
assert result.attrs == {}
assert result.z1.attrs == {}

# keyword takes precedence over global option
with set_options(keep_attrs=False):
result = ds.rolling_exp(time=10).mean(keep_attrs=True)
assert result.attrs == attrs_global
assert result.z1.attrs == attrs_z1

with set_options(keep_attrs=True):
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}
assert result.z1.attrs == {}


@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
Expand Down

0 comments on commit 4be6653

Please sign in to comment.