Skip to content

Commit

Permalink
GroupBy[Series].count() return type should be Series[int] (#966)
Browse files Browse the repository at this point in the history
* GroupBy[Series].count() return type should be Series[int]

* Use np.integer instead of np.int_

* Update pyright requirement '>=1.1.369' -> '>=1.1.374'
  • Loading branch information
chrisyeh96 authored Aug 1, 2024
1 parent 458ecb4 commit 7e6aee4
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pandas-stubs/core/groupby/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ class GroupBy(BaseGroupBy[NDFrameT]):
@overload
def all(self: GroupBy[DataFrame], skipna: bool = ...) -> DataFrame: ...
@final
def count(self) -> NDFrameT: ...
@overload
def count(self: GroupBy[Series]) -> Series[int]: ...
@overload
def count(self: GroupBy[DataFrame]) -> DataFrame: ...
@final
def mean(
self,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ mypy = "1.10.1"
pandas = "2.2.2"
pyarrow = ">=10.0.1"
pytest = ">=7.1.2"
pyright = ">=1.1.369"
pyright = ">= 1.1.374"
poethepoet = ">=0.16.5"
loguru = ">=0.6.0"
typing-extensions = ">=4.4.0"
Expand Down
1 change: 1 addition & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,7 @@ def test_types_groupby_methods() -> None:
check(assert_type(df.groupby("col1").sum(), pd.DataFrame), pd.DataFrame)
check(assert_type(df.groupby("col1").prod(), pd.DataFrame), pd.DataFrame)
check(assert_type(df.groupby("col1").sample(), pd.DataFrame), pd.DataFrame)
check(assert_type(df.groupby("col1").count(), pd.DataFrame), pd.DataFrame)
check(
assert_type(df.groupby("col1").value_counts(normalize=False), "pd.Series[int]"),
pd.Series,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,13 @@ def test_types_groupby_methods() -> None:
check(assert_type(s.groupby(level=0).idxmax(), pd.Series), pd.Series)
check(assert_type(s.groupby(level=0).idxmin(), pd.Series), pd.Series)

s2 = pd.Series(["w", "x", "y", "z"], index=[3, 4, 3, 4], dtype=str)
check(
assert_type(s2.groupby(level=0).count(), "pd.Series[int]"),
pd.Series,
np.integer,
)


def test_groupby_result() -> None:
# GH 142
Expand Down

0 comments on commit 7e6aee4

Please sign in to comment.