Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rolling_exp: keep_attrs and typing #4592

Merged
merged 3 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ New Features
By `Michal Baumgartner <https://github.com/m1so>`_.
- :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`).
By `Julius Busecke <https://github.com/jbusecke>`_.
- Added the ``keep_attrs`` keyword to :py:meth:`~xarray.DataArray.rolling_exp.mean` and :py:meth:`~xarray.Dataset.rolling_exp.mean`.
The attributes are now kept per default.
By `Mathias Hauser <https://github.com/mathause>`_.


Bug fixes
~~~~~~~~~
Expand Down
36 changes: 31 additions & 5 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar

import numpy as np

from .options import _get_keep_attrs
from .pdcompat import count_not_none
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .dataarray import DataArray # noqa: F401
from .dataset import Dataset # noqa: F401

T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different from DataWithCoords?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from .common import DataWithCoords
T_DataWithCoords = TypeVar("T_DataWithCoords", bound=DataWithCoords)

works as well. A bit to my surprise as DataWithCoords does not implement reduce. But mypy does not throw an error here:

from xarray.core.common import DataWithCoords
class A:
    pass

hasattr(xr.core.common.DataWithCoords, "reduce") # -> False

def test(x: "A"):
    x.reduce() # mypy errors

def test2(x: "DataWithCoords"):
    x.reduce() # mypy does not error

not sure why not...


(Also I am not entirely sure what the difference is between

T_Parent = TypeVar("T_Parent", bound=Parent)
T_Children = TypeVar("T_Children", ChildA, ChildB)

)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is odd it doesn't throw an error...

On the second case, I think — not completely sure — that T_Parent is generic with an "upper bound" of Parent, so any function that takes T_Parent will return the same type.

Whereas a function that takes T_Children could return either of ChildA or ChildB regardless of what's passed in.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And your call on whether we use DataWithCoords or the TypeVar, but I think if we do the latter we could find a clearer name :)

I would marginally vote to go with DataWithCoords given we do that elsewhere but totally open on trying another approach

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both would need to be TypeVars. I I'll stick to T_DSorDA as DataWithCoords does not implement reduce. T_DSorDA is already used at some places (map_blocks), e.g.:

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I didn't know that, thanks. I wasn't a fan of the name, but we can change them together another time if others agree.



def _get_alpha(com=None, span=None, halflife=None, alpha=None):
# pandas defines in terms of com (converting to alpha in the algo)
Expand Down Expand Up @@ -56,7 +65,7 @@ def _get_center_of_mass(comass, span, halflife, alpha):
return float(comass)


class RollingExp:
class RollingExp(Generic[T_DSorDA]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea with Generic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that took quite a while to figure out but it doesn't work without it as the Type of obj does not bind to the function else. That should come in handy for some other classes.

"""
Exponentially-weighted moving window object.
Similar to EWM in pandas
Expand All @@ -78,16 +87,28 @@ class RollingExp:
RollingExp : type of input argument
"""

def __init__(self, obj, windows, window_type="span"):
self.obj = obj
def __init__(
self,
obj: T_DSorDA,
windows: Mapping[Hashable, int],
window_type: str = "span",
):
self.obj: T_DSorDA = obj
dim, window = next(iter(windows.items()))
self.dim = dim
self.alpha = _get_alpha(**{window_type: window})

def mean(self):
def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA:
"""
Exponentially weighted moving average

Parameters
----------
keep_attrs : bool, default: None
If True, the attributes (``attrs``) will be copied from the original
object to the new one. If False, the new object will be returned
without attributes. If None uses the global default.

Examples
--------
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
Expand All @@ -97,4 +118,9 @@ def mean(self):
Dimensions without coordinates: x
"""

return self.obj.reduce(move_exp_nanmean, dim=self.dim, alpha=self.alpha)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

return self.obj.reduce(
move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
)
29 changes: 29 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6927,6 +6927,35 @@ def test_rolling_exp(da, dim, window_type, window):
assert_allclose(expected.variable, result.variable)


@requires_numbagg
def test_rolling_exp_keep_attrs(da):

attrs = {"attrs": "da"}
da.attrs = attrs

# attrs are kept per default
result = da.rolling_exp(time=10).mean()
assert result.attrs == attrs

# discard attrs
result = da.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}

# test discard attrs using global option
with set_options(keep_attrs=False):
result = da.rolling_exp(time=10).mean()
assert result.attrs == {}

# keyword takes precedence over global option
with set_options(keep_attrs=False):
result = da.rolling_exp(time=10).mean(keep_attrs=True)
assert result.attrs == attrs

with set_options(keep_attrs=True):
result = da.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}


def test_no_dict():
d = DataArray()
with pytest.raises(AttributeError):
Expand Down
37 changes: 37 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6145,6 +6145,43 @@ def test_rolling_exp(ds):
assert isinstance(result, Dataset)


@requires_numbagg
def test_rolling_exp_keep_attrs(ds):

attrs_global = {"attrs": "global"}
attrs_z1 = {"attr": "z1"}

ds.attrs = attrs_global
ds.z1.attrs = attrs_z1

# attrs are kept per default
result = ds.rolling_exp(time=10).mean()
assert result.attrs == attrs_global
assert result.z1.attrs == attrs_z1

# discard attrs
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}
assert result.z1.attrs == {}

# test discard attrs using global option
with set_options(keep_attrs=False):
result = ds.rolling_exp(time=10).mean()
assert result.attrs == {}
assert result.z1.attrs == {}

# keyword takes precedence over global option
with set_options(keep_attrs=False):
result = ds.rolling_exp(time=10).mean(keep_attrs=True)
assert result.attrs == attrs_global
assert result.z1.attrs == attrs_z1

with set_options(keep_attrs=True):
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}
assert result.z1.attrs == {}


@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
Expand Down