Skip to content

Commit

Permalink
TYPING: --check-untyped-defs util._decorators (pandas-dev#28128)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonjayhawkins authored and proost committed Dec 19, 2019
1 parent 3fd0e2b commit 33ce4fb
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 60 deletions.
30 changes: 15 additions & 15 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,45 +833,45 @@ def apply(self, func, *args, **kwargs):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, func_or_funcs=None, *args, **kwargs):
def aggregate(self, func=None, *args, **kwargs):
_level = kwargs.pop("_level", None)

relabeling = func_or_funcs is None
relabeling = func is None
columns = None
no_arg_message = "Must provide 'func_or_funcs' or named aggregation **kwargs."
no_arg_message = "Must provide 'func' or named aggregation **kwargs."
if relabeling:
columns = list(kwargs)
if not PY36:
# sort for 3.5 and earlier
columns = list(sorted(columns))

func_or_funcs = [kwargs[col] for col in columns]
func = [kwargs[col] for col in columns]
kwargs = {}
if not columns:
raise TypeError(no_arg_message)

if isinstance(func_or_funcs, str):
return getattr(self, func_or_funcs)(*args, **kwargs)
if isinstance(func, str):
return getattr(self, func)(*args, **kwargs)

if isinstance(func_or_funcs, abc.Iterable):
if isinstance(func, abc.Iterable):
# Catch instances of lists / tuples
# but not the class list / tuple itself.
func_or_funcs = _maybe_mangle_lambdas(func_or_funcs)
ret = self._aggregate_multiple_funcs(func_or_funcs, (_level or 0) + 1)
func = _maybe_mangle_lambdas(func)
ret = self._aggregate_multiple_funcs(func, (_level or 0) + 1)
if relabeling:
ret.columns = columns
else:
cyfunc = self._get_cython_func(func_or_funcs)
cyfunc = self._get_cython_func(func)
if cyfunc and not args and not kwargs:
return getattr(self, cyfunc)()

if self.grouper.nkeys > 1:
return self._python_agg_general(func_or_funcs, *args, **kwargs)
return self._python_agg_general(func, *args, **kwargs)

try:
return self._python_agg_general(func_or_funcs, *args, **kwargs)
return self._python_agg_general(func, *args, **kwargs)
except Exception:
result = self._aggregate_named(func_or_funcs, *args, **kwargs)
result = self._aggregate_named(func, *args, **kwargs)

index = Index(sorted(result), name=self.grouper.names[0])
ret = Series(result, index=index)
Expand Down Expand Up @@ -1464,8 +1464,8 @@ class DataFrameGroupBy(NDFrameGroupBy):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg=None, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func=None, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def _find_non_overlapping_monotonic_bounds(self, key):
return start, stop

def get_loc(
self, key: Any, method: Optional[str] = None
self, key: Any, method: Optional[str] = None, tolerance=None
) -> Union[int, slice, np.ndarray]:
"""
Get integer location, slice or boolean mask for requested label.
Expand Down Expand Up @@ -982,7 +982,7 @@ def get_indexer_for(self, target: AnyArrayLike, **kwargs) -> np.ndarray:
List of indices.
"""
if self.is_overlapping:
return self.get_indexer_non_unique(target, **kwargs)[0]
return self.get_indexer_non_unique(target)[0]
return self.get_indexer(target, **kwargs)

@Appender(_index_shared_docs["get_value"] % _index_doc_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/window/ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def _constructor(self):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/window/expanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def _get_window(self, other=None, **kwargs):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
10 changes: 5 additions & 5 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,12 @@ def func(arg, window, min_periods=None, closed=None):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
result, how = self._aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
result, how = self._aggregate(func, *args, **kwargs)
if result is None:

# these must apply directly
result = arg(self)
result = func(self)

return result

Expand Down Expand Up @@ -1788,8 +1788,8 @@ def _validate_freq(self):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
82 changes: 48 additions & 34 deletions pandas/util/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
from functools import wraps
import inspect
from textwrap import dedent
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import warnings

from pandas._libs.properties import cache_readonly # noqa

FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)


