From 4be6653a0cebc368998e4fe69c0f35231aa39621 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 20 Nov 2020 20:39:10 +0100 Subject: [PATCH] rolling_exp: keep_attrs and typing (#4592) * rolling_exp: keep_attrs and typing * Update doc/whats-new.rst * update whats-new --- doc/whats-new.rst | 2 ++ xarray/core/rolling_exp.py | 36 ++++++++++++++++++++++++++++----- xarray/tests/test_dataarray.py | 29 ++++++++++++++++++++++++++ xarray/tests/test_dataset.py | 37 ++++++++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a499c8b3505..5427e85c2ad 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,6 +53,8 @@ New Features By `Michal Baumgartner `_. - :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`). By `Julius Busecke `_. +- Added the ``keep_attrs`` keyword to ``rolling_exp.mean()``; it now keeps attributes + per default. By `Mathias Hauser `_ (:pull:`4592`). Bug fixes ~~~~~~~~~ diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index b80a4d313d9..0ae85a870e8 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,8 +1,17 @@ +from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar + import numpy as np +from .options import _get_keep_attrs from .pdcompat import count_not_none from .pycompat import is_duck_dask_array +if TYPE_CHECKING: + from .dataarray import DataArray # noqa: F401 + from .dataset import Dataset # noqa: F401 + +T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") + def _get_alpha(com=None, span=None, halflife=None, alpha=None): # pandas defines in terms of com (converting to alpha in the algo) @@ -56,7 +65,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -class RollingExp: +class RollingExp(Generic[T_DSorDA]): """ Exponentially-weighted moving window object. Similar to EWM in pandas @@ -78,16 +87,28 @@ class RollingExp: RollingExp : type of input argument """ - def __init__(self, obj, windows, window_type="span"): - self.obj = obj + def __init__( + self, + obj: T_DSorDA, + windows: Mapping[Hashable, int], + window_type: str = "span", + ): + self.obj: T_DSorDA = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self): + def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA: """ Exponentially weighted moving average + Parameters + ---------- + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + Examples -------- >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") @@ -97,4 +118,9 @@ def mean(self): Dimensions without coordinates: x """ - return self.obj.reduce(move_exp_nanmean, dim=self.dim, alpha=self.alpha) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + return self.obj.reduce( + move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs + ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 540e97b380d..599584e0081 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6931,6 +6931,35 @@ def test_rolling_exp(da, dim, window_type, window): assert_allclose(expected.variable, result.variable) +@requires_numbagg +def test_rolling_exp_keep_attrs(da): + + attrs = {"attrs": "da"} + da.attrs = attrs + + # attrs are kept per default + result = da.rolling_exp(time=10).mean() + assert result.attrs == attrs + + # discard attrs + result = da.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = da.rolling_exp(time=10).mean() + assert result.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = da.rolling_exp(time=10).mean(keep_attrs=True) + assert result.attrs == attrs + + with set_options(keep_attrs=True): + result = da.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + + def test_no_dict(): d = DataArray() with pytest.raises(AttributeError): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 21d9bc9ca01..6c4311c3791 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6150,6 +6150,43 @@ def test_rolling_exp(ds): assert isinstance(result, Dataset) +@requires_numbagg +def test_rolling_exp_keep_attrs(ds): + + attrs_global = {"attrs": "global"} + attrs_z1 = {"attr": "z1"} + + ds.attrs = attrs_global + ds.z1.attrs = attrs_z1 + + # attrs are kept per default + result = ds.rolling_exp(time=10).mean() + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + # discard attrs + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + assert result.z1.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean() + assert result.attrs == {} + assert result.z1.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean(keep_attrs=True) + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + with set_options(keep_attrs=True): + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + assert result.z1.attrs == {} + + @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) @pytest.mark.parametrize("window", (1, 2, 3, 4))