From 13328fe32381894fcd0c1705c4d567f870e84447 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Fri, 3 Jan 2020 11:15:57 -0800 Subject: [PATCH] REF/TST: PeriodArray comparisons with listlike --- pandas/core/arrays/period.py | 44 ++++++++++++---- pandas/core/indexes/base.py | 5 ++ pandas/tests/arithmetic/test_period.py | 73 ++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 10 deletions(-) diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index 056c80717e54f..b7d841ab5c6a1 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -29,6 +29,7 @@ is_datetime64_dtype, is_float_dtype, is_list_like, + is_object_dtype, is_period_dtype, pandas_dtype, ) @@ -41,6 +42,7 @@ ) from pandas.core.dtypes.missing import isna, notna +from pandas.core import ops import pandas.core.algorithms as algos from pandas.core.arrays import datetimelike as dtl import pandas.core.common as com @@ -92,22 +94,44 @@ def wrapper(self, other): self._check_compatible_with(other) result = ordinal_op(other.ordinal) - elif isinstance(other, cls): - self._check_compatible_with(other) - - result = ordinal_op(other.asi8) - - mask = self._isnan | other._isnan - if mask.any(): - result[mask] = nat_result - return result elif other is NaT: result = np.empty(len(self.asi8), dtype=bool) result.fill(nat_result) - else: + + elif not is_list_like(other): return invalid_comparison(self, other, op) + else: + if isinstance(other, list): + # TODO: could use pd.Index to do inference? + other = np.array(other) + + if not isinstance(other, (np.ndarray, cls)): + return invalid_comparison(self, other, op) + + if is_object_dtype(other): + with np.errstate(all="ignore"): + result = ops.comp_method_OBJECT_ARRAY( + op, self.astype(object), other + ) + o_mask = isna(other) + + elif not is_period_dtype(other): + # e.g. is_timedelta64_dtype(other) + return invalid_comparison(self, other, op) + + else: + assert isinstance(other, cls), type(other) + + self._check_compatible_with(other) + + result = ordinal_op(other.asi8) + o_mask = other._isnan + + if o_mask.any(): + result[o_mask] = nat_result + if self._hasnans: result[self._isnan] = nat_result diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index ed5c6b450b05e..50040409473d2 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -107,6 +107,11 @@ def cmp_method(self, other): if is_object_dtype(self) and isinstance(other, ABCCategorical): left = type(other)(self._values, dtype=other.dtype) return op(left, other) + elif is_object_dtype(self) and isinstance(other, ExtensionArray): + # e.g. PeriodArray + with np.errstate(all="ignore"): + result = op(self.values, other) + elif is_object_dtype(self) and not isinstance(self, ABCMultiIndex): # don't pass MultiIndex with np.errstate(all="ignore"): diff --git a/pandas/tests/arithmetic/test_period.py b/pandas/tests/arithmetic/test_period.py index 3ad7a6d8e465c..6eef99a124b1a 100644 --- a/pandas/tests/arithmetic/test_period.py +++ b/pandas/tests/arithmetic/test_period.py @@ -50,6 +50,79 @@ def test_compare_invalid_scalar(self, box_with_array, scalar): parr = tm.box_expected(pi, box_with_array) assert_invalid_comparison(parr, scalar, box_with_array) + @pytest.mark.parametrize( + "other", + [ + pd.date_range("2000", periods=4).array, + pd.timedelta_range("1D", periods=4).array, + np.arange(4), + np.arange(4).astype(np.float64), + list(range(4)), + ], + ) + def test_compare_invalid_listlike(self, box_with_array, other): + pi = pd.period_range("2000", periods=4) + parr = tm.box_expected(pi, box_with_array) + assert_invalid_comparison(parr, other, box_with_array) + + @pytest.mark.parametrize("other_box", [list, np.array, lambda x: x.astype(object)]) + def test_compare_object_dtype(self, box_with_array, other_box): + pi = pd.period_range("2000", periods=5) + parr = tm.box_expected(pi, box_with_array) + + xbox = np.ndarray if box_with_array is pd.Index else box_with_array + + other = other_box(pi) + + expected = np.array([True, True, True, True, True]) + expected = tm.box_expected(expected, xbox) + + result = parr == other + tm.assert_equal(result, expected) + result = parr <= other + tm.assert_equal(result, expected) + result = parr >= other + tm.assert_equal(result, expected) + + result = parr != other + tm.assert_equal(result, ~expected) + result = parr < other + tm.assert_equal(result, ~expected) + result = parr > other + tm.assert_equal(result, ~expected) + + other = other_box(pi[::-1]) + + expected = np.array([False, False, True, False, False]) + expected = tm.box_expected(expected, xbox) + result = parr == other + tm.assert_equal(result, expected) + + expected = np.array([True, True, True, False, False]) + expected = tm.box_expected(expected, xbox) + result = parr <= other + tm.assert_equal(result, expected) + + expected = np.array([False, False, True, True, True]) + expected = tm.box_expected(expected, xbox) + result = parr >= other + tm.assert_equal(result, expected) + + expected = np.array([True, True, False, True, True]) + expected = tm.box_expected(expected, xbox) + result = parr != other + tm.assert_equal(result, expected) + + expected = np.array([True, True, False, False, False]) + expected = tm.box_expected(expected, xbox) + result = parr < other + tm.assert_equal(result, expected) + + expected = np.array([False, False, False, True, True]) + expected = tm.box_expected(expected, xbox) + result = parr > other + tm.assert_equal(result, expected) + class TestPeriodIndexComparisons: # TODO: parameterize over boxes