From da91ef967567959e9dcb62f057dfcff169cce858 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 5 Mar 2021 12:09:58 +0100 Subject: [PATCH] [ArrayManager] Fix window operations with axis=1 --- .github/workflows/ci.yml | 1 + pandas/conftest.py | 2 +- pandas/core/internals/array_manager.py | 24 ++++++++++++++++++++++++ pandas/core/window/rolling.py | 10 +++++++++- pandas/tests/window/test_rolling.py | 10 ++++++---- 5 files changed, 41 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 59c3fc8f05105..6c60522092739 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -192,3 +192,4 @@ jobs: pytest pandas/tests/tseries/ pytest pandas/tests/tslibs/ pytest pandas/tests/util/ + pytest pandas/tests/window/ diff --git a/pandas/conftest.py b/pandas/conftest.py index 0c2e9a6942a13..07a20d9e0eb5c 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -190,7 +190,7 @@ def add_imports(doctest_namespace): # ---------------------------------------------------------------- # Common arguments # ---------------------------------------------------------------- -@pytest.fixture(params=[0, 1, "index", "columns"], ids=lambda x: f"axis {repr(x)}") +@pytest.fixture(params=[0, 1, "index", "columns"], ids=lambda x: f"axis={repr(x)}") def axis(request): """ Fixture for returning the axis numbers of a DataFrame. diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index 0449be84bdcf7..5131cd63e23be 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -403,6 +403,30 @@ def apply( return type(self)(result_arrays, new_axes) + def apply_2d( + self: T, + f, + ignore_failures: bool = False, + **kwargs, + ) -> T: + """ + Variant of `apply`, but where the function should not be applied to + each column independently, but to the full data as a 2D array. + """ + values = self.as_array() + try: + result = f(values, **kwargs) + except (TypeError, NotImplementedError): + if not ignore_failures: + raise + result_arrays = [] + new_axes = [self._axes[0], self.axes[1].take([])] + else: + result_arrays = [result[:, i] for i in range(len(self._axes[1]))] + new_axes = self._axes + + return type(self)(result_arrays, new_axes) + def apply_with_block(self: T, f, align_keys=None, **kwargs) -> T: align_keys = align_keys or [] diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 6f1e2ce121775..299e1755c0025 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -69,6 +69,7 @@ Index, MultiIndex, ) +from pandas.core.internals import ArrayManager from pandas.core.reshape.concat import concat from pandas.core.util.numba_ import ( NUMBA_FUNC_CACHE, @@ -410,7 +411,14 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike: res_values = homogeneous_func(values) return getattr(res_values, "T", res_values) - new_mgr = mgr.apply(hfunc, ignore_failures=True) + def hfunc2d(values: ArrayLike) -> ArrayLike: + values = self._prep_values(values) + return homogeneous_func(values) + + if isinstance(mgr, ArrayManager) and self.axis == 1: + new_mgr = mgr.apply_2d(hfunc2d, ignore_failures=True) + else: + new_mgr = mgr.apply(hfunc, ignore_failures=True) out = obj._constructor(new_mgr) if out.shape[1] == 0 and obj.shape[1] > 0: diff --git a/pandas/tests/window/test_rolling.py b/pandas/tests/window/test_rolling.py index fc2e86310dae9..70c076e086fb7 100644 --- a/pandas/tests/window/test_rolling.py +++ b/pandas/tests/window/test_rolling.py @@ -397,7 +397,7 @@ def test_rolling_datetime(axis_frame, tz_naive_fixture): tm.assert_frame_equal(result, expected) -def test_rolling_window_as_string(): +def test_rolling_window_as_string(using_array_manager): # see gh-22590 date_today = datetime.now() days = date_range(date_today, date_today + timedelta(365), freq="D") @@ -450,9 +450,11 @@ def test_rolling_window_as_string(): + [95.0] * 20 ) - expected = Series( - expData, index=days.rename("DateCol")._with_freq(None), name="metric" - ) + index = days.rename("DateCol") + if not using_array_manager: + # INFO(ArrayManager) preserves the frequence of the index + index = index._with_freq(None) + expected = Series(expData, index=index, name="metric") tm.assert_series_equal(result, expected)