Skip to content

Commit

Permalink
ENH: Support mask in groupby sum (pandas-dev#48018)
Browse files Browse the repository at this point in the history
* ENH: Support mask in groupby sum

* ENH: Support mask in groupby sum

* Fix mypy

* Refactor if condition
  • Loading branch information
phofl authored and noatamir committed Nov 9, 2022
1 parent c15d075 commit 5b39d00
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ Groupby/resample/rolling
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
- Bug in :meth:`DataFrameGroupBy.cumsum` with ``skipna=False`` giving incorrect results (:issue:`46216`)
- Bug in :meth:`GroupBy.sum` with integer dtypes losing precision (:issue:`37493`)
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`46216`)
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable dtypes incorrectly altering the original data in place (:issue:`46220`)
- Bug in :meth:`DataFrame.groupby` raising error when ``None`` is in first level of :class:`MultiIndex` (:issue:`47348`)
Expand Down
6 changes: 4 additions & 2 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def group_any_all(
skipna: bool,
) -> None: ...
def group_sum(
out: np.ndarray, # complexfloating_t[:, ::1]
out: np.ndarray, # complexfloatingintuint_t[:, ::1]
counts: np.ndarray, # int64_t[::1]
values: np.ndarray, # ndarray[complexfloating_t, ndim=2]
values: np.ndarray, # ndarray[complexfloatingintuint_t, ndim=2]
labels: np.ndarray, # const intp_t[:]
mask: np.ndarray | None,
result_mask: np.ndarray | None = ...,
min_count: int = ...,
is_datetimelike: bool = ...,
) -> None: ...
Expand Down
51 changes: 44 additions & 7 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,15 @@ ctypedef fused mean_t:

ctypedef fused sum_t:
mean_t
int8_t
int16_t
int32_t
int64_t

uint8_t
uint16_t
uint32_t
uint64_t
object


Expand All @@ -523,6 +532,8 @@ def group_sum(
int64_t[::1] counts,
ndarray[sum_t, ndim=2] values,
const intp_t[::1] labels,
const uint8_t[:, :] mask,
uint8_t[:, ::1] result_mask=None,
Py_ssize_t min_count=0,
bint is_datetimelike=False,
) -> None:
Expand All @@ -535,6 +546,8 @@ def group_sum(
sum_t[:, ::1] sumx, compensation
int64_t[:, ::1] nobs
Py_ssize_t len_values = len(values), len_labels = len(labels)
bint uses_mask = mask is not None
bint isna_entry

if len_values != len_labels:
raise ValueError("len(index) != len(labels)")
Expand Down Expand Up @@ -572,7 +585,8 @@ def group_sum(
for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
out[i, j] = NAN
out[i, j] = None

else:
out[i, j] = sumx[i, j]
else:
Expand All @@ -590,11 +604,18 @@ def group_sum(
# With dt64/td64 values, values have been cast to float64
# instead if int64 for group_sum, but the logic
# is otherwise the same as in _treat_as_na
if val == val and not (
sum_t is float64_t
and is_datetimelike
and val == <float64_t>NPY_NAT
):
if uses_mask:
isna_entry = mask[i, j]
elif (sum_t is float32_t or sum_t is float64_t
or sum_t is complex64_t or sum_t is complex64_t):
# avoid warnings because of equality comparison
isna_entry = not val == val
elif sum_t is int64_t and is_datetimelike and val == NPY_NAT:
isna_entry = True
else:
isna_entry = False

if not isna_entry:
nobs[lab, j] += 1
y = val - compensation[lab, j]
t = sumx[lab, j] + y
Expand All @@ -604,7 +625,23 @@ def group_sum(
for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
out[i, j] = NAN
# if we are integer dtype, not is_datetimelike, and
# not uses_mask, then getting here implies that
# counts[i] < min_count, which means we will
# be cast to float64 and masked at the end
# of WrappedCythonOp._call_cython_op. So we can safely
# set a placeholder value in out[i, j].
if uses_mask:
result_mask[i, j] = True
elif (sum_t is float32_t or sum_t is float64_t
or sum_t is complex64_t or sum_t is complex64_t):
out[i, j] = NAN
elif sum_t is int64_t:
out[i, j] = NPY_NAT
else:
# placeholder, see above
out[i, j] = 0

else:
out[i, j] = sumx[i, j]

Expand Down
8 changes: 6 additions & 2 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
"last",
"first",
"rank",
"sum",
}

_cython_arity = {"ohlc": 4} # OHLC
Expand Down Expand Up @@ -217,7 +218,7 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
values = ensure_float64(values)

elif values.dtype.kind in ["i", "u"]:
if how in ["sum", "var", "prod", "mean", "ohlc"] or (
if how in ["var", "prod", "mean", "ohlc"] or (
self.kind == "transform" and self.has_dropped_na
):
# result may still include NaN, so we have to cast
Expand Down Expand Up @@ -578,6 +579,8 @@ def _call_cython_op(
counts=counts,
values=values,
labels=comp_ids,
mask=mask,
result_mask=result_mask,
min_count=min_count,
is_datetimelike=is_datetimelike,
)
Expand Down Expand Up @@ -613,7 +616,8 @@ def _call_cython_op(
# need to have the result set to np.nan, which may require casting,
# see GH#40767
if is_integer_dtype(result.dtype) and not is_datetimelike:
cutoff = max(1, min_count)
# Neutral value for sum is 0, so don't fill empty groups with nan
cutoff = max(0 if self.how == "sum" else 1, min_count)
empty_groups = counts < cutoff
if empty_groups.any():
if result_mask is not None and self.uses_mask():
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2808,3 +2808,24 @@ def test_single_element_list_grouping():
)
with tm.assert_produces_warning(FutureWarning, match=msg):
values, _ = next(iter(df.groupby(["a"])))


def test_groupby_sum_avoid_casting_to_float():
# GH#37493
val = 922337203685477580
df = DataFrame({"a": 1, "b": [val]})
result = df.groupby("a").sum() - val
expected = DataFrame({"b": [0]}, index=Index([1], name="a"))
tm.assert_frame_equal(result, expected)


def test_groupby_sum_support_mask(any_numeric_ea_dtype):
# GH#37493
df = DataFrame({"a": 1, "b": [1, 2, pd.NA]}, dtype=any_numeric_ea_dtype)
result = df.groupby("a").sum()
expected = DataFrame(
{"b": [3]},
index=Index([1], name="a", dtype=any_numeric_ea_dtype),
dtype=any_numeric_ea_dtype,
)
tm.assert_frame_equal(result, expected)

0 comments on commit 5b39d00

Please sign in to comment.