From 7e6aee4e41f8f60b4ce23df87ccfd4f39eb042ef Mon Sep 17 00:00:00 2001 From: Christopher Yeh Date: Thu, 1 Aug 2024 15:24:42 -0700 Subject: [PATCH] GroupBy[Series].count() return type should be Series[int] (#966) * GroupBy[Series].count() return type should be Series[int] * Use np.integer instead of np.int_ * Update pyright requirement '>=1.1.369' -> '>=1.1.374' --- pandas-stubs/core/groupby/groupby.pyi | 5 ++++- pyproject.toml | 2 +- tests/test_frame.py | 1 + tests/test_series.py | 7 +++++++ 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index cac03085..75be9578 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -176,7 +176,10 @@ class GroupBy(BaseGroupBy[NDFrameT]): @overload def all(self: GroupBy[DataFrame], skipna: bool = ...) -> DataFrame: ... @final - def count(self) -> NDFrameT: ... + @overload + def count(self: GroupBy[Series]) -> Series[int]: ... + @overload + def count(self: GroupBy[DataFrame]) -> DataFrame: ... @final def mean( self, diff --git a/pyproject.toml b/pyproject.toml index 71c22ca5..c5f8f9eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ mypy = "1.10.1" pandas = "2.2.2" pyarrow = ">=10.0.1" pytest = ">=7.1.2" -pyright = ">=1.1.369" +pyright = ">= 1.1.374" poethepoet = ">=0.16.5" loguru = ">=0.6.0" typing-extensions = ">=4.4.0" diff --git a/tests/test_frame.py b/tests/test_frame.py index 8ce0dc35..55219cfa 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1025,6 +1025,7 @@ def test_types_groupby_methods() -> None: check(assert_type(df.groupby("col1").sum(), pd.DataFrame), pd.DataFrame) check(assert_type(df.groupby("col1").prod(), pd.DataFrame), pd.DataFrame) check(assert_type(df.groupby("col1").sample(), pd.DataFrame), pd.DataFrame) + check(assert_type(df.groupby("col1").count(), pd.DataFrame), pd.DataFrame) check( assert_type(df.groupby("col1").value_counts(normalize=False), "pd.Series[int]"), pd.Series, diff --git a/tests/test_series.py b/tests/test_series.py index 9c99dba4..21f36847 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -676,6 +676,13 @@ def test_types_groupby_methods() -> None: check(assert_type(s.groupby(level=0).idxmax(), pd.Series), pd.Series) check(assert_type(s.groupby(level=0).idxmin(), pd.Series), pd.Series) + s2 = pd.Series(["w", "x", "y", "z"], index=[3, 4, 3, 4], dtype=str) + check( + assert_type(s2.groupby(level=0).count(), "pd.Series[int]"), + pd.Series, + np.integer, + ) + def test_groupby_result() -> None: # GH 142