From 708204370ef168f35f250ca75e666246ae1420f6 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 1 May 2022 17:56:31 -0700 Subject: [PATCH 01/24] Run mypy tests (but always pass) (#6557) So we can at least see the result --- .github/workflows/ci-additional.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index a5ddceebf11..9a10403d44b 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -105,10 +105,12 @@ jobs: - name: Install mypy run: | python -m pip install mypy - python -m mypy --install-types --non-interactive + # Temporarily overriding to be true due to https://github.com/pydata/xarray/issues/6551 + # python -m mypy --install-types --non-interactive - name: Run mypy - run: python -m mypy + run: | + python -m mypy --install-types --non-interactive || true min-version-policy: name: Minimum Version Policy From cf8f1d6fc306e3bcef27d5acaf0d7c9989642f52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 May 2022 12:55:11 -0600 Subject: [PATCH 02/24] [pre-commit.ci] pre-commit autoupdate (#6562) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/mirrors-mypy: v0.942 → v0.950](https://github.com/pre-commit/mirrors-mypy/compare/v0.942...v0.950) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4ce6ebc523..1b11e285f1e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,7 +46,7 @@ repos: # - id: velin # args: ["--write", "--compact"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.942 + rev: v0.950 hooks: - id: mypy # Copied from setup.cfg From 126051f2bf2ddb7926a7da11b047b852d5ca6b87 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 3 May 2022 00:02:10 -0700 Subject: [PATCH 03/24] Run mypy tests (but always pass) (#6568) * Run mypy tests (but always pass) So we can at least see the result --- .github/workflows/ci-additional.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 9a10403d44b..e2685f445d7 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -72,8 +72,7 @@ jobs: runs-on: "ubuntu-latest" needs: detect-ci-trigger # temporarily skipping due to https://github.com/pydata/xarray/issues/6551 - # if: needs.detect-ci-trigger.outputs.triggered == 'false' - if: false + if: needs.detect-ci-trigger.outputs.triggered == 'false' defaults: run: shell: bash -l {0} From 39bda44ad76432afa63cebbfca8a57cf13a2860b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 May 2022 13:15:40 -0600 Subject: [PATCH 04/24] Bump min deps (#6559) Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Co-authored-by: Anderson Banihirwe --- ci/requirements/min-all-deps.yml | 35 +++---- doc/getting-started-guide/installing.rst | 4 +- doc/whats-new.rst | 20 ++++ setup.cfg | 4 +- xarray/core/dask_array_compat.py | 127 +---------------------- 5 files changed, 43 insertions(+), 147 deletions(-) diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index 76e2b28093d..ecabde06622 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -10,46 +10,45 @@ dependencies: - python=3.8 - boto3=1.13 - bottleneck=1.3 - # cartopy 0.18 conflicts with pynio - - cartopy=0.17 + - cartopy=0.19 - cdms2=3.1 - cfgrib=0.9 - - cftime=1.2 + - cftime=1.4 - coveralls - - dask-core=2.30 - - distributed=2.30 - - h5netcdf=0.8 - - h5py=2.10 - # hdf5 1.12 conflicts with h5py=2.10 + - dask-core=2021.04 + - distributed=2021.04 + - h5netcdf=0.11 + - h5py=3.1 + # hdf5 1.12 conflicts with h5py=3.1 - hdf5=1.10 - hypothesis - iris=2.4 - lxml=4.6 # Optional dep of pydap - - matplotlib-base=3.3 + - matplotlib-base=3.4 - nc-time-axis=1.2 # netcdf follows a 1.major.minor[.patch] convention # (see https://github.com/Unidata/netcdf4-python/issues/1090) # bumping the netCDF4 version is currently blocked by #4491 - netcdf4=1.5.3 - - numba=0.51 - - numpy=1.18 + - numba=0.53 + - numpy=1.19 - packaging=20.0 - - pandas=1.1 - - pint=0.16 + - pandas=1.2 + - pint=0.17 - pip - pseudonetcdf=3.1 - pydap=3.2 - - pynio=1.5 + # - pynio=1.5.5 - pytest - pytest-cov - pytest-env - pytest-xdist - - rasterio=1.1 - - scipy=1.5 + - rasterio=1.2 + - scipy=1.6 - seaborn=0.11 - - sparse=0.11 + - sparse=0.12 - toolz=0.11 - typing_extensions=3.7 - - zarr=2.5 + - zarr=2.8 - pip: - numbagg==0.1 diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst index 0668853946f..faa0fba5dd3 100644 --- a/doc/getting-started-guide/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -7,9 +7,9 @@ Required dependencies --------------------- - Python (3.8 or later) -- `numpy `__ (1.18 or later) +- `numpy `__ (1.19 or later) - `packaging `__ (20.0 or later) -- `pandas `__ (1.1 or later) +- `pandas `__ (1.2 or later) .. _optional-dependencies: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4882402073c..0d8ab5a8b40 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,6 +45,26 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- PyNIO support is now untested. The minimum versions of some dependencies were changed: + + =============== ===== ==== + Package Old New + =============== ===== ==== + cftime 1.2 1.4 + dask 2.30 2021.4 + distributed 2.30 2021.4 + h5netcdf 0.8 0.11 + matplotlib-base 3.3 3.4 + numba 0.51 0.53 + numpy 1.18 1.19 + pandas 1.1 1.2 + pint 0.16 0.17 + rasterio 1.1 1.2 + scipy 1.5 1.6 + sparse 0.11 0.12 + zarr 2.5 2.8 + =============== ===== ==== + - The Dataset and DataArray ``rename*`` methods do not implicitly add or drop indexes. (:pull:`5692`). By `Benoît Bovy `_. diff --git a/setup.cfg b/setup.cfg index 05b202810b4..6a0a06d2367 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,8 +75,8 @@ zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.htm include_package_data = True python_requires = >=3.8 install_requires = - numpy >= 1.18 - pandas >= 1.1 + numpy >= 1.19 + pandas >= 1.2 packaging >= 20.0 [options.extras_require] diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 0e0229cc3ca..4d73867a283 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,9 +1,6 @@ import warnings import numpy as np -from packaging.version import Version - -from .pycompat import dask_version try: import dask.array as da @@ -57,127 +54,7 @@ def pad(array, pad_width, mode="constant", **kwargs): return padded -if dask_version > Version("2.30.0"): - ensure_minimum_chunksize = da.overlap.ensure_minimum_chunksize -else: - - # copied from dask - def ensure_minimum_chunksize(size, chunks): - """Determine new chunks to ensure that every chunk >= size - - Parameters - ---------- - size : int - The maximum size of any chunk. - chunks : tuple - Chunks along one axis, e.g. ``(3, 3, 2)`` - - Examples - -------- - >>> ensure_minimum_chunksize(10, (20, 20, 1)) - (20, 11, 10) - >>> ensure_minimum_chunksize(3, (1, 1, 3)) - (5,) - - See Also - -------- - overlap - """ - if size <= min(chunks): - return chunks - - # add too-small chunks to chunks before them - output = [] - new = 0 - for c in chunks: - if c < size: - if new > size + (size - c): - output.append(new - (size - c)) - new = size - else: - new += c - if new >= size: - output.append(new) - new = 0 - if c >= size: - new += c - if new >= size: - output.append(new) - elif len(output) >= 1: - output[-1] += new - else: - raise ValueError( - f"The overlapping depth {size} is larger than your " - f"array {sum(chunks)}." - ) - - return tuple(output) - - -if dask_version > Version("2021.03.0"): +if da is not None: sliding_window_view = da.lib.stride_tricks.sliding_window_view else: - - def sliding_window_view(x, window_shape, axis=None): - from dask.array.overlap import map_overlap - from numpy.core.numeric import normalize_axis_tuple - - from .npcompat import sliding_window_view as _np_sliding_window_view - - window_shape = ( - tuple(window_shape) if np.iterable(window_shape) else (window_shape,) - ) - - window_shape_array = np.array(window_shape) - if np.any(window_shape_array <= 0): - raise ValueError("`window_shape` must contain positive values") - - if axis is None: - axis = tuple(range(x.ndim)) - if len(window_shape) != len(axis): - raise ValueError( - f"Since axis is `None`, must provide " - f"window_shape for all dimensions of `x`; " - f"got {len(window_shape)} window_shape elements " - f"and `x.ndim` is {x.ndim}." - ) - else: - axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True) - if len(window_shape) != len(axis): - raise ValueError( - f"Must provide matching length window_shape and " - f"axis; got {len(window_shape)} window_shape " - f"elements and {len(axis)} axes elements." - ) - - depths = [0] * x.ndim - for ax, window in zip(axis, window_shape): - depths[ax] += window - 1 - - # Ensure that each chunk is big enough to leave at least a size-1 chunk - # after windowing (this is only really necessary for the last chunk). - safe_chunks = tuple( - ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks) - ) - x = x.rechunk(safe_chunks) - - # result.shape = x_shape_trimmed + window_shape, - # where x_shape_trimmed is x.shape with every entry - # reduced by one less than the corresponding window size. - # trim chunks to match x_shape_trimmed - newchunks = tuple( - c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks) - ) + tuple((window,) for window in window_shape) - - kwargs = dict( - depth=tuple((0, d) for d in depths), # Overlap on +ve side only - boundary="none", - meta=x._meta, - new_axis=range(x.ndim, x.ndim + len(axis)), - chunks=newchunks, - trim=False, - window_shape=window_shape, - axis=axis, - ) - - return map_overlap(_np_sliding_window_view, x, align_arrays=False, **kwargs) + sliding_window_view = None From 6fbeb13105b419cb0a6646909df358d535e09faf Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Thu, 5 May 2022 21:15:58 +0200 Subject: [PATCH 05/24] polyval: Use Horner's algorithm + support chunked inputs (#6548) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Co-authored-by: Deepak Cherian --- asv_bench/benchmarks/polyfit.py | 38 ++++++++++ doc/whats-new.rst | 7 ++ xarray/core/computation.py | 101 +++++++++++++++++++++----- xarray/tests/test_computation.py | 119 +++++++++++++++++++++++-------- 4 files changed, 220 insertions(+), 45 deletions(-) create mode 100644 asv_bench/benchmarks/polyfit.py diff --git a/asv_bench/benchmarks/polyfit.py b/asv_bench/benchmarks/polyfit.py new file mode 100644 index 00000000000..429ffa19baa --- /dev/null +++ b/asv_bench/benchmarks/polyfit.py @@ -0,0 +1,38 @@ +import numpy as np + +import xarray as xr + +from . import parameterized, randn, requires_dask + +NDEGS = (2, 5, 20) +NX = (10**2, 10**6) + + +class Polyval: + def setup(self, *args, **kwargs): + self.xs = {nx: xr.DataArray(randn((nx,)), dims="x", name="x") for nx in NX} + self.coeffs = { + ndeg: xr.DataArray( + randn((ndeg,)), dims="degree", coords={"degree": np.arange(ndeg)} + ) + for ndeg in NDEGS + } + + @parameterized(["nx", "ndeg"], [NX, NDEGS]) + def time_polyval(self, nx, ndeg): + x = self.xs[nx] + c = self.coeffs[ndeg] + xr.polyval(x, c).compute() + + @parameterized(["nx", "ndeg"], [NX, NDEGS]) + def peakmem_polyval(self, nx, ndeg): + x = self.xs[nx] + c = self.coeffs[ndeg] + xr.polyval(x, c).compute() + + +class PolyvalDask(Polyval): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(*args, **kwargs) + self.xs = {k: v.chunk({"x": 10000}) for k, v in self.xs.items()} diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0d8ab5a8b40..fc5135dc598 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,9 @@ New Features - Allow passing chunks in ``**kwargs`` form to :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and :py:meth:`Variable.chunk`. (:pull:`6471`) By `Tom Nicholas `_. +- :py:meth:`xr.polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape, + is faster and requires less memory. (:pull:`6548`) + By `Michael Niklas `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -74,6 +77,10 @@ Breaking changes - Xarray's ufuncs have been removed, now that they can be replaced by numpy's ufuncs in all supported versions of numpy. By `Maximilian Roos `_. +- :py:meth:`xr.polyval` now uses the ``coord`` argument directly instead of its index coordinate. + (:pull:`6548`) + By `Michael Niklas `_. + Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1834622d96e..1a32cda512c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -17,12 +17,15 @@ Iterable, Mapping, Sequence, + overload, ) import numpy as np from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align +from .common import zeros_like +from .duck_array_ops import datetime_to_numeric from .indexes import Index, filter_indexes_from_coords from .merge import merge_attrs, merge_coordinates_without_align from .options import OPTIONS, _get_keep_attrs @@ -1843,36 +1846,100 @@ def where(cond, x, y, keep_attrs=None): ) -def polyval(coord, coeffs, degree_dim="degree"): +@overload +def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray: + ... + + +@overload +def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset: + ... + + +@overload +def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset: + ... + + +def polyval( + coord: T_Xarray, coeffs: T_Xarray, degree_dim: Hashable = "degree" +) -> T_Xarray: """Evaluate a polynomial at specific values Parameters ---------- - coord : DataArray - The 1D coordinate along which to evaluate the polynomial. - coeffs : DataArray - Coefficients of the polynomials. - degree_dim : str, default: "degree" + coord : DataArray or Dataset + Values at which to evaluate the polynomial. + coeffs : DataArray or Dataset + Coefficients of the polynomial. + degree_dim : Hashable, default: "degree" Name of the polynomial degree dimension in `coeffs`. + Returns + ------- + DataArray or Dataset + Evaluated polynomial. + See Also -------- xarray.DataArray.polyfit - numpy.polyval + numpy.polynomial.polynomial.polyval """ - from .dataarray import DataArray - from .missing import get_clean_interp_index - x = get_clean_interp_index(coord, coord.name, strict=False) + if degree_dim not in coeffs._indexes: + raise ValueError( + f"Dimension `{degree_dim}` should be a coordinate variable with labels." + ) + if not np.issubdtype(coeffs[degree_dim].dtype, int): + raise ValueError( + f"Dimension `{degree_dim}` should be of integer dtype. Received {coeffs[degree_dim].dtype} instead." + ) + max_deg = coeffs[degree_dim].max().item() + coeffs = coeffs.reindex( + {degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False + ) + coord = _ensure_numeric(coord) + + # using Horner's method + # https://en.wikipedia.org/wiki/Horner%27s_method + res = coeffs.isel({degree_dim: max_deg}, drop=True) + zeros_like(coord) + for deg in range(max_deg - 1, -1, -1): + res *= coord + res += coeffs.isel({degree_dim: deg}, drop=True) - deg_coord = coeffs[degree_dim] + return res - lhs = DataArray( - np.vander(x, int(deg_coord.max()) + 1), - dims=(coord.name, degree_dim), - coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]}, - ) - return (lhs * coeffs).sum(degree_dim) + +def _ensure_numeric(data: T_Xarray) -> T_Xarray: + """Converts all datetime64 variables to float64 + + Parameters + ---------- + data : DataArray or Dataset + Variables with possible datetime dtypes. + + Returns + ------- + DataArray or Dataset + Variables with datetime64 dtypes converted to float64. + """ + from .dataset import Dataset + + def to_floatable(x: DataArray) -> DataArray: + if x.dtype.kind in "mM": + return x.copy( + data=datetime_to_numeric( + x.data, + offset=np.datetime64("1970-01-01"), + datetime_unit="ns", + ), + ) + return x + + if isinstance(data, Dataset): + return data.map(to_floatable) + else: + return to_floatable(data) def _calc_idxminmax( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 7a397428ba3..127fdc5404f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1933,37 +1933,100 @@ def test_where_attrs() -> None: assert actual.attrs == {} -@pytest.mark.parametrize("use_dask", [True, False]) -@pytest.mark.parametrize("use_datetime", [True, False]) -def test_polyval(use_dask, use_datetime) -> None: - if use_dask and not has_dask: - pytest.skip("requires dask") - - if use_datetime: - xcoord = xr.DataArray( - pd.date_range("2000-01-01", freq="D", periods=10), dims=("x",), name="x" - ) - x = xr.core.missing.get_clean_interp_index(xcoord, "x") - else: - x = np.arange(10) - xcoord = xr.DataArray(x, dims=("x",), name="x") - - da = xr.DataArray( - np.stack((1.0 + x + 2.0 * x**2, 1.0 + 2.0 * x + 3.0 * x**2)), - dims=("d", "x"), - coords={"x": xcoord, "d": [0, 1]}, - ) - coeffs = xr.DataArray( - [[2, 1, 1], [3, 2, 1]], - dims=("d", "degree"), - coords={"d": [0, 1], "degree": [2, 1, 0]}, - ) +@pytest.mark.parametrize("use_dask", [False, True]) +@pytest.mark.parametrize( + ["x", "coeffs", "expected"], + [ + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}), + xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"), + id="simple", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [[0, 1], [0, 1]], dims=("y", "degree"), coords={"degree": [0, 1]} + ), + xr.DataArray([[1, 2, 3], [1, 2, 3]], dims=("y", "x")), + id="broadcast-x", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [[0, 1], [1, 0], [1, 1]], + dims=("x", "degree"), + coords={"degree": [0, 1]}, + ), + xr.DataArray([1, 1, 1 + 3], dims="x"), + id="shared-dim", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([1, 0, 0], dims="degree", coords={"degree": [2, 1, 0]}), + xr.DataArray([1, 2**2, 3**2], dims="x"), + id="reordered-index", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([5], dims="degree", coords={"degree": [3]}), + xr.DataArray([5, 5 * 2**3, 5 * 3**3], dims="x"), + id="sparse-index", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.Dataset( + {"a": ("degree", [0, 1]), "b": ("degree", [1, 0])}, + coords={"degree": [0, 1]}, + ), + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [1, 1, 1])}), + id="array-dataset", + ), + pytest.param( + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [2, 3, 4])}), + xr.DataArray([1, 1], dims="degree", coords={"degree": [0, 1]}), + xr.Dataset({"a": ("x", [2, 3, 4]), "b": ("x", [3, 4, 5])}), + id="dataset-array", + ), + pytest.param( + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [2, 3, 4])}), + xr.Dataset( + {"a": ("degree", [0, 1]), "b": ("degree", [1, 1])}, + coords={"degree": [0, 1]}, + ), + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [3, 4, 5])}), + id="dataset-dataset", + ), + pytest.param( + xr.DataArray(pd.date_range("1970-01-01", freq="s", periods=3), dims="x"), + xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}), + xr.DataArray( + [0, 1e9, 2e9], + dims="x", + coords={"x": pd.date_range("1970-01-01", freq="s", periods=3)}, + ), + id="datetime", + ), + ], +) +def test_polyval(use_dask, x, coeffs, expected) -> None: if use_dask: - coeffs = coeffs.chunk({"d": 2}) + if not has_dask: + pytest.skip("requires dask") + coeffs = coeffs.chunk({"degree": 2}) + x = x.chunk({"x": 2}) + with raise_if_dask_computes(): + actual = xr.polyval(x, coeffs) + xr.testing.assert_allclose(actual, expected) - da_pv = xr.polyval(da.x, coeffs) - xr.testing.assert_allclose(da, da_pv.T) +def test_polyval_degree_dim_checks(): + x = (xr.DataArray([1, 2, 3], dims="x"),) + coeffs = xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}) + with pytest.raises(ValueError): + xr.polyval(x, coeffs.drop_vars("degree")) + with pytest.raises(ValueError): + xr.polyval(x, coeffs.assign_coords(degree=coeffs.degree.astype(float))) @pytest.mark.parametrize("use_dask", [False, True]) From c60f9b03f9939ae7b3768821fdb26c811e302102 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 8 May 2022 16:32:23 -0700 Subject: [PATCH 06/24] Fix mypy issues & reenable in tests (#6581) * Run mypy tests (but always pass) So we can at least see the result * Fix mypy --- .github/workflows/ci-additional.yaml | 4 +--- xarray/backends/api.py | 2 +- xarray/backends/locks.py | 4 ++-- xarray/core/_typed_ops.pyi | 2 +- xarray/core/computation.py | 27 ++++++++++++++------------- xarray/core/dask_array_compat.py | 2 +- xarray/core/dataarray.py | 7 +++++-- xarray/core/dataset.py | 5 +++-- xarray/core/duck_array_ops.py | 2 +- xarray/core/nanops.py | 2 +- xarray/core/types.py | 2 +- xarray/tests/test_computation.py | 7 +++++-- xarray/tests/test_testing.py | 2 +- xarray/util/generate_ops.py | 2 +- 14 files changed, 38 insertions(+), 32 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index e2685f445d7..ff3e8ab7e63 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -105,11 +105,9 @@ jobs: run: | python -m pip install mypy - # Temporarily overriding to be true due to https://github.com/pydata/xarray/issues/6551 - # python -m mypy --install-types --non-interactive - name: Run mypy run: | - python -m mypy --install-types --non-interactive || true + python -m mypy --install-types --non-interactive min-version-policy: name: Minimum Version Policy diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9967b0a08c0..4962a4a9c02 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -35,7 +35,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None + Delayed = None # type: ignore DATAARRAY_NAME = "__xarray_dataarray_name__" diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 59417336f5f..1cc93779843 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -7,12 +7,12 @@ from dask.utils import SerializableLock except ImportError: # no need to worry about serializing the lock - SerializableLock = threading.Lock + SerializableLock = threading.Lock # type: ignore try: from dask.distributed import Lock as DistributedLock except ImportError: - DistributedLock = None + DistributedLock = None # type: ignore # Locks used by multiple backends. diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi index e23b5848ff7..e5b3c9112c7 100644 --- a/xarray/core/_typed_ops.pyi +++ b/xarray/core/_typed_ops.pyi @@ -21,7 +21,7 @@ from .variable import Variable try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray + DaskArray = np.ndarray # type: ignore # DatasetOpsMixin etc. are parent classes of Dataset etc. # Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1a32cda512c..6f1e08bf84f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -17,7 +17,6 @@ Iterable, Mapping, Sequence, - overload, ) import numpy as np @@ -1846,24 +1845,26 @@ def where(cond, x, y, keep_attrs=None): ) -@overload -def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray: - ... +# These overloads seem not to work — mypy says it can't find a matching overload for +# `DataArray` & `DataArray`, despite that being in the first overload. Would be nice to +# have overloaded functions rather than just `T_Xarray` for everything. +# @overload +# def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray: +# ... -@overload -def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset: - ... +# @overload +# def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset: +# ... -@overload -def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset: - ... + +# @overload +# def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset: +# ... -def polyval( - coord: T_Xarray, coeffs: T_Xarray, degree_dim: Hashable = "degree" -) -> T_Xarray: +def polyval(coord: T_Xarray, coeffs: T_Xarray, degree_dim="degree") -> T_Xarray: """Evaluate a polynomial at specific values Parameters diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 4d73867a283..e114c238b72 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -5,7 +5,7 @@ try: import dask.array as da except ImportError: - da = None + da = None # type: ignore def _validate_pad_output_shape(input_shape, pad_width, output_shape): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d15cbd00c0d..fc3cbef16f8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd +from ..backends.common import AbstractDataStore, ArrayWriter from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex from ..plot.plot import _PlotMethods @@ -67,7 +68,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None + Delayed = None # type: ignore try: from cdms2 import Variable as cdms2_Variable except ImportError: @@ -2875,7 +2876,9 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: isnull = pd.isnull(values) return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) - def to_netcdf(self, *args, **kwargs) -> bytes | Delayed | None: + def to_netcdf( + self, *args, **kwargs + ) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """Write DataArray contents to a netCDF file. All parameters are passed directly to :py:meth:`xarray.Dataset.to_netcdf`. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 76776b4bc44..1166e240120 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -32,6 +32,7 @@ import xarray as xr +from ..backends.common import ArrayWriter from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings from ..plot.dataset_plot import _Dataset_PlotMethods @@ -110,7 +111,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None + Delayed = None # type: ignore # list of attributes of pd.DatetimeIndex that are ndarrays of time info @@ -1686,7 +1687,7 @@ def to_netcdf( unlimited_dims: Iterable[Hashable] = None, compute: bool = True, invalid_netcdf: bool = False, - ) -> bytes | Delayed | None: + ) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """Write dataset contents to a netCDF file. Parameters diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index b85d0e1645e..253a68b7205 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -30,7 +30,7 @@ import dask.array as dask_array from dask.base import tokenize except ImportError: - dask_array = None + dask_array = None # type: ignore def _dask_or_eager_func( diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index c1a4d629f97..fa96bd6e150 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -11,7 +11,7 @@ from . import dask_array_compat except ImportError: - dask_array = None + dask_array = None # type: ignore[assignment] dask_array_compat = None # type: ignore[assignment] diff --git a/xarray/core/types.py b/xarray/core/types.py index 3f368501b25..74cb2fc2d46 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -16,7 +16,7 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray + DaskArray = np.ndarray # type: ignore T_Dataset = TypeVar("T_Dataset", bound="Dataset") diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 127fdc5404f..c59f1a6584f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import operator import pickle @@ -22,6 +24,7 @@ unified_dim_sizes, ) from xarray.core.pycompat import dask_version +from xarray.core.types import T_Xarray from . import has_dask, raise_if_dask_computes, requires_dask @@ -2009,14 +2012,14 @@ def test_where_attrs() -> None: ), ], ) -def test_polyval(use_dask, x, coeffs, expected) -> None: +def test_polyval(use_dask, x: T_Xarray, coeffs: T_Xarray, expected) -> None: if use_dask: if not has_dask: pytest.skip("requires dask") coeffs = coeffs.chunk({"degree": 2}) x = x.chunk({"x": 2}) with raise_if_dask_computes(): - actual = xr.polyval(x, coeffs) + actual = xr.polyval(coord=x, coeffs=coeffs) xr.testing.assert_allclose(actual, expected) diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 2bde7529d1e..1470706d0eb 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -10,7 +10,7 @@ try: from dask.array import from_array as dask_from_array except ImportError: - dask_from_array = lambda x: x + dask_from_array = lambda x: x # type: ignore try: import pint diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index f1fd6cbfeb2..0a382642708 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -210,7 +210,7 @@ def inplace(): try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray + DaskArray = np.ndarray # type: ignore # DatasetOpsMixin etc. are parent classes of Dataset etc. # Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally From bbb14a5d8383520f1a1e7e6d885c03ecddfbcf47 Mon Sep 17 00:00:00 2001 From: Fabien Maussion Date: Mon, 9 May 2022 17:25:01 +0200 Subject: [PATCH 07/24] Allow string formatting of scalar DataArrays (#5981) * Allow string formatting of scalar DataArrays * better comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * forgot type check * yeah, typing is new to me * Simpler: pass to numpy in all cases * Add dask test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/core/common.py | 4 ++++ xarray/tests/test_formatting.py | 22 +++++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 3db9b1cfa0c..cf02bcff77b 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -158,6 +158,10 @@ def _repr_html_(self): return f"
{escape(repr(self))}
" return formatting_html.array_repr(self) + def __format__(self: Any, format_spec: str) -> str: + # we use numpy: scalars will print fine and arrays will raise + return self.values.__format__(format_spec) + def _iter(self: Any) -> Iterator[Any]: for n in range(len(self)): yield self[n] diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 4bbf41c7b38..a5c044d8ea7 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -9,7 +9,7 @@ import xarray as xr from xarray.core import formatting -from . import requires_netCDF4 +from . import requires_dask, requires_netCDF4 class TestFormatting: @@ -418,6 +418,26 @@ def test_array_repr_variable(self) -> None: with xr.set_options(display_expand_data=False): formatting.array_repr(var) + @requires_dask + def test_array_scalar_format(self) -> None: + var = xr.DataArray(0) + assert var.__format__("") == "0" + assert var.__format__("d") == "0" + assert var.__format__(".2f") == "0.00" + + var = xr.DataArray([0.1, 0.2]) + assert var.__format__("") == "[0.1 0.2]" + with pytest.raises(TypeError) as excinfo: + var.__format__(".2f") + assert "unsupported format string passed to" in str(excinfo.value) + + # also check for dask + var = var.chunk(chunks={"dim_0": 1}) + assert var.__format__("") == "[0.1 0.2]" + with pytest.raises(TypeError) as excinfo: + var.__format__(".2f") + assert "unsupported format string passed to" in str(excinfo.value) + def test_inline_variable_array_repr_custom_repr() -> None: class CustomArray: From a7654b7f44ccc2e97f8e1fe935a5f1bc27c12d77 Mon Sep 17 00:00:00 2001 From: Philippe Blain Date: Mon, 9 May 2022 13:44:48 -0400 Subject: [PATCH 08/24] terminology.rst: fix link to Unidata's "netcdf_dataset_components" (#6583) --- doc/user-guide/terminology.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/user-guide/terminology.rst b/doc/user-guide/terminology.rst index 1876058323e..c8cfdd5133d 100644 --- a/doc/user-guide/terminology.rst +++ b/doc/user-guide/terminology.rst @@ -27,7 +27,7 @@ complete examples, please consult the relevant documentation.* Variable A `NetCDF-like variable - `_ + `_ consisting of dimensions, data, and attributes which describe a single array. The main functional difference between variables and numpy arrays is that numerical operations on variables implement array broadcasting From 4b76831124f8cd11463c9e4ecfdc6842654cf810 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 May 2022 12:17:21 -0600 Subject: [PATCH 09/24] [pre-commit.ci] pre-commit autoupdate (#6584) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b11e285f1e..6d6c94ff88f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.32.0 + rev: v2.32.1 hooks: - id: pyupgrade args: From 218e77a9f2f6af0fc2a944563eb0ba2e8f457051 Mon Sep 17 00:00:00 2001 From: Fabien Maussion Date: Tue, 10 May 2022 07:54:05 +0200 Subject: [PATCH 10/24] Add some warnings about rechunking to the docs (#6569) * Dask doc changes * small change * More edits * Update doc/user-guide/dask.rst * Update doc/user-guide/dask.rst * Back to one liners Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- doc/user-guide/dask.rst | 43 +++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/doc/user-guide/dask.rst b/doc/user-guide/dask.rst index 5110a970390..56717f5306e 100644 --- a/doc/user-guide/dask.rst +++ b/doc/user-guide/dask.rst @@ -84,7 +84,7 @@ argument to :py:func:`~xarray.open_dataset` or using the In this example ``latitude`` and ``longitude`` do not appear in the ``chunks`` dict, so only one chunk will be used along those dimensions. It is also -entirely equivalent to opening a dataset using :py:meth:`~xarray.open_dataset` +entirely equivalent to opening a dataset using :py:func:`~xarray.open_dataset` and then chunking the data using the ``chunk`` method, e.g., ``xr.open_dataset('example-data.nc').chunk({'time': 10})``. @@ -95,13 +95,21 @@ use :py:func:`~xarray.open_mfdataset`:: This function will automatically concatenate and merge datasets into one in the simple cases that it understands (see :py:func:`~xarray.combine_by_coords` -for the full disclaimer). By default, :py:meth:`~xarray.open_mfdataset` will chunk each +for the full disclaimer). By default, :py:func:`~xarray.open_mfdataset` will chunk each netCDF file into a single Dask array; again, supply the ``chunks`` argument to control the size of the resulting Dask arrays. In more complex cases, you can -open each file individually using :py:meth:`~xarray.open_dataset` and merge the result, as -described in :ref:`combining data`. Passing the keyword argument ``parallel=True`` to :py:meth:`~xarray.open_mfdataset` will speed up the reading of large multi-file datasets by +open each file individually using :py:func:`~xarray.open_dataset` and merge the result, as +described in :ref:`combining data`. Passing the keyword argument ``parallel=True`` to +:py:func:`~xarray.open_mfdataset` will speed up the reading of large multi-file datasets by executing those read tasks in parallel using ``dask.delayed``. +.. warning:: + + :py:func:`~xarray.open_mfdataset` called without ``chunks`` argument will return + dask arrays with chunk sizes equal to the individual files. Re-chunking + the dataset after creation with ``ds.chunk()`` will lead to an ineffective use of + memory and is not recommended. + You'll notice that printing a dataset still shows a preview of array values, even if they are actually Dask arrays. We can do this quickly with Dask because we only need to compute the first few values (typically from the first block). @@ -224,6 +232,7 @@ disk. available memory. .. note:: + For more on the differences between :py:meth:`~xarray.Dataset.persist` and :py:meth:`~xarray.Dataset.compute` see this `Stack Overflow answer `_ and the `Dask documentation `_. @@ -236,6 +245,11 @@ sizes of Dask arrays is done with the :py:meth:`~xarray.Dataset.chunk` method: rechunked = ds.chunk({"latitude": 100, "longitude": 100}) +.. warning:: + + Rechunking an existing dask array created with :py:func:`~xarray.open_mfdataset` + is not recommended (see above). + You can view the size of existing chunks on an array by viewing the :py:attr:`~xarray.Dataset.chunks` attribute: @@ -295,8 +309,7 @@ each block of your xarray object, you have three options: ``apply_ufunc`` ~~~~~~~~~~~~~~~ -Another option is to use xarray's :py:func:`~xarray.apply_ufunc`, which can -automate `embarrassingly parallel +:py:func:`~xarray.apply_ufunc` automates `embarrassingly parallel `__ "map" type operations where a function written for processing NumPy arrays should be repeatedly applied to xarray objects containing Dask arrays. It works similarly to @@ -542,18 +555,20 @@ larger chunksizes. Optimization Tips ----------------- -With analysis pipelines involving both spatial subsetting and temporal resampling, Dask performance can become very slow in certain cases. Here are some optimization tips we have found through experience: +With analysis pipelines involving both spatial subsetting and temporal resampling, Dask performance +can become very slow or memory hungry in certain cases. Here are some optimization tips we have found +through experience: -1. Do your spatial and temporal indexing (e.g. ``.sel()`` or ``.isel()``) early in the pipeline, especially before calling ``resample()`` or ``groupby()``. Grouping and resampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn't been implemented in Dask yet. (See `Dask issue #746 `_). +1. Do your spatial and temporal indexing (e.g. ``.sel()`` or ``.isel()``) early in the pipeline, especially before calling ``resample()`` or ``groupby()``. Grouping and resampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn't been implemented in Dask yet. (See `Dask issue #746 `_). More generally, ``groupby()`` is a costly operation and does not (yet) perform well on datasets split across multiple files (see :pull:`5734` and linked discussions there). 2. Save intermediate results to disk as a netCDF files (using ``to_netcdf()``) and then load them again with ``open_dataset()`` for further computations. For example, if subtracting temporal mean from a dataset, save the temporal mean to disk before subtracting. Again, in theory, Dask should be able to do the computation in a streaming fashion, but in practice this is a fail case for the Dask scheduler, because it tries to keep every chunk of an array that it computes in memory. (See `Dask issue #874 `_) -3. Specify smaller chunks across space when using :py:meth:`~xarray.open_mfdataset` (e.g., ``chunks={'latitude': 10, 'longitude': 10}``). This makes spatial subsetting easier, because there's no risk you will load chunks of data referring to different chunks (probably not necessary if you follow suggestion 1). +3. Specify smaller chunks across space when using :py:meth:`~xarray.open_mfdataset` (e.g., ``chunks={'latitude': 10, 'longitude': 10}``). This makes spatial subsetting easier, because there's no risk you will load subsets of data which span multiple chunks. On individual files, prefer to subset before chunking (suggestion 1). + +4. Chunk as early as possible, and avoid rechunking as much as possible. Always pass the ``chunks={}`` argument to :py:func:`~xarray.open_mfdataset` to avoid redundant file reads. -4. Using the h5netcdf package by passing ``engine='h5netcdf'`` to :py:meth:`~xarray.open_mfdataset` - can be quicker than the default ``engine='netcdf4'`` that uses the netCDF4 package. +5. Using the h5netcdf package by passing ``engine='h5netcdf'`` to :py:meth:`~xarray.open_mfdataset` can be quicker than the default ``engine='netcdf4'`` that uses the netCDF4 package. -5. Some dask-specific tips may be found `here `_. +6. Some dask-specific tips may be found `here `_. -6. The dask `diagnostics `_ can be - useful in identifying performance bottlenecks. +7. The dask `diagnostics `_ can be useful in identifying performance bottlenecks. From fdc3c3d305bfb880a22e0cb9eb57d69c98e774a7 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 10 May 2022 08:18:19 +0200 Subject: [PATCH 11/24] Fix Dataset/DataArray.isel with drop=True and scalar DataArray indexes (#6579) * apply drop argument in isel_fancy * use literal type for error handling * add test for drop support in isel * add isel fix to whats-new * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * correct isel unit tests * add link to issue * type most (all?) occurences of errors/missing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- doc/whats-new.rst | 4 ++++ xarray/core/dataarray.py | 24 +++++++++++----------- xarray/core/dataset.py | 40 +++++++++++++++++++++--------------- xarray/core/indexes.py | 10 ++++----- xarray/core/types.py | 5 ++++- xarray/core/utils.py | 11 +++++++--- xarray/core/variable.py | 8 ++++---- xarray/tests/test_dataset.py | 19 +++++++++++++++++ 8 files changed, 79 insertions(+), 42 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fc5135dc598..b4dd36e9961 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -44,6 +44,7 @@ New Features - :py:meth:`xr.polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape, is faster and requires less memory. (:pull:`6548`) By `Michael Niklas `_. +- Improved overall typing. Breaking changes ~~~~~~~~~~~~~~~~ @@ -119,6 +120,9 @@ Bug fixes :pull:`6489`). By `Spencer Clark `_. - Dark themes are now properly detected in Furo-themed Sphinx documents (:issue:`6500`, :pull:`6501`). By `Kevin Paul `_. +- :py:meth:`isel` with `drop=True` works as intended with scalar :py:class:`DataArray` indexers. + (:issue:`6554`, :pull:`6579`) + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index fc3cbef16f8..150e5d9ca0f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -78,7 +78,7 @@ except ImportError: iris_Cube = None - from .types import T_DataArray, T_Xarray + from .types import ErrorChoice, ErrorChoiceWithWarn, T_DataArray, T_Xarray def _infer_coords_and_dims( @@ -1171,7 +1171,7 @@ def isel( self, indexers: Mapping[Any, Any] = None, drop: bool = False, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **indexers_kwargs: Any, ) -> DataArray: """Return a new DataArray whose data is given by integer indexing @@ -1186,7 +1186,7 @@ def isel( If DataArrays are passed as indexers, xarray-style indexing will be carried out. See :ref:`indexing` for the details. One of indexers or indexers_kwargs must be provided. - drop : bool, optional + drop : bool, default: False If ``drop=True``, drop coordinates variables indexed by integers instead of making them scalar. missing_dims : {"raise", "warn", "ignore"}, default: "raise" @@ -2335,7 +2335,7 @@ def transpose( self, *dims: Hashable, transpose_coords: bool = True, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> DataArray: """Return a new DataArray object with transposed dimensions. @@ -2386,7 +2386,7 @@ def T(self) -> DataArray: return self.transpose() def drop_vars( - self, names: Hashable | Iterable[Hashable], *, errors: str = "raise" + self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" ) -> DataArray: """Returns an array with dropped variables. @@ -2394,8 +2394,8 @@ def drop_vars( ---------- names : hashable or iterable of hashable Name(s) of variables to drop. - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if any of the variable + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the DataArray are dropped and no error is raised. @@ -2412,7 +2412,7 @@ def drop( labels: Mapping = None, dim: Hashable = None, *, - errors: str = "raise", + errors: ErrorChoice = "raise", **labels_kwargs, ) -> DataArray: """Backward compatible method based on `drop_vars` and `drop_sel` @@ -2431,7 +2431,7 @@ def drop_sel( self, labels: Mapping[Any, Any] = None, *, - errors: str = "raise", + errors: ErrorChoice = "raise", **labels_kwargs, ) -> DataArray: """Drop index labels from this DataArray. @@ -2440,8 +2440,8 @@ def drop_sel( ---------- labels : mapping of hashable to Any Index labels to drop - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the index labels passed are not in the dataset. If 'ignore', any given labels that are in the dataset are dropped and no error is raised. @@ -4589,7 +4589,7 @@ def query( queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **queries_kwargs: Any, ) -> DataArray: """Return a new data array indexed along the specified diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1166e240120..f8c2223157e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -106,7 +106,7 @@ from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray from .merge import CoercibleMapping - from .types import T_Xarray + from .types import ErrorChoice, ErrorChoiceWithWarn, T_Xarray try: from dask.delayed import Delayed @@ -2059,7 +2059,7 @@ def chunk( return self._replace(variables) def _validate_indexers( - self, indexers: Mapping[Any, Any], missing_dims: str = "raise" + self, indexers: Mapping[Any, Any], missing_dims: ErrorChoiceWithWarn = "raise" ) -> Iterator[tuple[Hashable, int | slice | np.ndarray | Variable]]: """Here we make sure + indexer has a valid keys @@ -2164,7 +2164,7 @@ def isel( self, indexers: Mapping[Any, Any] = None, drop: bool = False, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **indexers_kwargs: Any, ) -> Dataset: """Returns a new dataset with each array indexed along the specified @@ -2183,14 +2183,14 @@ def isel( If DataArrays are passed as indexers, xarray-style indexing will be carried out. See :ref:`indexing` for the details. One of indexers or indexers_kwargs must be provided. - drop : bool, optional + drop : bool, default: False If ``drop=True``, drop coordinates variables indexed by integers instead of making them scalar. missing_dims : {"raise", "warn", "ignore"}, default: "raise" What to do if dimensions that should be selected from are not present in the Dataset: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions **indexers_kwargs : {dim: indexer, ...}, optional The keyword arguments form of ``indexers``. @@ -2255,7 +2255,7 @@ def _isel_fancy( indexers: Mapping[Any, Any], *, drop: bool, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Dataset: valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) @@ -2271,6 +2271,10 @@ def _isel_fancy( } if var_indexers: new_var = var.isel(indexers=var_indexers) + # drop scalar coordinates + # https://github.com/pydata/xarray/issues/6554 + if name in self.coords and drop and new_var.ndim == 0: + continue else: new_var = var.copy(deep=False) if name not in indexes: @@ -4521,7 +4525,7 @@ def _assert_all_in_dataset( ) def drop_vars( - self, names: Hashable | Iterable[Hashable], *, errors: str = "raise" + self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" ) -> Dataset: """Drop variables from this dataset. @@ -4529,8 +4533,8 @@ def drop_vars( ---------- names : hashable or iterable of hashable Name(s) of variables to drop. - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if any of the variable + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the dataset are dropped and no error is raised. @@ -4556,7 +4560,9 @@ def drop_vars( variables, coord_names=coord_names, indexes=indexes ) - def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs): + def drop( + self, labels=None, dim=None, *, errors: ErrorChoice = "raise", **labels_kwargs + ): """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -4605,15 +4611,15 @@ def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs): ) return self.drop_sel(labels, errors=errors) - def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): + def drop_sel(self, labels=None, *, errors: ErrorChoice = "raise", **labels_kwargs): """Drop index labels from this dataset. Parameters ---------- labels : mapping of hashable to Any Index labels to drop - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the index labels passed are not in the dataset. If 'ignore', any given labels that are in the dataset are dropped and no error is raised. @@ -4740,7 +4746,7 @@ def drop_isel(self, indexers=None, **indexers_kwargs): return ds def drop_dims( - self, drop_dims: Hashable | Iterable[Hashable], *, errors: str = "raise" + self, drop_dims: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" ) -> Dataset: """Drop dimensions and associated variables from this dataset. @@ -4780,7 +4786,7 @@ def drop_dims( def transpose( self, *dims: Hashable, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Dataset: """Return a new Dataset object with all array dimensions transposed. @@ -7714,7 +7720,7 @@ def query( queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **queries_kwargs: Any, ) -> Dataset: """Return a new dataset with each array indexed along the specified @@ -7747,7 +7753,7 @@ def query( Dataset: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions **queries_kwargs : {dim: query, ...}, optional diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e02e1f569b2..9884a756fe6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -22,10 +22,10 @@ from . import formatting, nputils, utils from .indexing import IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter -from .types import T_Index from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar if TYPE_CHECKING: + from .types import ErrorChoice, T_Index from .variable import Variable IndexVars = Dict[Any, "Variable"] @@ -1098,7 +1098,7 @@ def is_multi(self, key: Hashable) -> bool: return len(self._id_coord_names[self._coord_name_id[key]]) > 1 def get_all_coords( - self, key: Hashable, errors: str = "raise" + self, key: Hashable, errors: ErrorChoice = "raise" ) -> dict[Hashable, Variable]: """Return all coordinates having the same index. @@ -1106,7 +1106,7 @@ def get_all_coords( ---------- key : hashable Index key. - errors : {"raise", "ignore"}, optional + errors : {"raise", "ignore"}, default: "raise" If "raise", raises a ValueError if `key` is not in indexes. If "ignore", an empty tuple is returned instead. @@ -1129,7 +1129,7 @@ def get_all_coords( return {k: self._variables[k] for k in all_coord_names} def get_all_dims( - self, key: Hashable, errors: str = "raise" + self, key: Hashable, errors: ErrorChoice = "raise" ) -> Mapping[Hashable, int]: """Return all dimensions shared by an index. @@ -1137,7 +1137,7 @@ def get_all_dims( ---------- key : hashable Index key. - errors : {"raise", "ignore"}, optional + errors : {"raise", "ignore"}, default: "raise" If "raise", raises a ValueError if `key` is not in indexes. If "ignore", an empty tuple is returned instead. diff --git a/xarray/core/types.py b/xarray/core/types.py index 74cb2fc2d46..6dbc57ce797 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TYPE_CHECKING, Literal, TypeVar, Union import numpy as np @@ -33,3 +33,6 @@ DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] VarCompatible = Union["Variable", "ScalarOrArray"] GroupByIncompatible = Union["Variable", "GroupBy"] + +ErrorChoice = Literal["raise", "ignore"] +ErrorChoiceWithWarn = Literal["raise", "warn", "ignore"] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index bab6476e734..aaa087a3532 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -29,6 +29,9 @@ import numpy as np import pandas as pd +if TYPE_CHECKING: + from .types import ErrorChoiceWithWarn + K = TypeVar("K") V = TypeVar("V") T = TypeVar("T") @@ -756,7 +759,9 @@ def __len__(self) -> int: def infix_dims( - dims_supplied: Collection, dims_all: Collection, missing_dims: str = "raise" + dims_supplied: Collection, + dims_all: Collection, + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Iterator: """ Resolves a supplied list containing an ellipsis representing other items, to @@ -804,7 +809,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( indexers: Mapping[Any, Any], dims: list | Mapping[Any, int], - missing_dims: str, + missing_dims: ErrorChoiceWithWarn, ) -> Mapping[Hashable, Any]: """Depending on the setting of missing_dims, drop any dimensions from indexers that are not present in dims. @@ -850,7 +855,7 @@ def drop_dims_from_indexers( def drop_missing_dims( - supplied_dims: Collection, dims: Collection, missing_dims: str + supplied_dims: Collection, dims: Collection, missing_dims: ErrorChoiceWithWarn ) -> Collection: """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that are not present in dims. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 05c70390b46..82a567b2c2a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -59,7 +59,7 @@ BASIC_INDEXING_TYPES = integer_types + (slice,) if TYPE_CHECKING: - from .types import T_Variable + from .types import ErrorChoiceWithWarn, T_Variable class MissingDimensionsError(ValueError): @@ -1159,7 +1159,7 @@ def _to_dense(self): def isel( self: T_Variable, indexers: Mapping[Any, Any] = None, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **indexers_kwargs: Any, ) -> T_Variable: """Return a new array indexed along the specified dimension(s). @@ -1173,7 +1173,7 @@ def isel( What to do if dimensions that should be selected from are not present in the DataArray: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions Returns @@ -1436,7 +1436,7 @@ def roll(self, shifts=None, **shifts_kwargs): def transpose( self, *dims, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Variable: """Return a new Variable object with transposed dimensions. diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c0f7f09ff61..c1fb161fb6a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1175,6 +1175,25 @@ def test_isel_fancy(self): assert_array_equal(actual["var2"], expected_var2) assert_array_equal(actual["var3"], expected_var3) + # test that drop works + ds = xr.Dataset({"a": (("x",), [1, 2, 3])}, coords={"b": (("x",), [5, 6, 7])}) + + actual = ds.isel({"x": 1}, drop=False) + expected = xr.Dataset({"a": 2}, coords={"b": 6}) + assert_identical(actual, expected) + + actual = ds.isel({"x": 1}, drop=True) + expected = xr.Dataset({"a": 2}) + assert_identical(actual, expected) + + actual = ds.isel({"x": DataArray(1)}, drop=False) + expected = xr.Dataset({"a": 2}, coords={"b": 6}) + assert_identical(actual, expected) + + actual = ds.isel({"x": DataArray(1)}, drop=True) + expected = xr.Dataset({"a": 2}) + assert_identical(actual, expected) + def test_isel_dataarray(self): """Test for indexing by DataArray""" data = create_test_data() From 3920c48d61d1f213a849bae51faa473b9c471946 Mon Sep 17 00:00:00 2001 From: brynjarmorka <79692387+brynjarmorka@users.noreply.github.com> Date: Tue, 10 May 2022 17:40:24 +0200 Subject: [PATCH 12/24] Doc Link to accessors list in extending-xarray.rst (#6587) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/internals/extending-xarray.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/internals/extending-xarray.rst b/doc/internals/extending-xarray.rst index 2951ce10f21..f8b61d12a2f 100644 --- a/doc/internals/extending-xarray.rst +++ b/doc/internals/extending-xarray.rst @@ -92,8 +92,8 @@ on ways to write new accessors and the philosophy behind the approach, see To help users keep things straight, please `let us know `_ if you plan to write a new accessor -for an open source library. In the future, we will maintain a list of accessors -and the libraries that implement them on this page. +for an open source library. Existing open source accessors and the libraries +that implement them are available in the list on the :ref:`ecosystem` page. To make documenting accessors with ``sphinx`` and ``sphinx.ext.autosummary`` easier, you can use `sphinx-autosummary-accessors`_. From 770e878663b03bd83d2c28af0643770bdd43c3da Mon Sep 17 00:00:00 2001 From: Michael Bauer Date: Wed, 11 May 2022 10:16:06 +0200 Subject: [PATCH 13/24] Add missing space in exception message (#6590) --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e97499f06b4..151ef844f44 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -312,7 +312,7 @@ def __init__( if not hashable(group): raise TypeError( "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension." + "name of an xarray variable or dimension. " f"Received {group!r} instead." ) group = obj[group] From 4a53e4114bd67eb1cbf82a54e4c65b4999d9150f Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Wed, 11 May 2022 10:35:09 -0700 Subject: [PATCH 14/24] Fix zarr append dtype checks (#6476) * fix zarr append dtype check first commit * use zstore in _validate_datatype * remove coding.strings.is_unicode_dtype check * test appending fixed length strings * test string length mismatch raises for U and S * add explanatory comment for zarr append dtype checks Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- xarray/backends/api.py | 48 ++++++++++++++++++++++------------- xarray/tests/test_backends.py | 17 ++++++++++++- xarray/tests/test_dataset.py | 35 +++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 19 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4962a4a9c02..05aa5d04deb 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -17,7 +17,7 @@ import numpy as np -from .. import backends, coding, conventions +from .. import backends, conventions from ..core import indexing from ..core.combine import ( _infer_concat_order_from_positions, @@ -1277,28 +1277,40 @@ def _validate_region(ds, region): ) -def _validate_datatypes_for_zarr_append(dataset): - """DataArray.name and Dataset keys must be a string or None""" +def _validate_datatypes_for_zarr_append(zstore, dataset): + """If variable exists in the store, confirm dtype of the data to append is compatible with + existing dtype. + """ + + existing_vars = zstore.get_variables() - def check_dtype(var): + def check_dtype(vname, var): if ( - not np.issubdtype(var.dtype, np.number) - and not np.issubdtype(var.dtype, np.datetime64) - and not np.issubdtype(var.dtype, np.bool_) - and not coding.strings.is_unicode_dtype(var.dtype) - and not var.dtype == object + vname not in existing_vars + or np.issubdtype(var.dtype, np.number) + or np.issubdtype(var.dtype, np.datetime64) + or np.issubdtype(var.dtype, np.bool_) + or var.dtype == object ): - # and not re.match('^bytes[1-9]+$', var.dtype.name)): + # We can skip dtype equality checks under two conditions: (1) if the var to append is + # new to the dataset, because in this case there is no existing var to compare it to; + # or (2) if var to append's dtype is known to be easy-to-append, because in this case + # we can be confident appending won't cause problems. Examples of dtypes which are not + # easy-to-append include length-specified strings of type `|S*` or ` Date: Wed, 11 May 2022 11:36:05 -0600 Subject: [PATCH 15/24] [docs] add Dataset.assign_coords example (#6336) (#6558) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/core/common.py | 50 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index cf02bcff77b..75518716870 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -454,7 +454,7 @@ def assign_coords(self, coords=None, **coords_kwargs): Examples -------- - Convert longitude coordinates from 0-359 to -180-179: + Convert `DataArray` longitude coordinates from 0-359 to -180-179: >>> da = xr.DataArray( ... np.random.rand(4), @@ -494,6 +494,54 @@ def assign_coords(self, coords=None, **coords_kwargs): >>> _ = da.assign_coords({"lon_2": ("lon", lon_2)}) + Note the same method applies to `Dataset` objects. + + Convert `Dataset` longitude coordinates from 0-359 to -180-179: + + >>> temperature = np.linspace(20, 32, num=16).reshape(2, 2, 4) + >>> precipitation = 2 * np.identity(4).reshape(2, 2, 4) + >>> ds = xr.Dataset( + ... data_vars=dict( + ... temperature=(["x", "y", "time"], temperature), + ... precipitation=(["x", "y", "time"], precipitation), + ... ), + ... coords=dict( + ... lon=(["x", "y"], [[260.17, 260.68], [260.21, 260.77]]), + ... lat=(["x", "y"], [[42.25, 42.21], [42.63, 42.59]]), + ... time=pd.date_range("2014-09-06", periods=4), + ... reference_time=pd.Timestamp("2014-09-05"), + ... ), + ... attrs=dict(description="Weather-related data"), + ... ) + >>> ds + + Dimensions: (x: 2, y: 2, time: 4) + Coordinates: + lon (x, y) float64 260.2 260.7 260.2 260.8 + lat (x, y) float64 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 2014-09-06 2014-09-07 ... 2014-09-09 + reference_time datetime64[ns] 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 20.0 20.8 21.6 22.4 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 2.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 2.0 + Attributes: + description: Weather-related data + >>> ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180)) + + Dimensions: (x: 2, y: 2, time: 4) + Coordinates: + lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 2014-09-06 2014-09-07 ... 2014-09-09 + reference_time datetime64[ns] 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 20.0 20.8 21.6 22.4 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 2.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 2.0 + Attributes: + description: Weather-related data + Notes ----- Since ``coords_kwargs`` is a dictionary, the order of your arguments From 11041bdfd903d2b3a36a63b2afd86b6ca752bb74 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 11 May 2022 12:04:40 -0600 Subject: [PATCH 16/24] Restore old MultiIndex dropping behaviour (#6592) Co-authored-by: Benoit Bovy --- xarray/core/dataset.py | 17 +++++++++++++++++ xarray/tests/test_dataarray.py | 9 +++++---- xarray/tests/test_dataset.py | 9 ++++----- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f8c2223157e..8b2f4783e34 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4551,6 +4551,23 @@ def drop_vars( if errors == "raise": self._assert_all_in_dataset(names) + # GH6505 + other_names = set() + for var in names: + maybe_midx = self._indexes.get(var, None) + if isinstance(maybe_midx, PandasMultiIndex): + idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim]) + idx_other_names = idx_coord_names - set(names) + other_names.update(idx_other_names) + if other_names: + names |= set(other_names) + warnings.warn( + f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. " + f"Please also drop the following variables: {other_names!r} to avoid an error in the future.", + DeprecationWarning, + stacklevel=2, + ) + assert_no_index_corrupted(self.xindexes, names) variables = {k: v for k, v in self._variables.items() if k not in names} diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b8c9edd7258..8e1099b7e33 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2360,10 +2360,11 @@ def test_drop_coordinates(self): assert_identical(actual, renamed) def test_drop_multiindex_level(self): - with pytest.raises( - ValueError, match=r"cannot remove coordinate.*corrupt.*index " - ): - self.mda.drop_vars("level_1") + # GH6505 + expected = self.mda.drop_vars(["x", "level_1", "level_2"]) + with pytest.warns(DeprecationWarning): + actual = self.mda.drop_vars("level_1") + assert_identical(expected, actual) def test_drop_all_multiindex_levels(self): dim_levels = ["x", "level_1", "level_2"] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f58fb0c9a99..263237d9d30 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2453,11 +2453,10 @@ def test_drop_variables(self): def test_drop_multiindex_level(self): data = create_test_multiindex() - - with pytest.raises( - ValueError, match=r"cannot remove coordinate.*corrupt.*index " - ): - data.drop_vars("level_1") + expected = data.drop_vars(["x", "level_1", "level_2"]) + with pytest.warns(DeprecationWarning): + actual = data.drop_vars("level_1") + assert_identical(expected, actual) def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) From cad4474a9ecd8acc78e42cf46030c9a1277f10c4 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Wed, 11 May 2022 21:42:40 +0200 Subject: [PATCH 17/24] Fix polyval overloads (#6593) * add polyval overloads * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/core/computation.py | 34 +++++++++++++++++++------------- xarray/tests/test_computation.py | 9 +++++++-- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6f1e08bf84f..823cbe02560 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -17,6 +17,7 @@ Iterable, Mapping, Sequence, + overload, ) import numpy as np @@ -1845,26 +1846,31 @@ def where(cond, x, y, keep_attrs=None): ) -# These overloads seem not to work — mypy says it can't find a matching overload for -# `DataArray` & `DataArray`, despite that being in the first overload. Would be nice to -# have overloaded functions rather than just `T_Xarray` for everything. +@overload +def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray: + ... -# @overload -# def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray: -# ... +@overload +def polyval(coord: DataArray, coeffs: Dataset, degree_dim: Hashable) -> Dataset: + ... -# @overload -# def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset: -# ... +@overload +def polyval(coord: Dataset, coeffs: DataArray, degree_dim: Hashable) -> Dataset: + ... -# @overload -# def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset: -# ... + +@overload +def polyval(coord: Dataset, coeffs: Dataset, degree_dim: Hashable) -> Dataset: + ... -def polyval(coord: T_Xarray, coeffs: T_Xarray, degree_dim="degree") -> T_Xarray: +def polyval( + coord: Dataset | DataArray, + coeffs: Dataset | DataArray, + degree_dim: Hashable = "degree", +) -> Dataset | DataArray: """Evaluate a polynomial at specific values Parameters @@ -1899,7 +1905,7 @@ def polyval(coord: T_Xarray, coeffs: T_Xarray, degree_dim="degree") -> T_Xarray: coeffs = coeffs.reindex( {degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False ) - coord = _ensure_numeric(coord) + coord = _ensure_numeric(coord) # type: ignore # https://github.com/python/mypy/issues/1533 ? # using Horner's method # https://en.wikipedia.org/wiki/Horner%27s_method diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c59f1a6584f..737ed82bc05 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -2012,14 +2012,19 @@ def test_where_attrs() -> None: ), ], ) -def test_polyval(use_dask, x: T_Xarray, coeffs: T_Xarray, expected) -> None: +def test_polyval( + use_dask: bool, + x: xr.DataArray | xr.Dataset, + coeffs: xr.DataArray | xr.Dataset, + expected: xr.DataArray | xr.Dataset, +) -> None: if use_dask: if not has_dask: pytest.skip("requires dask") coeffs = coeffs.chunk({"degree": 2}) x = x.chunk({"x": 2}) with raise_if_dask_computes(): - actual = xr.polyval(coord=x, coeffs=coeffs) + actual = xr.polyval(coord=x, coeffs=coeffs) # type: ignore xr.testing.assert_allclose(actual, expected) From 0512da117388a451653484b4f45927ac337b596f Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Wed, 11 May 2022 16:26:42 -0400 Subject: [PATCH 18/24] New inline_array kwarg for open_dataset (#6566) * added inline_array kwarg * remove cheeky print statements * Remove another rogue print statement * bump dask dependency * update multiple dependencies based on min-deps-check.py * update environment to match #6559 * Update h5py in ci/requirements/min-all-deps.yml * Update ci/requirements/min-all-deps.yml * remove pynio from test env * Update ci/requirements/min-all-deps.yml * promote inline_array kwarg to be top-level kwarg * whatsnew * add test * Remove repeated docstring entry Co-authored-by: Deepak Cherian * Remove repeated docstring entry Co-authored-by: Deepak Cherian * hyperlink to dask functions Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 3 +++ xarray/backends/api.py | 20 ++++++++++++++++++++ xarray/core/dataarray.py | 17 ++++++++++++++++- xarray/core/dataset.py | 8 +++++++- xarray/core/variable.py | 17 +++++++++++++++-- xarray/tests/test_backends.py | 21 +++++++++++++++++++++ 6 files changed, 82 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b4dd36e9961..6cba8563ecd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,9 @@ New Features - Allow passing chunks in ``**kwargs`` form to :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and :py:meth:`Variable.chunk`. (:pull:`6471`) By `Tom Nicholas `_. +- Expose `inline_array` kwarg from `dask.array.from_array` in :py:func:`open_dataset`, :py:meth:`Dataset.chunk`, + :py:meth:`DataArray.chunk`, and :py:meth:`Variable.chunk`. (:pull:`6471`) + By `Tom Nicholas `_. - :py:meth:`xr.polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape, is faster and requires less memory. (:pull:`6548`) By `Michael Niklas `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 05aa5d04deb..95f8dbc6eaf 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -274,6 +274,7 @@ def _chunk_ds( engine, chunks, overwrite_encoded_chunks, + inline_array, **extra_tokens, ): from dask.base import tokenize @@ -292,6 +293,7 @@ def _chunk_ds( overwrite_encoded_chunks=overwrite_encoded_chunks, name_prefix=name_prefix, token=token, + inline_array=inline_array, ) return backend_ds._replace(variables) @@ -303,6 +305,7 @@ def _dataset_from_backend_dataset( chunks, cache, overwrite_encoded_chunks, + inline_array, **extra_tokens, ): if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: @@ -320,6 +323,7 @@ def _dataset_from_backend_dataset( engine, chunks, overwrite_encoded_chunks, + inline_array, **extra_tokens, ) @@ -346,6 +350,7 @@ def open_dataset( concat_characters=None, decode_coords=None, drop_variables=None, + inline_array=False, backend_kwargs=None, **kwargs, ): @@ -430,6 +435,12 @@ def open_dataset( A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + inline_array: bool, optional + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -505,6 +516,7 @@ def open_dataset( chunks, cache, overwrite_encoded_chunks, + inline_array, drop_variables=drop_variables, **decoders, **kwargs, @@ -526,6 +538,7 @@ def open_dataarray( concat_characters=None, decode_coords=None, drop_variables=None, + inline_array=False, backend_kwargs=None, **kwargs, ): @@ -613,6 +626,12 @@ def open_dataarray( A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + inline_array: bool, optional + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -660,6 +679,7 @@ def open_dataarray( chunks=chunks, cache=cache, drop_variables=drop_variables, + inline_array=inline_array, backend_kwargs=backend_kwargs, use_cftime=use_cftime, decode_timedelta=decode_timedelta, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 150e5d9ca0f..64c4e419788 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1114,6 +1114,7 @@ def chunk( name_prefix: str = "xarray-", token: str = None, lock: bool = False, + inline_array: bool = False, **chunks_kwargs: Any, ) -> DataArray: """Coerce this array's data into a dask arrays with the given chunks. @@ -1138,6 +1139,9 @@ def chunk( lock : optional Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + inline_array: optional + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided. @@ -1145,6 +1149,13 @@ def chunk( Returns ------- chunked : xarray.DataArray + + See Also + -------- + DataArray.chunks + DataArray.chunksizes + xarray.unify_chunks + dask.array.from_array """ if chunks is None: warnings.warn( @@ -1163,7 +1174,11 @@ def chunk( chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") ds = self._to_temp_dataset().chunk( - chunks, name_prefix=name_prefix, token=token, lock=lock + chunks, + name_prefix=name_prefix, + token=token, + lock=lock, + inline_array=inline_array, ) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8b2f4783e34..d255cfacbd1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -240,6 +240,7 @@ def _maybe_chunk( lock=None, name_prefix="xarray-", overwrite_encoded_chunks=False, + inline_array=False, ): from dask.base import tokenize @@ -251,7 +252,7 @@ def _maybe_chunk( # subtle bugs result otherwise. see GH3350 token2 = tokenize(name, token if token else var._data, chunks) name2 = f"{name_prefix}{name}-{token2}" - var = var.chunk(chunks, name=name2, lock=lock) + var = var.chunk(chunks, name=name2, lock=lock, inline_array=inline_array) if overwrite_encoded_chunks and var.chunks is not None: var.encoding["chunks"] = tuple(x[0] for x in var.chunks) @@ -1995,6 +1996,7 @@ def chunk( name_prefix: str = "xarray-", token: str = None, lock: bool = False, + inline_array: bool = False, **chunks_kwargs: Any, ) -> Dataset: """Coerce all arrays in this dataset into dask arrays with the given @@ -2019,6 +2021,9 @@ def chunk( lock : optional Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + inline_array: optional + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided @@ -2032,6 +2037,7 @@ def chunk( Dataset.chunks Dataset.chunksizes xarray.unify_chunks + dask.array.from_array """ if chunks is None and chunks_kwargs is None: warnings.warn( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 82a567b2c2a..1e684a72984 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1023,6 +1023,7 @@ def chunk( ) = {}, name: str = None, lock: bool = False, + inline_array: bool = False, **chunks_kwargs: Any, ) -> Variable: """Coerce this array's data into a dask array with the given chunks. @@ -1046,6 +1047,9 @@ def chunk( lock : optional Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + inline_array: optional + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided. @@ -1053,6 +1057,13 @@ def chunk( Returns ------- chunked : xarray.Variable + + See Also + -------- + Variable.chunks + Variable.chunksizes + xarray.unify_chunks + dask.array.from_array """ import dask.array as da @@ -1098,7 +1109,9 @@ def chunk( if utils.is_dict_like(chunks): chunks = tuple(chunks.get(n, s) for n, s in enumerate(self.shape)) - data = da.from_array(data, chunks, name=name, lock=lock, **kwargs) + data = da.from_array( + data, chunks, name=name, lock=lock, inline_array=inline_array, **kwargs + ) return self._replace(data=data) @@ -2710,7 +2723,7 @@ def values(self, values): f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." ) - def chunk(self, chunks={}, name=None, lock=False): + def chunk(self, chunks={}, name=None, lock=False, inline_array=False): # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk() return self.copy(deep=False) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 16f630436e8..6b6f6e462bd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3840,6 +3840,27 @@ def test_load_dataarray(self): # load_dataarray ds.to_netcdf(tmp) + @pytest.mark.skipif( + ON_WINDOWS, + reason="counting number of tasks in graph fails on windows for some reason", + ) + def test_inline_array(self): + with create_tmp_file() as tmp: + original = Dataset({"foo": ("x", np.random.randn(10))}) + original.to_netcdf(tmp) + chunks = {"time": 10} + + def num_graph_nodes(obj): + return len(obj.__dask_graph__()) + + not_inlined = open_dataset(tmp, inline_array=False, chunks=chunks) + inlined = open_dataset(tmp, inline_array=True, chunks=chunks) + assert num_graph_nodes(inlined) < num_graph_nodes(not_inlined) + + not_inlined = open_dataarray(tmp, inline_array=False, chunks=chunks) + inlined = open_dataarray(tmp, inline_array=True, chunks=chunks) + assert num_graph_nodes(inlined) < num_graph_nodes(not_inlined) + @requires_scipy_or_netCDF4 @requires_pydap From 6bb2b855498b5c68d7cca8cceb710365d58e6048 Mon Sep 17 00:00:00 2001 From: Brewster Malevich Date: Wed, 11 May 2022 15:06:06 -0700 Subject: [PATCH 19/24] Minor Dataset.map docstr clarification (#6595) The summary line in the docstr for `Dataset.map()` says that it operates on variables in the dataset. This is a very minor update clarifying that it operates on *data* variables, as opposed to all variables (data + coord) in the Dataset. --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d255cfacbd1..b73cd797a8f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5308,7 +5308,7 @@ def map( args: Iterable[Any] = (), **kwargs: Any, ) -> Dataset: - """Apply a function to each variable in this dataset + """Apply a function to each data variable in this dataset Parameters ---------- From fc282d5979473a31529f09204d4811cfd7e5cd63 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Thu, 12 May 2022 17:43:28 +0200 Subject: [PATCH 20/24] re-add timedelta support for polyval (#6599) --- xarray/core/computation.py | 6 +++++- xarray/tests/test_computation.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 823cbe02560..8d450cceef9 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1933,7 +1933,8 @@ def _ensure_numeric(data: T_Xarray) -> T_Xarray: from .dataset import Dataset def to_floatable(x: DataArray) -> DataArray: - if x.dtype.kind in "mM": + if x.dtype.kind == "M": + # datetimes return x.copy( data=datetime_to_numeric( x.data, @@ -1941,6 +1942,9 @@ def to_floatable(x: DataArray) -> DataArray: datetime_unit="ns", ), ) + elif x.dtype.kind == "m": + # timedeltas + return x.astype(float) return x if isinstance(data, Dataset): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 737ed82bc05..b8aa05c75e7 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -2010,6 +2010,14 @@ def test_where_attrs() -> None: ), id="datetime", ), + pytest.param( + xr.DataArray( + np.array([1000, 2000, 3000], dtype="timedelta64[ns]"), dims="x" + ), + xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}), + xr.DataArray([1000.0, 2000.0, 3000.0], dims="x"), + id="timedelta", + ), ], ) def test_polyval( From c34ef8a60227720724e90aa11a6266c0026a812a Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Thu, 12 May 2022 21:01:58 +0200 Subject: [PATCH 21/24] change polyval dim ordering (#6601) --- xarray/core/computation.py | 2 +- xarray/tests/test_computation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 8d450cceef9..81b5e3fd915 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1909,7 +1909,7 @@ def polyval( # using Horner's method # https://en.wikipedia.org/wiki/Horner%27s_method - res = coeffs.isel({degree_dim: max_deg}, drop=True) + zeros_like(coord) + res = zeros_like(coord) + coeffs.isel({degree_dim: max_deg}, drop=True) for deg in range(max_deg - 1, -1, -1): res *= coord res += coeffs.isel({degree_dim: deg}, drop=True) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index b8aa05c75e7..ec8b5a5bc7c 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1951,7 +1951,7 @@ def test_where_attrs() -> None: xr.DataArray( [[0, 1], [0, 1]], dims=("y", "degree"), coords={"degree": [0, 1]} ), - xr.DataArray([[1, 2, 3], [1, 2, 3]], dims=("y", "x")), + xr.DataArray([[1, 1], [2, 2], [3, 3]], dims=("x", "y")), id="broadcast-x", ), pytest.param( From 9a62c2a8ebf934646b898a137fe0409fe8781350 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 15 May 2022 01:06:43 +0200 Subject: [PATCH 22/24] Add setuptools as dependency in ASV benchmark CI (#6609) * Test adding setuptools to required install * Update asv.conf.json * Update asv.conf.json --- asv_bench/asv.conf.json | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index 3e4137cf807..c9ddbd94b69 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -58,6 +58,8 @@ // "pip+emcee": [""], // emcee is only available for install with pip. // }, "matrix": { + "setuptools_scm[toml]": [""], // GH6609 + "setuptools_scm_git_archive": [""], // GH6609 "numpy": [""], "pandas": [""], "netcdf4": [""], From e02b1c3f6d18c7afcdf4f78cf3463652b4cc96c9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 14 May 2022 20:30:58 -0600 Subject: [PATCH 23/24] Enable flox in GroupBy and resample (#5734) Closes #5734 Closes #4473 Closes #4498 Closes #659 Closes #2237 xref https://github.com/pangeo-data/pangeo/issues/271 Co-authored-by: Anderson Banihirwe Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Mathias Hauser Co-authored-by: Stephan Hoyer --- asv_bench/asv.conf.json | 2 + asv_bench/benchmarks/groupby.py | 10 +- ci/install-upstream-wheels.sh | 2 + ci/requirements/all-but-dask.yml | 1 + ci/requirements/environment-windows.yml | 1 + ci/requirements/environment.yml | 1 + ci/requirements/min-all-deps.yml | 1 + doc/whats-new.rst | 5 + setup.cfg | 1 + xarray/core/_reductions.py | 1075 ++++++++++++++++------- xarray/core/groupby.py | 126 ++- xarray/core/options.py | 6 + xarray/core/resample.py | 33 +- xarray/core/utils.py | 18 + xarray/tests/__init__.py | 2 + xarray/tests/test_groupby.py | 116 ++- xarray/tests/test_units.py | 8 +- xarray/util/generate_reductions.py | 130 ++- xarray/util/print_versions.py | 2 + 19 files changed, 1185 insertions(+), 355 deletions(-) diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index c9ddbd94b69..5de6d6a4f76 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -67,6 +67,8 @@ "bottleneck": [""], "dask": [""], "distributed": [""], + "flox": [""], + "numpy_groupies": [""], "sparse": [""] }, diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index fa93ce9e8b5..490c2ccbd4c 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -13,6 +13,7 @@ def setup(self, *args, **kwargs): { "a": xr.DataArray(np.r_[np.repeat(1, self.n), np.repeat(2, self.n)]), "b": xr.DataArray(np.arange(2 * self.n)), + "c": xr.DataArray(np.arange(2 * self.n)), } ) self.ds2d = self.ds1d.expand_dims(z=10) @@ -50,10 +51,11 @@ class GroupByDask(GroupBy): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds1d = self.ds1d.sel(dim_0=slice(None, None, 2)).chunk({"dim_0": 50}) - self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)).chunk( - {"dim_0": 50, "z": 5} - ) + + self.ds1d = self.ds1d.sel(dim_0=slice(None, None, 2)) + self.ds1d["c"] = self.ds1d["c"].chunk({"dim_0": 50}) + self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)) + self.ds2d["c"] = self.ds2d["c"].chunk({"dim_0": 50, "z": 5}) self.ds1d_mean = self.ds1d.groupby("b").mean() self.ds2d_mean = self.ds2d.groupby("b").mean() diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index 96a39ccd20b..ff5615c17c6 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -15,6 +15,7 @@ conda uninstall -y --force \ pint \ bottleneck \ sparse \ + flox \ h5netcdf \ xarray # to limit the runtime of Upstream CI @@ -47,4 +48,5 @@ python -m pip install \ git+https://github.com/pydata/sparse \ git+https://github.com/intake/filesystem_spec \ git+https://github.com/SciTools/nc-time-axis \ + git+https://github.com/dcherian/flox \ git+https://github.com/h5netcdf/h5netcdf diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index cb9ec8d3bc5..e20ec2016ed 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -13,6 +13,7 @@ dependencies: - cfgrib - cftime - coveralls + - flox - h5netcdf - h5py - hdf5 diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 6c389c22ce6..634140fe84b 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -10,6 +10,7 @@ dependencies: - cftime - dask-core - distributed + - flox - fsspec!=2021.7.0 - h5netcdf - h5py diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 516c964afc7..d37bb7dc44a 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -12,6 +12,7 @@ dependencies: - cftime - dask-core - distributed + - flox - fsspec!=2021.7.0 - h5netcdf - h5py diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index ecabde06622..34879af730b 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -17,6 +17,7 @@ dependencies: - coveralls - dask-core=2021.04 - distributed=2021.04 + - flox=0.5 - h5netcdf=0.11 - h5py=3.1 # hdf5 1.12 conflicts with h5py=3.1 diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6cba8563ecd..680c8219a38 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -141,6 +141,11 @@ Performance - GroupBy binary operations are now vectorized. Previously this involved looping over all groups. (:issue:`5804`,:pull:`6160`) By `Deepak Cherian `_. +- Substantially improved GroupBy operations using `flox `_. + This is auto-enabled when ``flox`` is installed. Use ``xr.set_options(use_flox=False)`` to use + the old algorithm. (:issue:`4473`, :issue:`4498`, :issue:`659`, :issue:`2237`, :pull:`271`). + By `Deepak Cherian `_,`Anderson Banihirwe `_, + `Jimmy Westling `_. Internal Changes ~~~~~~~~~~~~~~~~ diff --git a/setup.cfg b/setup.cfg index 6a0a06d2367..f5dd4dde810 100644 --- a/setup.cfg +++ b/setup.cfg @@ -98,6 +98,7 @@ accel = scipy bottleneck numbagg + flox parallel = dask[complete] diff --git a/xarray/core/_reductions.py b/xarray/core/_reductions.py index 31365f39e65..d782363760a 100644 --- a/xarray/core/_reductions.py +++ b/xarray/core/_reductions.py @@ -4,11 +4,18 @@ from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Sequence, Union from . import duck_array_ops +from .options import OPTIONS +from .utils import contains_only_dask_or_numpy if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset +try: + import flox +except ImportError: + flox = None # type: ignore + class DatasetReductions: __slots__ = () @@ -1941,7 +1948,7 @@ def median( class DatasetGroupByReductions: - __slots__ = () + _obj: "Dataset" def reduce( self, @@ -1955,6 +1962,13 @@ def reduce( ) -> "Dataset": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "Dataset": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -2021,13 +2035,23 @@ def count( Data variables: da (labels) int64 1 2 2 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -2095,13 +2119,23 @@ def all( Data variables: da (labels) bool False True True """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -2169,13 +2203,23 @@ def any( Data variables: da (labels) bool True True True """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -2259,14 +2303,25 @@ def max( Data variables: da (labels) float64 nan 2.0 3.0 """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -2350,14 +2405,25 @@ def min( Data variables: da (labels) float64 nan 2.0 1.0 """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -2445,14 +2511,25 @@ def mean( Data variables: da (labels) float64 nan 2.0 2.0 """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -2557,15 +2634,27 @@ def prod( Data variables: da (labels) float64 nan 4.0 3.0 """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -2670,15 +2759,27 @@ def sum( Data variables: da (labels) float64 nan 4.0 4.0 """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -2780,15 +2881,27 @@ def std( Data variables: da (labels) float64 nan 0.0 1.414 """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -2890,15 +3003,27 @@ def var( Data variables: da (labels) float64 nan 0.0 2.0 """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -2997,7 +3122,7 @@ def median( class DatasetResampleReductions: - __slots__ = () + _obj: "Dataset" def reduce( self, @@ -3011,6 +3136,13 @@ def reduce( ) -> "Dataset": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "Dataset": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -3077,13 +3209,23 @@ def count( Data variables: da (time) int64 1 3 1 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -3151,13 +3293,23 @@ def all( Data variables: da (time) bool True True False """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -3225,13 +3377,23 @@ def any( Data variables: da (time) bool True True True """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -3315,14 +3477,25 @@ def max( Data variables: da (time) float64 1.0 3.0 nan """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -3406,14 +3579,25 @@ def min( Data variables: da (time) float64 1.0 1.0 nan """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -3501,14 +3685,25 @@ def mean( Data variables: da (time) float64 1.0 2.0 nan """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -3613,15 +3808,27 @@ def prod( Data variables: da (time) float64 nan 6.0 nan """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -3726,15 +3933,27 @@ def sum( Data variables: da (time) float64 nan 6.0 nan """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -3836,15 +4055,27 @@ def std( Data variables: da (time) float64 nan 1.0 nan """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -3946,15 +4177,27 @@ def var( Data variables: da (time) float64 nan 1.0 nan """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -4053,7 +4296,7 @@ def median( class DataArrayGroupByReductions: - __slots__ = () + _obj: "DataArray" def reduce( self, @@ -4067,6 +4310,13 @@ def reduce( ) -> "DataArray": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "DataArray": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -4128,12 +4378,21 @@ def count( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.count, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -4196,12 +4455,21 @@ def all( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -4264,12 +4532,21 @@ def any( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -4346,13 +4623,23 @@ def max( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -4429,13 +4716,23 @@ def min( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -4516,13 +4813,23 @@ def mean( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -4618,14 +4925,25 @@ def prod( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -4721,14 +5039,25 @@ def sum( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -4821,14 +5150,25 @@ def std( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -4921,14 +5261,25 @@ def var( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -5019,7 +5370,7 @@ def median( class DataArrayResampleReductions: - __slots__ = () + _obj: "DataArray" def reduce( self, @@ -5033,6 +5384,13 @@ def reduce( ) -> "DataArray": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "DataArray": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -5094,12 +5452,21 @@ def count( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -5162,12 +5529,21 @@ def all( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -5230,12 +5606,21 @@ def any( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -5312,13 +5697,23 @@ def max( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -5395,13 +5790,23 @@ def min( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -5482,13 +5887,23 @@ def mean( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -5584,14 +5999,25 @@ def prod( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -5687,14 +6113,25 @@ def sum( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -5787,14 +6224,25 @@ def std( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -5887,14 +6335,25 @@ def var( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 151ef844f44..fec8954c9e2 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -264,6 +264,10 @@ class GroupBy: "_stacked_dim", "_unique_coord", "_dims", + "_squeeze", + # Save unstacked object for flox + "_original_obj", + "_unstacked_group", "_bins", ) @@ -326,6 +330,10 @@ def __init__( if getattr(group, "name", None) is None: group.name = "group" + self._original_obj = obj + self._unstacked_group = group + self._bins = bins + group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) (group_dim,) = group.dims @@ -342,7 +350,7 @@ def __init__( if bins is not None: if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") - binned = pd.cut(group.values, bins, **cut_kwargs) + binned, bins = pd.cut(group.values, bins, **cut_kwargs, retbins=True) new_dim_name = group.name + "_bins" group = DataArray(binned, group.coords, name=new_dim_name) full_index = binned.categories @@ -403,6 +411,7 @@ def __init__( self._full_index = full_index self._restore_coord_dims = restore_coord_dims self._bins = bins + self._squeeze = squeeze # cached attributes self._groups = None @@ -570,6 +579,121 @@ def _maybe_unstack(self, obj): obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) return obj + def _flox_reduce(self, dim, **kwargs): + """Adaptor function that translates our groupby API to that of flox.""" + from flox.xarray import xarray_reduce + + from .dataset import Dataset + + obj = self._original_obj + + # preserve current strategy (approximately) for dask groupby. + # We want to control the default anyway to prevent surprises + # if flox decides to change its default + kwargs.setdefault("method", "split-reduce") + + numeric_only = kwargs.pop("numeric_only", None) + if numeric_only: + non_numeric = { + name: var + for name, var in obj.data_vars.items() + if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + } + else: + non_numeric = {} + + # weird backcompat + # reducing along a unique indexed dimension with squeeze=True + # should raise an error + if ( + dim is None or dim == self._group.name + ) and self._group.name in obj.xindexes: + index = obj.indexes[self._group.name] + if index.is_unique and self._squeeze: + raise ValueError(f"cannot reduce over dimensions {self._group.name!r}") + + # group is only passed by resample + group = kwargs.pop("group", None) + if group is None: + if isinstance(self._unstacked_group, _DummyGroup): + group = self._unstacked_group.name + else: + group = self._unstacked_group + + unindexed_dims = tuple() + if isinstance(group, str): + if group in obj.dims and group not in obj._indexes and self._bins is None: + unindexed_dims = (group,) + group = self._original_obj[group] + + if isinstance(dim, str): + dim = (dim,) + elif dim is None: + dim = group.dims + elif dim is Ellipsis: + dim = tuple(self._original_obj.dims) + + # Do this so we raise the same error message whether flox is present or not. + # Better to control it here than in flox. + if any(d not in group.dims and d not in self._original_obj.dims for d in dim): + raise ValueError(f"cannot reduce over dimensions {dim}.") + + if self._bins is not None: + # TODO: fix this; When binning by time, self._bins is a DatetimeIndex + expected_groups = (np.array(self._bins),) + isbin = (True,) + # This is an annoying hack. Xarray returns np.nan + # when there are no observations in a bin, instead of 0. + # We can fake that here by forcing min_count=1. + if kwargs["func"] == "count": + if "fill_value" not in kwargs or kwargs["fill_value"] is None: + kwargs["fill_value"] = np.nan + # note min_count makes no sense in the xarray world + # as a kwarg for count, so this should be OK + kwargs["min_count"] = 1 + # empty bins have np.nan regardless of dtype + # flox's default would not set np.nan for integer dtypes + kwargs.setdefault("fill_value", np.nan) + else: + expected_groups = (self._unique_coord.values,) + isbin = False + + result = xarray_reduce( + self._original_obj.drop_vars(non_numeric), + group, + dim=dim, + expected_groups=expected_groups, + isbin=isbin, + **kwargs, + ) + + # Ignore error when the groupby reduction is effectively + # a reduction of the underlying dataset + result = result.drop_vars(unindexed_dims, errors="ignore") + + # broadcast and restore non-numeric data variables (backcompat) + for name, var in non_numeric.items(): + if all(d not in var.dims for d in dim): + result[name] = var.variable.set_dims( + (group.name,) + var.dims, (result.sizes[group.name],) + var.shape + ) + + if self._bins is not None: + # bins provided to flox are at full precision + # the bin edge labels have a default precision of 3 + # reassign to fix that. + new_coord = [ + pd.Interval(inter.left, inter.right) for inter in self._full_index + ] + result[self._group.name] = new_coord + # Fix dimension order when binning a dimension coordinate + # Needed as long as we do a separate code path for pint; + # For some reason Datasets and DataArrays behave differently! + if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: + result = result.transpose(self._group.name, ...) + + return result + def fillna(self, value): """Fill missing values in this object by group. diff --git a/xarray/core/options.py b/xarray/core/options.py index 399afe90b66..d31f2577601 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -27,6 +27,7 @@ class T_Options(TypedDict): keep_attrs: Literal["default", True, False] warn_for_unclosed_files: bool use_bottleneck: bool + use_flox: bool OPTIONS: T_Options = { @@ -45,6 +46,7 @@ class T_Options(TypedDict): "file_cache_maxsize": 128, "keep_attrs": "default", "use_bottleneck": True, + "use_flox": True, "warn_for_unclosed_files": False, } @@ -70,6 +72,7 @@ def _positive_integer(value): "file_cache_maxsize": _positive_integer, "keep_attrs": lambda choice: choice in [True, False, "default"], "use_bottleneck": lambda value: isinstance(value, bool), + "use_flox": lambda value: isinstance(value, bool), "warn_for_unclosed_files": lambda value: isinstance(value, bool), } @@ -180,6 +183,9 @@ class set_options: use_bottleneck : bool, default: True Whether to use ``bottleneck`` to accelerate 1D reductions and 1D rolling reduction operations. + use_flox : bool, default: True + Whether to use ``numpy_groupies`` and `flox`` to + accelerate groupby and resampling reductions. warn_for_unclosed_files : bool, default: False Whether or not to issue a warning when unclosed files are deallocated. This is mostly useful for debugging. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ed665ad4048..bcc4bfb90cd 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,13 +1,15 @@ import warnings from typing import Any, Callable, Hashable, Sequence, Union +import numpy as np + from ._reductions import DataArrayResampleReductions, DatasetResampleReductions -from .groupby import DataArrayGroupByBase, DatasetGroupByBase +from .groupby import DataArrayGroupByBase, DatasetGroupByBase, GroupBy RESAMPLE_DIM = "__resample_dim__" -class Resample: +class Resample(GroupBy): """An object that extends the `GroupBy` object with additional logic for handling specialized re-sampling operations. @@ -21,6 +23,29 @@ class Resample: """ + def _flox_reduce(self, dim, **kwargs): + + from .dataarray import DataArray + + kwargs.setdefault("method", "cohorts") + + # now create a label DataArray since resample doesn't do that somehow + repeats = [] + for slicer in self._group_indices: + stop = ( + slicer.stop + if slicer.stop is not None + else self._obj.sizes[self._group_dim] + ) + repeats.append(stop - slicer.start) + labels = np.repeat(self._unique_coord.data, repeats) + group = DataArray(labels, dims=(self._group_dim,), name=self._unique_coord.name) + + result = super()._flox_reduce(dim=dim, group=group, **kwargs) + result = self._maybe_restore_empty_groups(result) + result = result.rename({RESAMPLE_DIM: self._group_dim}) + return result + def _upsample(self, method, *args, **kwargs): """Dispatch function to call appropriate up-sampling methods on data. @@ -158,7 +183,7 @@ def _interpolate(self, kind="linear"): ) -class DataArrayResample(DataArrayGroupByBase, DataArrayResampleReductions, Resample): +class DataArrayResample(Resample, DataArrayGroupByBase, DataArrayResampleReductions): """DataArrayGroupBy object specialized to time resampling operations over a specified dimension """ @@ -249,7 +274,7 @@ def apply(self, func, args=(), shortcut=None, **kwargs): return self.map(func=func, shortcut=shortcut, args=args, **kwargs) -class DatasetResample(DatasetGroupByBase, DatasetResampleReductions, Resample): +class DatasetResample(Resample, DatasetGroupByBase, DatasetResampleReductions): """DatasetGroupBy object specialized to resampling a specified dimension""" def __init__(self, *args, dim=None, resample_dim=None, **kwargs): diff --git a/xarray/core/utils.py b/xarray/core/utils.py index aaa087a3532..eda08becc20 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -928,3 +928,21 @@ def iterate_nested(nested_list): yield from iterate_nested(item) else: yield item + + +def contains_only_dask_or_numpy(obj) -> bool: + """Returns True if xarray object contains only numpy or dask arrays. + + Expects obj to be Dataset or DataArray""" + from .dataarray import DataArray + from .pycompat import is_duck_dask_array + + if isinstance(obj, DataArray): + obj = obj._to_temp_dataset() + + return all( + [ + isinstance(var.data, np.ndarray) or is_duck_dask_array(var.data) + for var in obj.variables.values() + ] + ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7872fec2e62..65f0bc08261 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -78,6 +78,8 @@ def _importorskip(modname, minversion=None): has_cartopy, requires_cartopy = _importorskip("cartopy") has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") +has_flox, requires_flox = _importorskip("flox") + # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b4b93d1dba3..8c745dc640d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -17,6 +17,7 @@ assert_identical, create_test_data, requires_dask, + requires_flox, requires_scipy, ) @@ -24,7 +25,10 @@ @pytest.fixture def dataset(): ds = xr.Dataset( - {"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))}, + { + "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), + "baz": ("x", ["e", "f", "g"]), + }, {"x": ["a", "b", "c"], "y": [1, 2, 3, 4], "z": [1, 2]}, ) ds["boo"] = (("z", "y"), [["f", "g", "h", "j"]] * 2) @@ -71,6 +75,15 @@ def test_multi_index_groupby_map(dataset) -> None: assert_equal(expected, actual) +def test_reduce_numeric_only(dataset) -> None: + gb = dataset.groupby("x", squeeze=False) + with xr.set_options(use_flox=False): + expected = gb.sum() + with xr.set_options(use_flox=True): + actual = gb.sum() + assert_identical(expected, actual) + + def test_multi_index_groupby_sum() -> None: # regression test for GH873 ds = xr.Dataset( @@ -961,6 +974,17 @@ def test_groupby_dataarray_map_dataset_func(): assert_identical(actual, expected) +@requires_flox +@pytest.mark.parametrize("kwargs", [{"method": "map-reduce"}, {"engine": "numpy"}]) +def test_groupby_flox_kwargs(kwargs): + ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) + with xr.set_options(use_flox=False): + expected = ds.groupby("c").mean() + with xr.set_options(use_flox=True): + actual = ds.groupby("c").mean(**kwargs) + assert_identical(expected, actual) + + class TestDataArrayGroupBy: @pytest.fixture(autouse=True) def setup(self): @@ -1016,19 +1040,22 @@ def test_groupby_properties(self): assert_array_equal(expected_groups[key], grouped.groups[key]) assert 3 == len(grouped) - def test_groupby_map_identity(self): + @pytest.mark.parametrize( + "by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)] + ) + @pytest.mark.parametrize("shortcut", [True, False]) + @pytest.mark.parametrize("squeeze", [True, False]) + def test_groupby_map_identity(self, by, use_da, shortcut, squeeze) -> None: expected = self.da - idx = expected.coords["y"] + if use_da: + by = expected.coords[by] def identity(x): return x - for g in ["x", "y", "abc", idx]: - for shortcut in [False, True]: - for squeeze in [False, True]: - grouped = expected.groupby(g, squeeze=squeeze) - actual = grouped.map(identity, shortcut=shortcut) - assert_identical(expected, actual) + grouped = expected.groupby(by, squeeze=squeeze) + actual = grouped.map(identity, shortcut=shortcut) + assert_identical(expected, actual) def test_groupby_sum(self): array = self.da @@ -1083,19 +1110,21 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) - def test_groupby_sum_default(self): + @pytest.mark.parametrize("method", ["sum", "mean", "median"]) + def test_groupby_reductions(self, method): array = self.da grouped = array.groupby("abc") - expected_sum_all = Dataset( + reduction = getattr(np, method) + expected = Dataset( { "foo": Variable( ["x", "abc"], np.array( [ - self.x[:, :9].sum(axis=-1), - self.x[:, 10:].sum(axis=-1), - self.x[:, 9:10].sum(axis=-1), + reduction(self.x[:, :9], axis=-1), + reduction(self.x[:, 10:], axis=-1), + reduction(self.x[:, 9:10], axis=-1), ] ).T, ), @@ -1103,7 +1132,14 @@ def test_groupby_sum_default(self): } )["foo"] - assert_allclose(expected_sum_all, grouped.sum(dim="y")) + with xr.set_options(use_flox=False): + actual_legacy = getattr(grouped, method)(dim="y") + + with xr.set_options(use_flox=True): + actual_npg = getattr(grouped, method)(dim="y") + + assert_allclose(expected, actual_legacy) + assert_allclose(expected, actual_npg) def test_groupby_count(self): array = DataArray( @@ -1318,13 +1354,23 @@ def test_groupby_bins(self): expected = DataArray( [1, 5], dims="dim_0_bins", coords={"dim_0_bins": bin_coords} ) - # the problem with this is that it overwrites the dimensions of array! - # actual = array.groupby('dim_0', bins=bins).sum() - actual = array.groupby_bins("dim_0", bins).map(lambda x: x.sum()) + actual = array.groupby_bins("dim_0", bins=bins).sum() + assert_identical(expected, actual) + + actual = array.groupby_bins("dim_0", bins=bins).map(lambda x: x.sum()) assert_identical(expected, actual) + # make sure original array dims are unchanged assert len(array.dim_0) == 4 + da = xr.DataArray(np.ones((2, 3, 4))) + bins = [-1, 0, 1, 2] + with xr.set_options(use_flox=False): + actual = da.groupby_bins("dim_0", bins).mean(...) + with xr.set_options(use_flox=True): + expected = da.groupby_bins("dim_0", bins).mean(...) + assert_allclose(actual, expected) + def test_groupby_bins_empty(self): array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty @@ -1350,6 +1396,27 @@ def test_groupby_bins_multidim(self): actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) assert_identical(expected, actual) + bins = [-2, -1, 0, 1, 2] + field = DataArray(np.ones((5, 3)), dims=("x", "y")) + by = DataArray( + np.array([[-1.5, -1.5, 0.5, 1.5, 1.5] * 3]).reshape(5, 3), dims=("x", "y") + ) + actual = field.groupby_bins(by, bins=bins).count() + + bincoord = np.array( + [ + pd.Interval(left, right, closed="right") + for left, right in zip(bins[:-1], bins[1:]) + ], + dtype=object, + ) + expected = DataArray( + np.array([6, np.nan, 3, 6]), + dims="group_bins", + coords={"group_bins": bincoord}, + ) + assert_identical(actual, expected) + def test_groupby_bins_sort(self): data = xr.DataArray( np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} @@ -1357,6 +1424,12 @@ def test_groupby_bins_sort(self): binned_mean = data.groupby_bins("x", bins=11).mean() assert binned_mean.to_index().is_monotonic_increasing + with xr.set_options(use_flox=True): + actual = data.groupby_bins("x", bins=11).count() + with xr.set_options(use_flox=False): + expected = data.groupby_bins("x", bins=11).count() + assert_identical(actual, expected) + def test_groupby_assign_coords(self): array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") @@ -1769,7 +1842,7 @@ def test_resample_min_count(self): ], dim=actual["time"], ) - assert_equal(expected, actual) + assert_allclose(expected, actual) def test_resample_by_mean_with_keep_attrs(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) @@ -1903,7 +1976,7 @@ def test_resample_ds_da_are_the_same(self): "x": np.arange(5), } ) - assert_identical( + assert_allclose( ds.resample(time="M").mean()["foo"], ds.foo.resample(time="M").mean() ) @@ -1916,6 +1989,3 @@ def func(arg1, arg2, arg3=0.0): expected = xr.Dataset({"foo": ("time", [3.0, 3.0, 3.0]), "time": times}) actual = ds.resample(time="D").map(func, args=(1.0,), arg3=1.0) assert_identical(expected, actual) - - -# TODO: move other groupby tests from test_dataset and test_dataarray over here diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index c18b7d18c04..679733e1ecf 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5344,8 +5344,12 @@ def test_computation_objects(self, func, variant, dtype): units = extract_units(ds) args = [] if func.name != "groupby" else ["y"] - expected = attach_units(func(strip_units(ds)).mean(*args), units) - actual = func(ds).mean(*args) + # Doesn't work with flox because pint doesn't implement + # ufunc.reduceat or np.bincount + # kwargs = {"engine": "numpy"} if "groupby" in func.name else {} + kwargs = {} + expected = attach_units(func(strip_units(ds)).mean(*args, **kwargs), units) + actual = func(ds).mean(*args, **kwargs) assert_units_equal(expected, actual) assert_allclose(expected, actual) diff --git a/xarray/util/generate_reductions.py b/xarray/util/generate_reductions.py index e79c94e8907..96b91c16906 100644 --- a/xarray/util/generate_reductions.py +++ b/xarray/util/generate_reductions.py @@ -23,13 +23,19 @@ from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Sequence, Union from . import duck_array_ops +from .options import OPTIONS +from .utils import contains_only_dask_or_numpy if TYPE_CHECKING: from .dataarray import DataArray - from .dataset import Dataset''' + from .dataset import Dataset +try: + import flox +except ImportError: + flox = None # type: ignore''' -CLASS_PREAMBLE = """ +DEFAULT_PREAMBLE = """ class {obj}{cls}Reductions: __slots__ = () @@ -46,6 +52,54 @@ def reduce( ) -> "{obj}": raise NotImplementedError()""" +GROUPBY_PREAMBLE = """ + +class {obj}{cls}Reductions: + _obj: "{obj}" + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "{obj}": + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "{obj}": + raise NotImplementedError()""" + +RESAMPLE_PREAMBLE = """ + +class {obj}{cls}Reductions: + _obj: "{obj}" + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "{obj}": + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "{obj}": + raise NotImplementedError()""" + TEMPLATE_REDUCTION_SIGNATURE = ''' def {method}( self, @@ -113,11 +167,7 @@ def {method}( These could include dask-specific kwargs like ``split_every``.""" NAN_CUM_METHODS = ["cumsum", "cumprod"] - -NUMERIC_ONLY_METHODS = [ - "cumsum", - "cumprod", -] +NUMERIC_ONLY_METHODS = ["cumsum", "cumprod"] _NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing." ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example") @@ -182,6 +232,7 @@ def __init__( docref, docref_description, example_call_preamble, + definition_preamble, see_also_obj=None, ): self.datastructure = datastructure @@ -190,7 +241,7 @@ def __init__( self.docref = docref self.docref_description = docref_description self.example_call_preamble = example_call_preamble - self.preamble = CLASS_PREAMBLE.format(obj=datastructure.name, cls=cls) + self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls) if not see_also_obj: self.see_also_obj = self.datastructure.name else: @@ -268,6 +319,53 @@ def generate_example(self, method): >>> {calculation}(){extra_examples}""" +class GroupByReductionGenerator(ReductionGenerator): + def generate_code(self, method): + extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] + + if self.datastructure.numeric_only: + extra_kwargs.append(f"numeric_only={method.numeric_only},") + + # numpy_groupies & flox do not support median + # https://github.com/ml31415/numpy-groupies/issues/43 + if method.name == "median": + indent = 12 + else: + indent = 16 + + if extra_kwargs: + extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), indent * " ") + else: + extra_kwargs = "" + + if method.name == "median": + return f"""\ + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + else: + return f"""\ + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="{method.name}", + dim=dim,{extra_kwargs} + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + class GenericReductionGenerator(ReductionGenerator): def generate_code(self, method): extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] @@ -335,6 +433,7 @@ class DataStructure: docref_description="reduction or aggregation operations", example_call_preamble="", see_also_obj="DataArray", + definition_preamble=DEFAULT_PREAMBLE, ) DATAARRAY_GENERATOR = GenericReductionGenerator( cls="", @@ -344,39 +443,43 @@ class DataStructure: docref_description="reduction or aggregation operations", example_call_preamble="", see_also_obj="Dataset", + definition_preamble=DEFAULT_PREAMBLE, ) - -DATAARRAY_GROUPBY_GENERATOR = GenericReductionGenerator( +DATAARRAY_GROUPBY_GENERATOR = GroupByReductionGenerator( cls="GroupBy", datastructure=DATAARRAY_OBJECT, methods=REDUCTION_METHODS, docref="groupby", docref_description="groupby operations", example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, ) -DATAARRAY_RESAMPLE_GENERATOR = GenericReductionGenerator( +DATAARRAY_RESAMPLE_GENERATOR = GroupByReductionGenerator( cls="Resample", datastructure=DATAARRAY_OBJECT, methods=REDUCTION_METHODS, docref="resampling", docref_description="resampling operations", example_call_preamble='.resample(time="3M")', + definition_preamble=RESAMPLE_PREAMBLE, ) -DATASET_GROUPBY_GENERATOR = GenericReductionGenerator( +DATASET_GROUPBY_GENERATOR = GroupByReductionGenerator( cls="GroupBy", datastructure=DATASET_OBJECT, methods=REDUCTION_METHODS, docref="groupby", docref_description="groupby operations", example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, ) -DATASET_RESAMPLE_GENERATOR = GenericReductionGenerator( +DATASET_RESAMPLE_GENERATOR = GroupByReductionGenerator( cls="Resample", datastructure=DATASET_OBJECT, methods=REDUCTION_METHODS, docref="resampling", docref_description="resampling operations", example_call_preamble='.resample(time="3M")', + definition_preamble=RESAMPLE_PREAMBLE, ) @@ -386,6 +489,7 @@ class DataStructure: p = Path(os.getcwd()) filepath = p.parent / "xarray" / "xarray" / "core" / "_reductions.py" + # filepath = p.parent / "core" / "_reductions.py" # Run from script location with open(filepath, mode="w", encoding="utf-8") as f: f.write(MODULE_PREAMBLE + "\n") for gen in [ diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 561126ea05f..b8689e3a18f 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -122,6 +122,8 @@ def show_versions(file=sys.stdout): ("cupy", lambda mod: mod.__version__), ("pint", lambda mod: mod.__version__), ("sparse", lambda mod: mod.__version__), + ("flox", lambda mod: mod.__version__), + ("numpy_groupies", lambda mod: mod.__version__), # xarray setup/test ("setuptools", lambda mod: mod.__version__), ("pip", lambda mod: mod.__version__), From 8de706151e183f448e1af9115770713d18e229f1 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Sun, 15 May 2022 10:42:31 -0400 Subject: [PATCH 24/24] Fix overflow issue in decode_cf_datetime for dtypes <= np.uint32 (#6598) --- doc/whats-new.rst | 3 +++ xarray/coding/times.py | 9 ++++++--- xarray/tests/test_coding_times.py | 27 +++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 680c8219a38..c9ee52f3da0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -126,6 +126,9 @@ Bug fixes - :py:meth:`isel` with `drop=True` works as intended with scalar :py:class:`DataArray` indexers. (:issue:`6554`, :pull:`6579`) By `Michael Niklas `_. +- Fixed silent overflow issue when decoding times encoded with 32-bit and below + unsigned integer data types (:issue:`6589`, :pull:`6598`). By `Spencer Clark + `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 42a815300e5..5cdd9472277 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -218,9 +218,12 @@ def _decode_datetime_with_pandas(flat_num_dates, units, calendar): pd.to_timedelta(flat_num_dates.max(), delta) + ref_date # To avoid integer overflow when converting to nanosecond units for integer - # dtypes smaller than np.int64 cast all integer-dtype arrays to np.int64 - # (GH 2002). - if flat_num_dates.dtype.kind == "i": + # dtypes smaller than np.int64 cast all integer and unsigned integer dtype + # arrays to np.int64 (GH 2002, GH 6589). Note this is safe even in the case + # of np.uint64 values, because any np.uint64 value that would lead to + # overflow when converting to np.int64 would not be representable with a + # timedelta64 value, and therefore would raise an error in the lines above. + if flat_num_dates.dtype.kind in "iu": flat_num_dates = flat_num_dates.astype(np.int64) # Cast input ordinals to integers of nanoseconds because pd.to_timedelta diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index f10523aecbb..a5344fe4c85 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1121,3 +1121,30 @@ def test_should_cftime_be_used_target_not_npable(): ValueError, match="Calendar 'noleap' is only valid with cftime." ): _should_cftime_be_used(src, "noleap", False) + + +@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64]) +def test_decode_cf_datetime_uint(dtype): + units = "seconds since 2018-08-22T03:23:03Z" + num_dates = dtype(50) + result = decode_cf_datetime(num_dates, units) + expected = np.asarray(np.datetime64("2018-08-22T03:23:53", "ns")) + np.testing.assert_equal(result, expected) + + +@requires_cftime +def test_decode_cf_datetime_uint64_with_cftime(): + units = "days since 1700-01-01" + num_dates = np.uint64(182621) + result = decode_cf_datetime(num_dates, units) + expected = np.asarray(np.datetime64("2200-01-01", "ns")) + np.testing.assert_equal(result, expected) + + +@requires_cftime +def test_decode_cf_datetime_uint64_with_cftime_overflow_error(): + units = "microseconds since 1700-01-01" + calendar = "360_day" + num_dates = np.uint64(1_000_000 * 86_400 * 360 * 500_000) + with pytest.raises(OverflowError): + decode_cf_datetime(num_dates, units, calendar)