Skip to content

Commit

Permalink
Fix direct reductions of Xarray objects
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 13, 2024
1 parent 41372e0 commit fb49adc
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 36 deletions.
12 changes: 7 additions & 5 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,16 @@ def xarray_reduce(
# reducing along a dimension along which groups do not vary
# This is really just a normal reduction.
# This is not right when binning so we exclude.
if isinstance(func, str):
dsfunc = func[3:] if skipna else func
else:
if isinstance(func, str) and func.startswith("nan"):
raise ValueError(f"Specify func={func[3:]}, skipna=True instead of func={func}")
elif isinstance(func, Aggregation):
raise NotImplementedError(
"func must be a string when reducing along a dimension not present in `by`"
)
# TODO: skipna needs test
result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna)
# skipna is not supported for all reductions
# https://github.com/pydata/xarray/issues/8819
kwargs = {"skipna": skipna} if skipna is not None else {}
result = getattr(ds_broad, func)(dim=dim_tuple, **kwargs)
if isinstance(obj, xr.DataArray):
return obj._from_temp_dataset(result)
else:
Expand Down
32 changes: 32 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,35 @@ def assert_equal_tuple(a, b):
np.testing.assert_array_equal(a_, b_)
else:
assert a_ == b_


SCIPY_STATS_FUNCS = ("mode", "nanmode")
BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS
ALL_FUNCS = (
"sum",
"nansum",
"argmax",
"nanfirst",
"nanargmax",
"prod",
"nanprod",
"mean",
"nanmean",
"var",
"nanvar",
"std",
"nanstd",
"max",
"nanmax",
"min",
"nanmin",
"argmin",
"nanargmin",
"any",
"all",
"nanlast",
"median",
"nanmedian",
"quantile",
"nanquantile",
) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS)
34 changes: 3 additions & 31 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
)

from . import (
ALL_FUNCS,
BLOCKWISE_FUNCS,
SCIPY_STATS_FUNCS,
assert_equal,
assert_equal_tuple,
has_dask,
raise_if_dask_computes,
requires_dask,
requires_scipy,
)

logger = logging.getLogger("flox")
Expand All @@ -60,36 +62,6 @@ def dask_array_ones(*args):


DEFAULT_QUANTILE = 0.9
SCIPY_STATS_FUNCS = ("mode", "nanmode")
BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS
ALL_FUNCS = (
"sum",
"nansum",
"argmax",
"nanfirst",
"nanargmax",
"prod",
"nanprod",
"mean",
"nanmean",
"var",
"nanvar",
"std",
"nanstd",
"max",
"nanmax",
"min",
"nanmin",
"argmin",
"nanargmin",
"any",
"all",
"nanlast",
"median",
"nanmedian",
"quantile",
"nanquantile",
) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS)

if TYPE_CHECKING:
from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method
Expand Down
30 changes: 30 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flox.xarray import rechunk_for_blockwise, xarray_reduce

from . import (
ALL_FUNCS,
assert_equal,
has_dask,
raise_if_dask_computes,
Expand Down Expand Up @@ -710,3 +711,32 @@ def test_multiple_quantiles(q, chunk, by_ndim, skipna):
with xr.set_options(use_flox=False):
expected = da.groupby(by).quantile(q, skipna=skipna)
xr.testing.assert_allclose(expected, actual)


@pytest.mark.parametrize("func", ALL_FUNCS)
def test_direct_reduction(func):
if "arg" in func or "mode" in func:
pytest.skip()
# regression test for https://github.com/pydata/xarray/issues/8819
rand = np.random.choice([True, False], size=(2, 3))
if func not in ["any", "all"]:
rand = rand.astype(float)

if "nan" in func:
func = func[3:]
kwargs = {"skipna": True}
else:
kwargs = {}

if "first" not in func and "last" not in func:
kwargs["dim"] = "y"

if "quantile" in func:
kwargs["q"] = 0.9

data = xr.DataArray(rand, dims=("x", "y"), coords={"x": [10, 20], "y": [0, 1, 2]})
with xr.set_options(use_flox=True):
actual = getattr(data.groupby("x"), func)(**kwargs)
with xr.set_options(use_flox=False):
expected = getattr(data.groupby("x"), func)(**kwargs)
xr.testing.assert_identical(expected, actual)

0 comments on commit fb49adc

Please sign in to comment.