Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYPING: check-untyped-defs for util._decorators #28128

Merged
merged 1 commit into from
Aug 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,45 +833,45 @@ def apply(self, func, *args, **kwargs):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, func_or_funcs=None, *args, **kwargs):
def aggregate(self, func=None, *args, **kwargs):
_level = kwargs.pop("_level", None)

relabeling = func_or_funcs is None
relabeling = func is None
columns = None
no_arg_message = "Must provide 'func_or_funcs' or named aggregation **kwargs."
no_arg_message = "Must provide 'func' or named aggregation **kwargs."
if relabeling:
columns = list(kwargs)
if not PY36:
# sort for 3.5 and earlier
columns = list(sorted(columns))

func_or_funcs = [kwargs[col] for col in columns]
func = [kwargs[col] for col in columns]
kwargs = {}
if not columns:
raise TypeError(no_arg_message)

if isinstance(func_or_funcs, str):
return getattr(self, func_or_funcs)(*args, **kwargs)
if isinstance(func, str):
return getattr(self, func)(*args, **kwargs)

if isinstance(func_or_funcs, abc.Iterable):
if isinstance(func, abc.Iterable):
# Catch instances of lists / tuples
# but not the class list / tuple itself.
func_or_funcs = _maybe_mangle_lambdas(func_or_funcs)
ret = self._aggregate_multiple_funcs(func_or_funcs, (_level or 0) + 1)
func = _maybe_mangle_lambdas(func)
ret = self._aggregate_multiple_funcs(func, (_level or 0) + 1)
if relabeling:
ret.columns = columns
else:
cyfunc = self._get_cython_func(func_or_funcs)
cyfunc = self._get_cython_func(func)
if cyfunc and not args and not kwargs:
return getattr(self, cyfunc)()

if self.grouper.nkeys > 1:
return self._python_agg_general(func_or_funcs, *args, **kwargs)
return self._python_agg_general(func, *args, **kwargs)

try:
return self._python_agg_general(func_or_funcs, *args, **kwargs)
return self._python_agg_general(func, *args, **kwargs)
except Exception:
result = self._aggregate_named(func_or_funcs, *args, **kwargs)
result = self._aggregate_named(func, *args, **kwargs)

index = Index(sorted(result), name=self.grouper.names[0])
ret = Series(result, index=index)
Expand Down Expand Up @@ -1464,8 +1464,8 @@ class DataFrameGroupBy(NDFrameGroupBy):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg=None, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func=None, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def _find_non_overlapping_monotonic_bounds(self, key):
return start, stop

def get_loc(
self, key: Any, method: Optional[str] = None
self, key: Any, method: Optional[str] = None, tolerance=None
WillAyd marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[int, slice, np.ndarray]:
"""
Get integer location, slice or boolean mask for requested label.
Expand Down Expand Up @@ -982,7 +982,7 @@ def get_indexer_for(self, target: AnyArrayLike, **kwargs) -> np.ndarray:
List of indices.
"""
if self.is_overlapping:
return self.get_indexer_non_unique(target, **kwargs)[0]
return self.get_indexer_non_unique(target)[0]
return self.get_indexer(target, **kwargs)

@Appender(_index_shared_docs["get_value"] % _index_doc_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/window/ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def _constructor(self):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/window/expanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def _get_window(self, other=None, **kwargs):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
10 changes: 5 additions & 5 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,12 @@ def func(arg, window, min_periods=None, closed=None):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
result, how = self._aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
result, how = self._aggregate(func, *args, **kwargs)
if result is None:

# these must apply directly
result = arg(self)
result = func(self)

return result

Expand Down Expand Up @@ -1788,8 +1788,8 @@ def _validate_freq(self):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, arg, *args, **kwargs):
return super().aggregate(arg, *args, **kwargs)
def aggregate(self, func, *args, **kwargs):
return super().aggregate(func, *args, **kwargs)

agg = aggregate

Expand Down
82 changes: 48 additions & 34 deletions pandas/util/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
from functools import wraps
import inspect
from textwrap import dedent
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import warnings

from pandas._libs.properties import cache_readonly # noqa

FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
WillAyd marked this conversation as resolved.
Show resolved Hide resolved


def deprecate(
name: str,
alternative: Callable,
alternative: Callable[..., Any],
version: str,
alt_name: Optional[str] = None,
klass: Optional[Type[Warning]] = None,
stacklevel: int = 2,
msg: Optional[str] = None,
) -> Callable:
) -> Callable[..., Any]:
"""
Return a new function that emits a deprecation warning on use.

Expand Down Expand Up @@ -47,7 +61,7 @@ def deprecate(
warning_msg = msg or "{} is deprecated, use {} instead".format(name, alt_name)

@wraps(alternative)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Callable[..., Any]:
warnings.warn(warning_msg, klass, stacklevel=stacklevel)
return alternative(*args, **kwargs)

Expand Down Expand Up @@ -90,9 +104,9 @@ def wrapper(*args, **kwargs):
def deprecate_kwarg(
old_arg_name: str,
new_arg_name: Optional[str],
mapping: Optional[Union[Dict, Callable[[Any], Any]]] = None,
mapping: Optional[Union[Dict[Any, Any], Callable[[Any], Any]]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think OK just to do Dict instead of Dict[Any, Any]; same thing but more readable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is the same thing but would fail --disallow-any-generics check.

although readability is important, i don't think the benefits of precise type checking can be understated.

stacklevel: int = 2,
) -> Callable:
) -> Callable[..., Any]:
"""
Decorator to deprecate a keyword argument of a function.

