Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom type narrowing with a special decorator #7870

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
binder = None # type: ConditionalTypeBinder
# Helper for type checking expressions
expr_checker = None # type: mypy.checkexpr.ExpressionChecker

# temporary container for ctn
ctns_queue = None # type:List[str]
# uniqueness check for ctn
ctns_keys = None # type: Set[str]
# custom type narrowers
ctns = None # type: List[Tuple[str, Expression]]
tscope = None # type: Scope
scope = None # type: CheckerScope
# Stack of function return types
Expand Down Expand Up @@ -231,6 +236,9 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
self.type_map = {}
self.module_refs = set()
self.pass_num = 0
self.ctns_queue = []
self.ctns_keys = set()
self.ctns = []
self.current_node_deferred = False
self.is_stub = tree.is_stub
self.is_typeshed_stub = errors.is_typeshed_file(path)
Expand Down Expand Up @@ -3379,6 +3387,11 @@ def visit_decorator(self, e: Decorator) -> None:
e.var.type = AnyType(TypeOfAny.special_form)
e.var.is_ready = True
return
elif isinstance(d, CallExpr):
assert isinstance(d.callee, RefExpr)
if d.callee.fullname == 'mypy.extern.narrow_cast':
# this function is a CTN
self.ctns_queue.append(e.func.fullname()) # type will be added later

if self.recurse_into_functions:
with self.tscope.function_scope(e.func):
Expand Down Expand Up @@ -3702,14 +3715,25 @@ def find_isinstance_check(self, node: Expression
elif is_false_literal(node):
return None, {}
elif isinstance(node, CallExpr):
expr = None
vartype = None
type = None
for name, expr_ in self.ctns:
if refers_to_fullname(node.callee, name):
expr = node.args[0]
if literal(expr) == LITERAL_TYPE:
vartype = type_map[node.args[0]]
type = get_isinstance_type(expr_, type_map)
break # name is unique
if refers_to_fullname(node.callee, 'builtins.isinstance'):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
expr = node.args[0]
if literal(expr) == LITERAL_TYPE:
vartype = type_map[expr]
type = get_isinstance_type(node.args[1], type_map)
return conditional_type_map(expr, vartype, type)
if expr and vartype and type:
return conditional_type_map(expr, vartype, type)
elif refers_to_fullname(node.callee, 'builtins.issubclass'):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
Expand Down
17 changes: 13 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
function_type, callable_type, try_getting_str_literals
)
from mypy.semanal import refers_to_fullname
import mypy.errorcodes as codes

# Type of callback user for checking individual function arguments. See
Expand Down Expand Up @@ -270,9 +271,12 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
typeddict_type = e.callee.node.typeddict_type.copy_modified(
fallback=Instance(e.callee.node, []))
return self.check_typeddict_call(typeddict_type, e.arg_kinds, e.arg_names, e.args, e)
if (isinstance(e.callee, NameExpr) and e.callee.name in ('isinstance', 'issubclass')
and len(e.args) == 2):
for typ in mypy.checker.flatten(e.args[1]):
if (isinstance(e.callee, NameExpr)
and ((e.callee.name in ('isinstance', 'issubclass') and len(e.args) == 2)
or refers_to_fullname(e.callee, 'mypy.extern.narrow_cast') and len(e.args) == 1)):
is_narrow_cast = refers_to_fullname(e.callee, 'mypy.extern.narrow_cast')
arg = e.args[1] if not is_narrow_cast else e.args[0]
for typ in mypy.checker.flatten(arg):
node = None
if isinstance(typ, NameExpr):
try:
Expand All @@ -297,6 +301,11 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
self.msg.cannot_use_function_with_type(e.callee.name, "TypedDict", e)
elif typ.node.is_newtype:
self.msg.cannot_use_function_with_type(e.callee.name, "NewType", e)
if is_narrow_cast:
ctn_name = self.chk.ctns_queue.pop()
if ctn_name not in self.chk.ctns_keys:
self.chk.ctns.append((ctn_name, arg))
self.chk.ctns_keys.add(ctn_name)
self.try_infer_partial_type(e)
type_context = None
if isinstance(e.callee, LambdaExpr):
Expand Down Expand Up @@ -3308,7 +3317,7 @@ def visit_super_expr(self, e: SuperExpr) -> Type:
self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e)
return AnyType(TypeOfAny.from_error)

for base in mro[index+1:]:
for base in mro[index + 1:]:
if e.name in base.names or base == mro[-1]:
if e.info and e.info.fallback_to_any and base == mro[-1]:
# There's an undefined base class, and we're at the end of the
Expand Down
10 changes: 10 additions & 0 deletions mypy/extern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Union, Tuple, Any, Callable


# signature is partially copy pasted from typeshed isinstance
def narrow_cast(T: Union[type, Tuple[Union[type, Tuple[Any, ...]], ...]]) \
-> Callable[..., Callable[..., bool]]:
def narrow_cast_inner(f: Callable[..., bool]) -> Callable[..., bool]:
return f # binds first argument of f to T

return narrow_cast_inner
49 changes: 49 additions & 0 deletions test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2279,3 +2279,52 @@ var = 'some string'
if isinstance(var, *(str, int)): # E: Too many arguments for "isinstance"
pass
[builtins fixtures/isinstancelist.pyi]

[case testCustomTypeNarrower]
from typing import Union, List, Tuple, Any, Callable
from mypy.extern import narrow_cast
@narrow_cast(int)
def isint(x: Union[str, int], *args) -> bool:
return isinstance(x, int)

@narrow_cast(str)
def isstr(x, name) -> bool:
return name == 'STRING'

u: Union[str, int] = 5
x: Union[str, float]
if isint(u):
reveal_type(u) # N: Revealed type is 'builtins.int'
if isinstance(u, int):
reveal_type(u) # N: Revealed type is 'builtins.int'
if isstr(x, 'STRING'):
x + ""

@narrow_cast(str)
def is_fizz_buzz(foo):
return foo in ['fizz', 'buzz']

def foobar(foo: Union[str, float]):
if foo in ['fizz', 'buzz']:
reveal_type(foo) # N: Revealed type is 'Union[builtins.str, builtins.float]'
if is_fizz_buzz(foo):
reveal_type(foo) # N: Revealed type is 'builtins.str'

@narrow_cast((str, (int,)))
def is_str_or_int(x):
return isinstance(x, (str, (int,)))

@narrow_cast((str, (list,)))
def is_str_or_list(x):
return isinstance(x, (str, (list,)))

def f(x: Union[int, str, List]) -> None:
if isinstance(x, (str, (int,))):
reveal_type(x) # N: Revealed type is 'Union[builtins.int, builtins.str]'
if is_str_or_int(x):
reveal_type(x) # N: Revealed type is 'Union[builtins.int, builtins.str]'
if isinstance(x, (str, (list,))):
reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.list[Any]]'
if is_str_or_list(x):
reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.list[Any]]'
[builtins fixtures/isinstancelist.pyi]
Empty file.
10 changes: 10 additions & 0 deletions test-data/unit/lib-stub/mypy/extern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Union, Tuple, Any, Callable


# copy pasted from typeshed
def narrow_cast(T: Union[type, Tuple[Union[type, Tuple[Any, ...]], ...]]) \
-> Callable[..., Callable[..., bool]]:
def narrow_cast_inner(f: Callable[..., bool]) -> Callable[..., bool]:
return f # binds first argument of f to T

return narrow_cast_inner