diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index ea2bd22cccc3d..7d6690a0dfa5a 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -833,45 +833,45 @@ def apply(self, func, *args, **kwargs): axis="", ) @Appender(_shared_docs["aggregate"]) - def aggregate(self, func_or_funcs=None, *args, **kwargs): + def aggregate(self, func=None, *args, **kwargs): _level = kwargs.pop("_level", None) - relabeling = func_or_funcs is None + relabeling = func is None columns = None - no_arg_message = "Must provide 'func_or_funcs' or named aggregation **kwargs." + no_arg_message = "Must provide 'func' or named aggregation **kwargs." if relabeling: columns = list(kwargs) if not PY36: # sort for 3.5 and earlier columns = list(sorted(columns)) - func_or_funcs = [kwargs[col] for col in columns] + func = [kwargs[col] for col in columns] kwargs = {} if not columns: raise TypeError(no_arg_message) - if isinstance(func_or_funcs, str): - return getattr(self, func_or_funcs)(*args, **kwargs) + if isinstance(func, str): + return getattr(self, func)(*args, **kwargs) - if isinstance(func_or_funcs, abc.Iterable): + if isinstance(func, abc.Iterable): # Catch instances of lists / tuples # but not the class list / tuple itself. - func_or_funcs = _maybe_mangle_lambdas(func_or_funcs) - ret = self._aggregate_multiple_funcs(func_or_funcs, (_level or 0) + 1) + func = _maybe_mangle_lambdas(func) + ret = self._aggregate_multiple_funcs(func, (_level or 0) + 1) if relabeling: ret.columns = columns else: - cyfunc = self._get_cython_func(func_or_funcs) + cyfunc = self._get_cython_func(func) if cyfunc and not args and not kwargs: return getattr(self, cyfunc)() if self.grouper.nkeys > 1: - return self._python_agg_general(func_or_funcs, *args, **kwargs) + return self._python_agg_general(func, *args, **kwargs) try: - return self._python_agg_general(func_or_funcs, *args, **kwargs) + return self._python_agg_general(func, *args, **kwargs) except Exception: - result = self._aggregate_named(func_or_funcs, *args, **kwargs) + result = self._aggregate_named(func, *args, **kwargs) index = Index(sorted(result), name=self.grouper.names[0]) ret = Series(result, index=index) @@ -1464,8 +1464,8 @@ class DataFrameGroupBy(NDFrameGroupBy): axis="", ) @Appender(_shared_docs["aggregate"]) - def aggregate(self, arg=None, *args, **kwargs): - return super().aggregate(arg, *args, **kwargs) + def aggregate(self, func=None, *args, **kwargs): + return super().aggregate(func, *args, **kwargs) agg = aggregate diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 9361408290bb1..c6104c460e0f1 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -788,7 +788,7 @@ def _find_non_overlapping_monotonic_bounds(self, key): return start, stop def get_loc( - self, key: Any, method: Optional[str] = None + self, key: Any, method: Optional[str] = None, tolerance=None ) -> Union[int, slice, np.ndarray]: """ Get integer location, slice or boolean mask for requested label. @@ -982,7 +982,7 @@ def get_indexer_for(self, target: AnyArrayLike, **kwargs) -> np.ndarray: List of indices. """ if self.is_overlapping: - return self.get_indexer_non_unique(target, **kwargs)[0] + return self.get_indexer_non_unique(target)[0] return self.get_indexer(target, **kwargs) @Appender(_index_shared_docs["get_value"] % _index_doc_kwargs) diff --git a/pandas/core/window/ewm.py b/pandas/core/window/ewm.py index 0ce6d5ddec2ad..40e6c679ba72d 100644 --- a/pandas/core/window/ewm.py +++ b/pandas/core/window/ewm.py @@ -206,8 +206,8 @@ def _constructor(self): axis="", ) @Appender(_shared_docs["aggregate"]) - def aggregate(self, arg, *args, **kwargs): - return super().aggregate(arg, *args, **kwargs) + def aggregate(self, func, *args, **kwargs): + return super().aggregate(func, *args, **kwargs) agg = aggregate diff --git a/pandas/core/window/expanding.py b/pandas/core/window/expanding.py index c43ca6b0565f3..47bd8f2ec593b 100644 --- a/pandas/core/window/expanding.py +++ b/pandas/core/window/expanding.py @@ -136,8 +136,8 @@ def _get_window(self, other=None, **kwargs): axis="", ) @Appender(_shared_docs["aggregate"]) - def aggregate(self, arg, *args, **kwargs): - return super().aggregate(arg, *args, **kwargs) + def aggregate(self, func, *args, **kwargs): + return super().aggregate(func, *args, **kwargs) agg = aggregate diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 323089b3fdf6b..a7e122fa3528f 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -901,12 +901,12 @@ def func(arg, window, min_periods=None, closed=None): axis="", ) @Appender(_shared_docs["aggregate"]) - def aggregate(self, arg, *args, **kwargs): - result, how = self._aggregate(arg, *args, **kwargs) + def aggregate(self, func, *args, **kwargs): + result, how = self._aggregate(func, *args, **kwargs) if result is None: # these must apply directly - result = arg(self) + result = func(self) return result @@ -1788,8 +1788,8 @@ def _validate_freq(self): axis="", ) @Appender(_shared_docs["aggregate"]) - def aggregate(self, arg, *args, **kwargs): - return super().aggregate(arg, *args, **kwargs) + def aggregate(self, func, *args, **kwargs): + return super().aggregate(func, *args, **kwargs) agg = aggregate diff --git a/pandas/util/_decorators.py b/pandas/util/_decorators.py index 5c7d481ff2586..8a25e511b5fc4 100644 --- a/pandas/util/_decorators.py +++ b/pandas/util/_decorators.py @@ -1,21 +1,35 @@ from functools import wraps import inspect from textwrap import dedent -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) import warnings from pandas._libs.properties import cache_readonly # noqa +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) + def deprecate( name: str, - alternative: Callable, + alternative: Callable[..., Any], version: str, alt_name: Optional[str] = None, klass: Optional[Type[Warning]] = None, stacklevel: int = 2, msg: Optional[str] = None, -) -> Callable: +) -> Callable[..., Any]: """ Return a new function that emits a deprecation warning on use. @@ -47,7 +61,7 @@ def deprecate( warning_msg = msg or "{} is deprecated, use {} instead".format(name, alt_name) @wraps(alternative) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> Callable[..., Any]: warnings.warn(warning_msg, klass, stacklevel=stacklevel) return alternative(*args, **kwargs) @@ -90,9 +104,9 @@ def wrapper(*args, **kwargs): def deprecate_kwarg( old_arg_name: str, new_arg_name: Optional[str], - mapping: Optional[Union[Dict, Callable[[Any], Any]]] = None, + mapping: Optional[Union[Dict[Any, Any], Callable[[Any], Any]]] = None, stacklevel: int = 2, -) -> Callable: +) -> Callable[..., Any]: """ Decorator to deprecate a keyword argument of a function. @@ -160,27 +174,27 @@ def deprecate_kwarg( "mapping from old to new argument values " "must be dict or callable!" ) - def _deprecate_kwarg(func): + def _deprecate_kwarg(func: F) -> F: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> Callable[..., Any]: old_arg_value = kwargs.pop(old_arg_name, None) - if new_arg_name is None and old_arg_value is not None: - msg = ( - "the '{old_name}' keyword is deprecated and will be " - "removed in a future version. " - "Please take steps to stop the use of '{old_name}'" - ).format(old_name=old_arg_name) - warnings.warn(msg, FutureWarning, stacklevel=stacklevel) - kwargs[old_arg_name] = old_arg_value - return func(*args, **kwargs) - if old_arg_value is not None: - if mapping is not None: - if hasattr(mapping, "get"): - new_arg_value = mapping.get(old_arg_value, old_arg_value) - else: + if new_arg_name is None: + msg = ( + "the '{old_name}' keyword is deprecated and will be " + "removed in a future version. " + "Please take steps to stop the use of '{old_name}'" + ).format(old_name=old_arg_name) + warnings.warn(msg, FutureWarning, stacklevel=stacklevel) + kwargs[old_arg_name] = old_arg_value + return func(*args, **kwargs) + + elif mapping is not None: + if callable(mapping): new_arg_value = mapping(old_arg_value) + else: + new_arg_value = mapping.get(old_arg_value, old_arg_value) msg = ( "the {old_name}={old_val!r} keyword is deprecated, " "use {new_name}={new_val!r} instead" @@ -198,7 +212,7 @@ def wrapper(*args, **kwargs): ).format(old_name=old_arg_name, new_name=new_arg_name) warnings.warn(msg, FutureWarning, stacklevel=stacklevel) - if kwargs.get(new_arg_name, None) is not None: + if kwargs.get(new_arg_name) is not None: msg = ( "Can only specify '{old_name}' or '{new_name}', " "not both" ).format(old_name=old_arg_name, new_name=new_arg_name) @@ -207,17 +221,17 @@ def wrapper(*args, **kwargs): kwargs[new_arg_name] = new_arg_value return func(*args, **kwargs) - return wrapper + return cast(F, wrapper) return _deprecate_kwarg def rewrite_axis_style_signature( name: str, extra_params: List[Tuple[str, Any]] -) -> Callable: - def decorate(func): +) -> Callable[..., Any]: + def decorate(func: F) -> F: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> Callable[..., Any]: return func(*args, **kwargs) kind = inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -234,8 +248,9 @@ def wrapper(*args, **kwargs): sig = inspect.Signature(params) - func.__signature__ = sig - return wrapper + # https://github.com/python/typing/issues/598 + func.__signature__ = sig # type: ignore + return cast(F, wrapper) return decorate @@ -279,18 +294,17 @@ def __init__(self, *args, **kwargs): self.params = args or kwargs - def __call__(self, func: Callable) -> Callable: + def __call__(self, func: F) -> F: func.__doc__ = func.__doc__ and func.__doc__ % self.params return func def update(self, *args, **kwargs) -> None: """ Update self.params with supplied args. - - If called, we assume self.params is a dict. """ - self.params.update(*args, **kwargs) + if isinstance(self.params, dict): + self.params.update(*args, **kwargs) class Appender: @@ -320,7 +334,7 @@ def __init__(self, addendum: Optional[str], join: str = "", indents: int = 0): self.addendum = addendum self.join = join - def __call__(self, func: Callable) -> Callable: + def __call__(self, func: F) -> F: func.__doc__ = func.__doc__ if func.__doc__ else "" self.addendum = self.addendum if self.addendum else "" docitems = [func.__doc__, self.addendum]