Expand Down Expand Up @@ -160,27 +174,27 @@ def deprecate_kwarg(
"mapping from old to new argument values " "must be dict or callable!"
)

def _deprecate_kwarg(func):
def _deprecate_kwarg(func: F) -> F:
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Callable[..., Any]:
old_arg_value = kwargs.pop(old_arg_name, None)

if new_arg_name is None and old_arg_value is not None:
msg = (
"the '{old_name}' keyword is deprecated and will be "
"removed in a future version. "
"Please take steps to stop the use of '{old_name}'"
).format(old_name=old_arg_name)
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
kwargs[old_arg_name] = old_arg_value
return func(*args, **kwargs)

if old_arg_value is not None:
if mapping is not None:
if hasattr(mapping, "get"):
new_arg_value = mapping.get(old_arg_value, old_arg_value)
else:
if new_arg_name is None:
msg = (
"the '{old_name}' keyword is deprecated and will be "
"removed in a future version. "
"Please take steps to stop the use of '{old_name}'"
).format(old_name=old_arg_name)
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
kwargs[old_arg_name] = old_arg_value
return func(*args, **kwargs)

elif mapping is not None:
if callable(mapping):
new_arg_value = mapping(old_arg_value)
else:
new_arg_value = mapping.get(old_arg_value, old_arg_value)
msg = (
"the {old_name}={old_val!r} keyword is deprecated, "
"use {new_name}={new_val!r} instead"
Expand All @@ -198,7 +212,7 @@ def wrapper(*args, **kwargs):
).format(old_name=old_arg_name, new_name=new_arg_name)

warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
if kwargs.get(new_arg_name, None) is not None:
if kwargs.get(new_arg_name) is not None:
msg = (
"Can only specify '{old_name}' or '{new_name}', " "not both"
).format(old_name=old_arg_name, new_name=new_arg_name)
Expand All @@ -207,17 +221,17 @@ def wrapper(*args, **kwargs):
kwargs[new_arg_name] = new_arg_value
return func(*args, **kwargs)

return wrapper
return cast(F, wrapper)

return _deprecate_kwarg


def rewrite_axis_style_signature(
name: str, extra_params: List[Tuple[str, Any]]
) -> Callable:
def decorate(func):
) -> Callable[..., Any]:
def decorate(func: F) -> F:
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Callable[..., Any]:
return func(*args, **kwargs)

kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
Expand All @@ -234,8 +248,9 @@ def wrapper(*args, **kwargs):

sig = inspect.Signature(params)

func.__signature__ = sig
return wrapper
# https://github.com/python/typing/issues/598
func.__signature__ = sig # type: ignore
return cast(F, wrapper)

return decorate

Expand Down Expand Up @@ -279,18 +294,17 @@ def __init__(self, *args, **kwargs):

self.params = args or kwargs

def __call__(self, func: Callable) -> Callable:
def __call__(self, func: F) -> F:
func.__doc__ = func.__doc__ and func.__doc__ % self.params
return func

def update(self, *args, **kwargs) -> None:
"""
Update self.params with supplied args.

If called, we assume self.params is a dict.
"""

self.params.update(*args, **kwargs)
if isinstance(self.params, dict):
self.params.update(*args, **kwargs)


class Appender:
Expand Down Expand Up @@ -320,7 +334,7 @@ def __init__(self, addendum: Optional[str], join: str = "", indents: int = 0):
self.addendum = addendum
self.join = join

def __call__(self, func: Callable) -> Callable:
def __call__(self, func: F) -> F:
func.__doc__ = func.__doc__ if func.__doc__ else ""
self.addendum = self.addendum if self.addendum else ""
docitems = [func.__doc__, self.addendum]
Expand Down