def deprecate(
name: str,
alternative: Callable,
alternative: Callable[..., Any],
version: str,
alt_name: Optional[str] = None,
klass: Optional[Type[Warning]] = None,
stacklevel: int = 2,
msg: Optional[str] = None,
) -> Callable:
) -> Callable[..., Any]:
"""
Return a new function that emits a deprecation warning on use.
Expand Down Expand Up @@ -47,7 +61,7 @@ def deprecate(
warning_msg = msg or "{} is deprecated, use {} instead".format(name, alt_name)

@wraps(alternative)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Callable[..., Any]:
warnings.warn(warning_msg, klass, stacklevel=stacklevel)
return alternative(*args, **kwargs)

Expand Down Expand Up @@ -90,9 +104,9 @@ def wrapper(*args, **kwargs):
def deprecate_kwarg(
old_arg_name: str,
new_arg_name: Optional[str],
mapping: Optional[Union[Dict, Callable[[Any], Any]]] = None,
mapping: Optional[Union[Dict[Any, Any], Callable[[Any], Any]]] = None,
stacklevel: int = 2,
) -> Callable:
) -> Callable[..., Any]:
"""
Decorator to deprecate a keyword argument of a function.
Expand Down Expand Up @@ -160,27 +174,27 @@ def deprecate_kwarg(
"mapping from old to new argument values " "must be dict or callable!"
)

def _deprecate_kwarg(func):
def _deprecate_kwarg(func: F) -> F:
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Callable[..., Any]:
old_arg_value = kwargs.pop(old_arg_name, None)

if new_arg_name is None and old_arg_value is not None:
msg = (
"the '{old_name}' keyword is deprecated and will be "
"removed in a future version. "
"Please take steps to stop the use of '{old_name}'"
).format(old_name=old_arg_name)
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
kwargs[old_arg_name] = old_arg_value
return func(*args, **kwargs)

if old_arg_value is not None:
if mapping is not None:
if hasattr(mapping, "get"):
new_arg_value = mapping.get(old_arg_value, old_arg_value)
else:
if new_arg_name is None:
msg = (
"the '{old_name}' keyword is deprecated and will be "
"removed in a future version. "
"Please take steps to stop the use of '{old_name}'"
).format(old_name=old_arg_name)
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
kwargs[old_arg_name] = old_arg_value
return func(*args, **kwargs)

elif mapping is not None:
if callable(mapping):
new_arg_value = mapping(old_arg_value)
else:
new_arg_value = mapping.get(old_arg_value, old_arg_value)
msg = (
"the {old_name}={old_val!r} keyword is deprecated, "
"use {new_name}={new_val!r} instead"
Expand All @@ -198,7 +212,7 @@ def wrapper(*args, **kwargs):
).format(old_name=old_arg_name, new_name=new_arg_name)

warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
if kwargs.get(new_arg_name, None) is not None:
if kwargs.get(new_arg_name) is not None:
msg = (
"Can only specify '{old_name}' or '{new_name}', " "not both"
).format(old_name=old_arg_name, new_name=new_arg_name)
Expand All @@ -207,17 +221,17 @@ def wrapper(*args, **kwargs):
kwargs[new_arg_name] = new_arg_value
return func(*args, **kwargs)

return wrapper
return cast(F, wrapper)

return _deprecate_kwarg


def rewrite_axis_style_signature(
name: str, extra_params: List[Tuple[str, Any]]
) -> Callable:
def decorate(func):
) -> Callable[..., Any]:
def decorate(func: F) -> F:
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Callable[..., Any]:
return func(*args, **kwargs)

kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
Expand All @@ -234,8 +248,9 @@ def wrapper(*args, **kwargs):

sig = inspect.Signature(params)

func.__signature__ = sig
return wrapper
# https://github.com/python/typing/issues/598
func.__signature__ = sig # type: ignore
return cast(F, wrapper)

return decorate

Expand Down Expand Up @@ -279,18 +294,17 @@ def __init__(self, *args, **kwargs):

self.params = args or kwargs

def __call__(self, func: Callable) -> Callable:
def __call__(self, func: F) -> F:
func.__doc__ = func.__doc__ and func.__doc__ % self.params
return func

def update(self, *args, **kwargs) -> None:
"""
Update self.params with supplied args.
If called, we assume self.params is a dict.
"""

self.params.update(*args, **kwargs)
if isinstance(self.params, dict):
self.params.update(*args, **kwargs)


class Appender:
Expand Down Expand Up @@ -320,7 +334,7 @@ def __init__(self, addendum: Optional[str], join: str = "", indents: int = 0):
self.addendum = addendum
self.join = join

def __call__(self, func: Callable) -> Callable:
def __call__(self, func: F) -> F:
func.__doc__ = func.__doc__ if func.__doc__ else ""
self.addendum = self.addendum if self.addendum else ""
docitems = [func.__doc__, self.addendum]
Expand Down

0 comments on commit 33ce4fb

Please sign in to comment.