diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index cac03085..75be9578 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -176,7 +176,10 @@ class GroupBy(BaseGroupBy[NDFrameT]): @overload def all(self: GroupBy[DataFrame], skipna: bool = ...) -> DataFrame: ... @final - def count(self) -> NDFrameT: ... + @overload + def count(self: GroupBy[Series]) -> Series[int]: ... + @overload + def count(self: GroupBy[DataFrame]) -> DataFrame: ... @final def mean( self, diff --git a/pyproject.toml b/pyproject.toml index 71c22ca5..c5f8f9eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ mypy = "1.10.1" pandas = "2.2.2" pyarrow = ">=10.0.1" pytest = ">=7.1.2" -pyright = ">=1.1.369" +pyright = ">= 1.1.374" poethepoet = ">=0.16.5" loguru = ">=0.6.0" typing-extensions = ">=4.4.0" diff --git a/tests/test_frame.py b/tests/test_frame.py index 8ce0dc35..55219cfa 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1025,6 +1025,7 @@ def test_types_groupby_methods() -> None: check(assert_type(df.groupby("col1").sum(), pd.DataFrame), pd.DataFrame) check(assert_type(df.groupby("col1").prod(), pd.DataFrame), pd.DataFrame) check(assert_type(df.groupby("col1").sample(), pd.DataFrame), pd.DataFrame) + check(assert_type(df.groupby("col1").count(), pd.DataFrame), pd.DataFrame) check( assert_type(df.groupby("col1").value_counts(normalize=False), "pd.Series[int]"), pd.Series, diff --git a/tests/test_series.py b/tests/test_series.py index 9c99dba4..21f36847 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -676,6 +676,13 @@ def test_types_groupby_methods() -> None: check(assert_type(s.groupby(level=0).idxmax(), pd.Series), pd.Series) check(assert_type(s.groupby(level=0).idxmin(), pd.Series), pd.Series) + s2 = pd.Series(["w", "x", "y", "z"], index=[3, 4, 3, 4], dtype=str) + check( + assert_type(s2.groupby(level=0).count(), "pd.Series[int]"), + pd.Series, + np.integer, + ) + def test_groupby_result() -> None: # GH 142