diff --git a/CHANGELOG b/CHANGELOG
index bed93916c..2457edac6 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,3 +1,14 @@
+Version 2022.07.18:
+
+Bug fixes:
+* Look up methods properly on classes with _HAS_DYNAMIC_ATTRIBUTES.
+* Handle .pyi-1 files in load_pytd.Module.is_package().
+* Adjust opcode line numbers for return statements in python 3.10+.
+* Remove optimize.Factorize, which unnecessarily flattens overloaded functions.
+* Fix coroutine signatures in overriding_checks.
+* Handle generic types correctly in signature compatibility checks.
+* Respect NoReturn annotations even when maximum depth is reached.
+
Version 2022.06.30:
Updates:
diff --git a/docs/faq.md b/docs/faq.md
index 898429429..62713cb53 100644
--- a/docs/faq.md
+++ b/docs/faq.md
@@ -22,7 +22,7 @@
* [How do I annotate *args and **kwargs
?](#how-do-i-annotate-args-and-kwargs)
* [Why are signature mismatches in subclasses bad? {#signature-mismatch}](#why-are-signature-mismatches-in-subclasses-bad-signature-mismatch)
-
+
@@ -371,7 +371,6 @@ TypeError: func() got multiple values for argument 'y'
[pep-561-issue]: https://github.com/google/pytype/issues/151
[typeshed]: https://github.com/python/typeshed
[typing-faq]: typing_faq.md
-[why-is-pytype-taking-so-long]: #why-is-pytype-taking-so-long
diff --git a/pytype/__version__.py b/pytype/__version__.py
index 8cffc8b5f..b3afc0ee5 100644
--- a/pytype/__version__.py
+++ b/pytype/__version__.py
@@ -1,2 +1,2 @@
# pylint: skip-file
-__version__ = '2022.06.30'
+__version__ = '2022.07.18'
diff --git a/pytype/abstract/_interpreter_function.py b/pytype/abstract/_interpreter_function.py
index a802d5adb..6b66bd812 100644
--- a/pytype/abstract/_interpreter_function.py
+++ b/pytype/abstract/_interpreter_function.py
@@ -642,7 +642,12 @@ def call(self, node, func, args, alias_map=None, new_locals=False,
not abstract_utils.func_name_is_class_init(self.name)):
log.info("Maximum depth reached. Not analyzing %r", self.name)
self._set_callself_maybe_missing_members()
- return node, self.ctx.new_unsolvable(node)
+ if self.signature.annotations.get("return") == self.ctx.convert.no_return:
+ # TODO(b/147230757): Use all return annotations, not just NoReturn.
+ ret = self.signature.annotations["return"]
+ else:
+ ret = self.ctx.convert.unsolvable
+ return node, ret.to_variable(node)
args = self._fix_args_for_unannotated_contextmanager_exit(node, func, args)
args = args.simplify(node, self.ctx, self.signature)
sig, substs, callargs = self._find_matching_sig(node, args, alias_map)
diff --git a/pytype/annotation_utils.py b/pytype/annotation_utils.py
index b85d03f86..a3b890f70 100644
--- a/pytype/annotation_utils.py
+++ b/pytype/annotation_utils.py
@@ -54,6 +54,14 @@ def _get_type_parameter_subst(self, node, annot, substs, instantiate_unbound):
return self.ctx.convert.merge_classes(vals)
def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True):
+
+ def get_type_parameter_subst(annotation):
+ return self._get_type_parameter_subst(node, annotation, substs,
+ instantiate_unbound)
+
+ return self._do_sub_one_annotation(node, annot, get_type_parameter_subst)
+
+ def _do_sub_one_annotation(self, node, annot, get_type_parameter_subst_fn):
"""Apply type parameter substitutions to an annotation."""
# We push annotations onto 'stack' and move them to the 'done' stack as they
# are processed. For each annotation, we also track an 'inner_type_keys'
@@ -77,8 +85,8 @@ def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True):
if cur not in late_annotations:
param_strings = []
for t in utils.unique_list(self.get_type_parameters(cur)):
- s = pytd_utils.Print(self._get_type_parameter_subst(
- node, t, substs, instantiate_unbound).get_instance_type(node))
+ s = pytd_utils.Print(
+ get_type_parameter_subst_fn(t).get_instance_type(node))
param_strings.append(s)
expr = f"{cur.expr}[{', '.join(param_strings)}]"
late_annot = abstract.LateAnnotation(expr, cur.stack, cur.ctx)
@@ -107,11 +115,53 @@ def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True):
late_annot.expr.split("[", 1)[0]].append(late_annot)
done.append(done_annot)
else:
- done.append(self._get_type_parameter_subst(
- node, cur, substs, instantiate_unbound))
+ done.append(get_type_parameter_subst_fn(cur))
assert len(done) == 1
return done[0]
+ def sub_annotations_for_parameterized_class(self, cls, annotations):
+ """Apply type parameter substitutions to a dictionary of annotations.
+
+ Args:
+ cls: ParameterizedClass that defines type parameter substitutions.
+ annotations: A dictionary of annotations to which type parameter
+ substition should be applied.
+
+ Returns:
+ Annotations with type parameters substituted.
+ """
+ assert isinstance(cls, abstract.ParameterizedClass)
+ formal_type_parameters = cls.get_formal_type_parameters()
+
+ def get_type_parameter_subst(annotation):
+ assert isinstance(annotation, abstract.TypeParameter)
+ # Normally the type parameter module is set correctly at this point.
+ # Except for the case when a method that references this type parameter
+ # is inherited in a subclass that does not specialize this parameter:
+ # class A(Generic[T]):
+ # def f(self, t: T): ...
+ #
+ # class B(Generic[T], A[T]):
+ # pass
+ #
+ # class C(B[int]): ...
+ # In this case t in A[T].f will be annotated with T with no module set,
+ # since we don't know the correct module until T is specialized in
+ # B[int].
+ annotation = annotation.with_module(cls.full_name)
+ # Method parameter can be annotated with a typevar that doesn't
+ # belong to the class template:
+ # class A(Generic[T]):
+ # def f(self, t: U): ...
+ # In this case we return it as is.
+ return formal_type_parameters.get(annotation.full_name, annotation)
+
+ return {
+ name: self._do_sub_one_annotation(self.ctx.root_node, annot,
+ get_type_parameter_subst)
+ for name, annot in annotations.items()
+ }
+
def get_late_annotations(self, annot):
if annot.is_late_annotation() and not annot.resolved:
yield annot
diff --git a/pytype/convert.py b/pytype/convert.py
index 8185a4205..9be7dcd48 100644
--- a/pytype/convert.py
+++ b/pytype/convert.py
@@ -615,6 +615,8 @@ def _special_constant_to_value(self, name):
return self.function_type
elif name == "types.NoneType":
return self.none_type
+ elif name == "types.CodeType":
+ return self.primitive_classes[types.CodeType]
else:
return None
diff --git a/pytype/directors/directors.py b/pytype/directors/directors.py
index 5dffdc32d..883034dbc 100644
--- a/pytype/directors/directors.py
+++ b/pytype/directors/directors.py
@@ -280,6 +280,7 @@ def __init__(self, src_tree, errorlog, filename, disable, code):
# Store function ranges and return lines to distinguish explicit and
# implicit returns (the bytecode has a `RETURN None` for implcit returns).
self._return_lines = set()
+ self.block_returns = None
self._function_ranges = _BlockRanges({})
# Parse the source code for directives.
self._parse_src_tree(src_tree, code)
@@ -313,7 +314,8 @@ def _parse_src_tree(self, src_tree, code):
else:
opcode_lines = None
- self._return_lines = visitor.returns
+ self.block_returns = visitor.block_returns
+ self._return_lines = visitor.block_returns.all_returns()
self._function_ranges = _BlockRanges(visitor.function_ranges)
for line_range, group in visitor.structured_comment_groups.items():
diff --git a/pytype/directors/parser.py b/pytype/directors/parser.py
index 3c8f02698..f105898ef 100644
--- a/pytype/directors/parser.py
+++ b/pytype/directors/parser.py
@@ -32,6 +32,9 @@ class LineRange:
def from_node(cls, node):
return cls(node.lineno, node.end_lineno)
+ def __contains__(self, line):
+ return self.start_line <= line <= self.end_line
+
@dataclasses.dataclass(frozen=True)
class Call(LineRange):
@@ -66,6 +69,44 @@ class _SourceTree:
structured_comments: Mapping[int, Sequence[_StructuredComment]]
+class BlockReturns:
+ """Tracks return statements in with/try blocks."""
+
+ def __init__(self):
+ self._block_ranges = []
+ self._returns = []
+ self._block_returns = {}
+ self._final = False
+
+ def add_block(self, node):
+ line_range = LineRange.from_node(node)
+ self._block_ranges.append(line_range)
+
+ def add_return(self, node):
+ self._returns.append(node.lineno)
+
+ def finalize(self):
+ for br in self._block_ranges:
+ self._block_returns[br.start_line] = sorted(
+ r for r in self._returns if r in br
+ )
+ self._final = True
+
+ def all_returns(self):
+ return set(self._returns)
+
+ def __iter__(self):
+ assert self._final
+ return iter(self._block_returns.items())
+
+ def __repr__(self):
+ return f"""
+ Blocks: {self._block_ranges}
+ Returns: {self._returns}
+ {self._block_returns}
+ """
+
+
class _ParseVisitor(visitor.BaseVisitor):
"""Visitor for parsing a source tree.
@@ -97,8 +138,9 @@ def __init__(self, raw_structured_comments):
self.variable_annotations = []
self.decorators = []
self.defs_start = None
- self.returns = set()
self.function_ranges = {}
+ self.block_returns = BlockReturns()
+ self.block_depth = 0
def _add_structured_comment_group(self, start_line, end_line, cls=LineRange):
"""Adds an empty _StructuredComment group with the given line range."""
@@ -171,6 +213,9 @@ def should_add(comment, group):
if cls is not LineRange:
group.extend(c for c in structured_comments if should_add(c, group))
+ def leave_Module(self, node):
+ self.block_returns.finalize()
+
def visit_Call(self, node):
self._process_structured_comments(LineRange.from_node(node), cls=Call)
@@ -200,8 +245,22 @@ def visit_Try(self, node):
def _visit_with(self, node):
item = node.items[-1]
end_lineno = (item.optional_vars or item.context_expr).end_lineno
+ if self.block_depth == 1:
+ self.block_returns.add_block(node)
self._process_structured_comments(LineRange(node.lineno, end_lineno))
+ def enter_With(self, node):
+ self.block_depth += 1
+
+ def leave_With(self, node):
+ self.block_depth -= 1
+
+ def enter_AsyncWith(self, node):
+ self.block_depth += 1
+
+ def leave_AsyncWith(self, node):
+ self.block_depth -= 1
+
def visit_With(self, node):
self._visit_with(node)
@@ -226,8 +285,8 @@ def generic_visit(self, node):
self._process_structured_comments(LineRange.from_node(node))
def visit_Return(self, node):
+ self.block_returns.add_return(node)
self._process_structured_comments(LineRange.from_node(node))
- self.returns.add(node.lineno)
def _visit_decorators(self, node):
if not node.decorator_list:
diff --git a/pytype/directors/parser_libcst.py b/pytype/directors/parser_libcst.py
index cabb36c6a..6f20a3315 100644
--- a/pytype/directors/parser_libcst.py
+++ b/pytype/directors/parser_libcst.py
@@ -25,6 +25,9 @@ class LineRange:
start_line: int
end_line: int
+ def __contains__(self, line):
+ return self.start_line <= line <= self.end_line
+
@dataclasses.dataclass(frozen=True)
class Call(LineRange):
@@ -53,6 +56,31 @@ class _VariableAnnotation(LineRange):
annotation: str
+class BlockReturns:
+ """Tracks return statements in with/try blocks."""
+
+ def __init__(self):
+ self._block_ranges = []
+ self._returns = []
+ self._block_returns = {}
+
+ def add_return(self, pos):
+ self._returns.append(pos.start.line)
+
+ def all_returns(self):
+ return set(self._returns)
+
+ def __iter__(self):
+ return iter(self._block_returns.items())
+
+ def __repr__(self):
+ return f"""
+ Blocks: {self._block_ranges}
+ Returns: {self._returns}
+ {self._block_returns}
+ """
+
+
class _ParseVisitor(libcst.CSTVisitor):
"""Visitor for parsing a source tree.
@@ -80,8 +108,8 @@ def __init__(self):
self.variable_annotations = []
self.decorators = []
self.defs_start = None
- self.returns = set()
self.function_ranges = {}
+ self.block_returns = BlockReturns()
def _get_containing_groups(self, start_line, end_line=None):
"""Get _StructuredComment groups that fully contain the given line range."""
@@ -240,7 +268,7 @@ def visit_AnnAssign(self, node):
_VariableAnnotation(pos.start.line, pos.end.line, annotation))
def visit_Return(self, node):
- self.returns.add(self._get_position(node).start.line)
+ self.block_returns.add_return(self._get_position(node))
def _visit_decorators(self, node):
if not node.decorators:
diff --git a/pytype/load_pytd.py b/pytype/load_pytd.py
index 1f760b25b..ee955a002 100644
--- a/pytype/load_pytd.py
+++ b/pytype/load_pytd.py
@@ -70,8 +70,6 @@ class Module:
metadata: The metadata extracted from the picked file.
"""
- _INIT_NAMES = ("__init__.pyi", f"__init__.{pytd_utils.PICKLE_EXT}")
-
# pylint: disable=redefined-outer-name
def __init__(self, module_name, filename, ast, metadata=None, pickle=None,
has_unresolved_pointers=True):
@@ -91,7 +89,10 @@ def is_package(self):
# imports_map_loader adds os.devnull entries for __init__.py files in
# intermediate directories.
return True
- return self.filename and os.path.basename(self.filename) in self._INIT_NAMES
+ if self.filename:
+ base, _ = os.path.splitext(os.path.basename(self.filename))
+ return base == "__init__"
+ return False
class BadDependencyError(Exception):
@@ -209,7 +210,7 @@ def _unpickle_module(self, module):
newly_loaded_asts.append(loaded_ast)
m.ast = loaded_ast.ast
if loaded_ast.is_package:
- init_file = f"__init__.{pytd_utils.PICKLE_EXT}"
+ init_file = f"__init__{pytd_utils.PICKLE_EXT}"
if m.filename and os.path.basename(m.filename) != init_file:
base, _ = os.path.splitext(m.filename)
m.filename = os.path.join(base, init_file)
diff --git a/pytype/load_pytd_test.py b/pytype/load_pytd_test.py
index 979ae52ba..47e569780 100644
--- a/pytype/load_pytd_test.py
+++ b/pytype/load_pytd_test.py
@@ -19,6 +19,20 @@
import unittest
+class ModuleTest(test_base.UnitTest):
+ """Tests for load_pytd.Module."""
+
+ def test_is_package(self):
+ for filename, is_package in [("foo/bar.pyi", False),
+ ("foo/__init__.pyi", True),
+ ("foo/__init__.pyi-1", True),
+ ("foo/__init__.pickled", True),
+ (os.devnull, True)]:
+ with self.subTest(filename=filename):
+ mod = load_pytd.Module(module_name=None, filename=filename, ast=None)
+ self.assertEqual(mod.is_package(), is_package)
+
+
class _LoaderTest(test_base.UnitTest):
@contextlib.contextmanager
diff --git a/pytype/overriding_checks.py b/pytype/overriding_checks.py
index 490c0de32..173b951f4 100644
--- a/pytype/overriding_checks.py
+++ b/pytype/overriding_checks.py
@@ -464,6 +464,32 @@ def _get_pytd_class_signature_map(cls, ctx):
return method_signature_map
+def _get_parameterized_class_signature_map(cls, ctx):
+ """Returns a map from method names to signatures for a ParameterizedClass."""
+ assert isinstance(cls, abstract.ParameterizedClass)
+ if cls in ctx.method_signature_map:
+ return ctx.method_signature_map[cls]
+
+ base_class = cls.base_cls
+
+ if isinstance(base_class, abstract.InterpreterClass):
+ base_signature_map = ctx.method_signature_map[base_class]
+ else:
+ assert isinstance(base_class, abstract.PyTDClass)
+ base_signature_map = _get_pytd_class_signature_map(base_class, ctx)
+
+ method_signature_map = {}
+ for base_method_name, base_method_signature in base_signature_map.items():
+ # Replace formal type parameters with their values.
+ annotations = ctx.annotation_utils.sub_annotations_for_parameterized_class(
+ cls, base_method_signature.annotations)
+ method_signature_map[base_method_name] = base_method_signature._replace(
+ annotations=annotations)
+
+ ctx.method_signature_map[cls] = method_signature_map
+ return method_signature_map
+
+
def check_overriding_members(cls, bases, members, matcher, ctx):
"""Check that the method signatures of the new class match base classes."""
@@ -484,19 +510,31 @@ def check_overriding_members(cls, bases, members, matcher, ctx):
assert member_name not in class_method_map
class_method_map[member_name] = method
- class_signature_map = {
- method_name: method.signature
- for method_name, method in class_method_map.items()
- }
+ class_signature_map = {}
+ for method_name, method in class_method_map.items():
+ if method.is_coroutine():
+ annotations = dict(method.signature.annotations)
+ coroutine_params = {
+ abstract_utils.T: ctx.convert.unsolvable,
+ abstract_utils.T2: ctx.convert.unsolvable,
+ abstract_utils.V: annotations.get("return", ctx.convert.unsolvable),
+ }
+ annotations["return"] = abstract.ParameterizedClass(
+ ctx.convert.coroutine_type, coroutine_params, ctx)
+ signature = method.signature._replace(annotations=annotations)
+ else:
+ signature = method.signature
+ class_signature_map[method_name] = signature
for base in bases:
try:
base_class = abstract_utils.get_atomic_value(base)
except abstract_utils.ConversionError:
continue
- if isinstance(base_class, abstract.ParameterizedClass):
- base_class = base_class.base_cls
if isinstance(base_class, abstract.InterpreterClass):
base_signature_map = ctx.method_signature_map[base_class]
+ elif isinstance(base_class, abstract.ParameterizedClass):
+ base_signature_map = _get_parameterized_class_signature_map(
+ base_class, ctx)
elif isinstance(base_class, abstract.PyTDClass):
base_signature_map = _get_pytd_class_signature_map(base_class, ctx)
else:
diff --git a/pytype/pytd/main_test.py b/pytype/pytd/main_test.py
index fb4d0feaa..bdfabfd25 100644
--- a/pytype/pytd/main_test.py
+++ b/pytype/pytd/main_test.py
@@ -77,26 +77,6 @@ def f(x: str) -> str: ...
with open(outpath) as f:
self.assertMultiLineEqual(f.read(), src)
- def test_optimize(self):
- with file_utils.Tempdir() as d:
- inpath = d.create_file("in.pytd", """
- from typing import overload
-
- @overload
- def f(x: int) -> str: ...
- @overload
- def f(x: str) -> str: ...
- """)
- outpath = os.path.join(d.path, "out.pytd")
- sys.argv = ["main.py", "--optimize", inpath, outpath]
- pytd_tool.main()
- with open(outpath) as f:
- self.assertMultiLineEqual(f.read(), textwrap.dedent("""
- from typing import Union
-
- def f(x: Union[int, str]) -> str: ...
- """).strip())
-
if __name__ == "__main__":
unittest.main()
diff --git a/pytype/pytd/optimize.py b/pytype/pytd/optimize.py
index ee6382ecf..044f897f0 100644
--- a/pytype/pytd/optimize.py
+++ b/pytype/pytd/optimize.py
@@ -307,93 +307,6 @@ def VisitUnionType(self, union):
return result
-class Factorize(visitors.Visitor):
- """Opposite of ExpandSignatures. Factorizes cartesian products of functions.
-
- For example, this transforms
- def f(x: int, y: int)
- def f(x: int, y: float)
- def f(x: float, y: int)
- def f(x: float, y: float)
- to
- def f(x: Union[int, float], y: Union[int, float])
- """
-
- def _GroupByOmittedArg(self, signatures, i):
- """Group functions that are identical if you ignore one of the arguments.
-
- Arguments:
- signatures: A list of function signatures
- i: The index of the argument to ignore during comparison.
-
- Returns:
- A list of tuples (signature, types). "signature" is a signature with
- argument i omitted, "types" is the list of types that argument was
- found to have. signatures that don't have argument i are represented
- as (original, None).
- """
- groups = {}
- for sig in signatures:
- if i >= len(sig.params):
- # We can't omit argument i, because this signature has too few
- # arguments. Represent this signature as (original, None).
- groups[sig] = None
- continue
- if sig.params[i].mutated_type is not None:
- # We can't group mutable parameters. Leave this signature alone.
- groups[sig] = None
- continue
-
- # Set type of parameter i to None
- params = list(sig.params)
- param_i = params[i]
- params[i] = param_i.Replace(type=None)
-
- stripped_signature = sig.Replace(params=tuple(params))
- existing = groups.get(stripped_signature)
- if existing:
- existing.append(param_i.type)
- else:
- groups[stripped_signature] = [param_i.type]
- return groups.items()
-
- def VisitFunction(self, f):
- """Shrink a function, by factorizing cartesian products of arguments.
-
- Greedily groups signatures, looking at the arguments from left to right.
- This algorithm is *not* optimal. But it does the right thing for the
- typical cases.
-
- Arguments:
- f: An instance of pytd.Function. If this function has more
- than one signature, we will try to combine some of these signatures by
- introducing union types.
-
- Returns:
- A new, potentially optimized, instance of pytd.Function.
-
- """
- max_argument_count = max(len(s.params) for s in f.signatures)
- signatures = f.signatures
-
- for i in range(max_argument_count):
- new_sigs = []
- for sig, types in self._GroupByOmittedArg(signatures, i):
- if types:
- # One or more options for argument :
- new_params = list(sig.params)
- new_params[i] = sig.params[i].Replace(
- type=pytd_utils.JoinTypes(types))
- sig = sig.Replace(params=tuple(new_params))
- new_sigs.append(sig)
- else:
- # Signature doesn't have argument , so we store the original:
- new_sigs.append(sig)
- signatures = new_sigs
-
- return f.Replace(signatures=tuple(signatures))
-
-
class SuperClassHierarchy:
"""Utility class for optimizations working with superclasses."""
@@ -962,7 +875,6 @@ def Optimize(node,
node = node.Visit(RemoveDuplicates())
node = node.Visit(SimplifyUnions())
node = node.Visit(CombineReturnsAndExceptions())
- node = node.Visit(Factorize())
node = node.Visit(CombineContainers())
node = node.Visit(SimplifyContainers())
if builtins:
diff --git a/pytype/pytd/optimize_test.py b/pytype/pytd/optimize_test.py
index eb34e13ba..051a402b5 100644
--- a/pytype/pytd/optimize_test.py
+++ b/pytype/pytd/optimize_test.py
@@ -225,46 +225,6 @@ def test_simplify_unions(self):
self.ApplyVisitorToString(src, optimize.SimplifyUnions()),
new_src)
- def test_factorize(self):
- src = pytd_src("""
- def foo(a: int) -> file: ...
- def foo(a: int, x: complex) -> file: ...
- def foo(a: int, x: str) -> file: ...
- def foo(a: float, x: complex) -> file: ...
- def foo(a: float, x: str) -> file: ...
- def foo(a: int, x: file, *args) -> file: ...
- """)
- new_src = pytd_src("""
- def foo(a: int) -> file: ...
- def foo(a: float, x: Union[complex, str]) -> file: ...
- def foo(a: int, x: file, *args) -> file: ...
- """)
- self.AssertSourceEquals(
- self.ApplyVisitorToString(src, optimize.Factorize()), new_src)
-
- def test_factorize_mutable(self):
- src = pytd_src("""
- def foo(a: list[bool], b: X) -> file:
- a = list[int]
- def foo(a: list[bool], b: Y) -> file:
- a = list[int]
- # not groupable:
- def bar(a: int, b: list[int]) -> file:
- b = list[complex]
- def bar(a: int, b: list[float]) -> file:
- b = list[str]
- """)
- new_src = pytd_src("""
- def foo(a: list[bool], b: Union[X, Y]) -> file:
- a = list[int]
- def bar(a: int, b: list[int]) -> file:
- b = list[complex]
- def bar(a: int, b: list[float]) -> file:
- b = list[str]
- """)
- self.AssertSourceEquals(
- self.ApplyVisitorToString(src, optimize.Factorize()), new_src)
-
def test_builtin_superclasses(self):
src = pytd_src("""
def f(x: Union[list, object], y: Union[complex, memoryview]) -> Union[int, bool]: ...
@@ -633,5 +593,18 @@ def ipsum(self, x: K) -> K: ...
new_tree = tree.Visit(optimize.MergeTypeParameters())
self.AssertSourceEquals(new_tree, expected)
+ def test_overloads_not_flattened(self):
+ # This test checks that @overloaded functions are not flattened into a
+ # single signature.
+ src = pytd_src("""
+ from typing import overload
+ @overload
+ def f(x: int) -> str: ...
+ @overload
+ def f(x: str) -> str: ...
+ """)
+ self.AssertOptimizeEquals(src, src)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/pytype/stubs/builtins/builtins.pytd b/pytype/stubs/builtins/builtins.pytd
index 3860ed19a..a9fe714b4 100644
--- a/pytype/stubs/builtins/builtins.pytd
+++ b/pytype/stubs/builtins/builtins.pytd
@@ -1011,8 +1011,200 @@ class PyCapsule(object):
pass
# types.CodeType, a.k.a., [type 'code']
+# Definition copied from
+# https://github.com/python/typeshed/blob/master/stdlib/types.pyi
class code(object):
- pass
+ @property
+ def co_argcount(self) -> int: ...
+ if sys.version_info >= (3, 8):
+ @property
+ def co_posonlyargcount(self) -> int: ...
+
+ @property
+ def co_kwonlyargcount(self) -> int: ...
+ @property
+ def co_nlocals(self) -> int: ...
+ @property
+ def co_stacksize(self) -> int: ...
+ @property
+ def co_flags(self) -> int: ...
+ @property
+ def co_code(self) -> bytes: ...
+ @property
+ def co_consts(self) -> tuple[Any, ...]: ...
+ @property
+ def co_names(self) -> tuple[str, ...]: ...
+ @property
+ def co_varnames(self) -> tuple[str, ...]: ...
+ @property
+ def co_filename(self) -> str: ...
+ @property
+ def co_name(self) -> str: ...
+ @property
+ def co_firstlineno(self) -> int: ...
+ @property
+ def co_lnotab(self) -> bytes: ...
+ @property
+ def co_freevars(self) -> tuple[str, ...]: ...
+ @property
+ def co_cellvars(self) -> tuple[str, ...]: ...
+ if sys.version_info >= (3, 10):
+ @property
+ def co_linetable(self) -> bytes: ...
+ def co_lines(self) -> Iterator[tuple[int, int, int | None]]: ...
+ if sys.version_info >= (3, 11):
+ @property
+ def co_exceptiontable(self) -> bytes: ...
+ @property
+ def co_qualname(self) -> str: ...
+ def co_positions(self) -> Iterable[tuple[int | None, int | None, int | None, int | None]]: ...
+
+ if sys.version_info >= (3, 11):
+ def __init__(
+ self,
+ __argcount: int,
+ __posonlyargcount: int,
+ __kwonlyargcount: int,
+ __nlocals: int,
+ __stacksize: int,
+ __flags: int,
+ __codestring: bytes,
+ __constants: tuple[object, ...],
+ __names: tuple[str, ...],
+ __varnames: tuple[str, ...],
+ __filename: str,
+ __name: str,
+ __qualname: str,
+ __firstlineno: int,
+ __linetable: bytes,
+ __exceptiontable: bytes,
+ __freevars: tuple[str, ...] = ...,
+ __cellvars: tuple[str, ...] = ...,
+ ) -> None: ...
+ elif sys.version_info >= (3, 10):
+ def __init__(
+ self,
+ __argcount: int,
+ __posonlyargcount: int,
+ __kwonlyargcount: int,
+ __nlocals: int,
+ __stacksize: int,
+ __flags: int,
+ __codestring: bytes,
+ __constants: tuple[object, ...],
+ __names: tuple[str, ...],
+ __varnames: tuple[str, ...],
+ __filename: str,
+ __name: str,
+ __firstlineno: int,
+ __linetable: bytes,
+ __freevars: tuple[str, ...] = ...,
+ __cellvars: tuple[str, ...] = ...,
+ ) -> None: ...
+ elif sys.version_info >= (3, 8):
+ def __init__(
+ self,
+ __argcount: int,
+ __posonlyargcount: int,
+ __kwonlyargcount: int,
+ __nlocals: int,
+ __stacksize: int,
+ __flags: int,
+ __codestring: bytes,
+ __constants: tuple[object, ...],
+ __names: tuple[str, ...],
+ __varnames: tuple[str, ...],
+ __filename: str,
+ __name: str,
+ __firstlineno: int,
+ __lnotab: bytes,
+ __freevars: tuple[str, ...] = ...,
+ __cellvars: tuple[str, ...] = ...,
+ ) -> None: ...
+ else:
+ def __init__(
+ self,
+ __argcount: int,
+ __kwonlyargcount: int,
+ __nlocals: int,
+ __stacksize: int,
+ __flags: int,
+ __codestring: bytes,
+ __constants: tuple[object, ...],
+ __names: tuple[str, ...],
+ __varnames: tuple[str, ...],
+ __filename: str,
+ __name: str,
+ __firstlineno: int,
+ __lnotab: bytes,
+ __freevars: tuple[str, ...] = ...,
+ __cellvars: tuple[str, ...] = ...,
+ ) -> None: ...
+ if sys.version_info >= (3, 11):
+ def replace(
+ self,
+ *,
+ co_argcount: int = ...,
+ co_posonlyargcount: int = ...,
+ co_kwonlyargcount: int = ...,
+ co_nlocals: int = ...,
+ co_stacksize: int = ...,
+ co_flags: int = ...,
+ co_firstlineno: int = ...,
+ co_code: bytes = ...,
+ co_consts: tuple[object, ...] = ...,
+ co_names: tuple[str, ...] = ...,
+ co_varnames: tuple[str, ...] = ...,
+ co_freevars: tuple[str, ...] = ...,
+ co_cellvars: tuple[str, ...] = ...,
+ co_filename: str = ...,
+ co_name: str = ...,
+ co_qualname: str = ...,
+ co_linetable: bytes = ...,
+ co_exceptiontable: bytes = ...,
+ ) -> code: ...
+ elif sys.version_info >= (3, 10):
+ def replace(
+ self,
+ *,
+ co_argcount: int = ...,
+ co_posonlyargcount: int = ...,
+ co_kwonlyargcount: int = ...,
+ co_nlocals: int = ...,
+ co_stacksize: int = ...,
+ co_flags: int = ...,
+ co_firstlineno: int = ...,
+ co_code: bytes = ...,
+ co_consts: tuple[object, ...] = ...,
+ co_names: tuple[str, ...] = ...,
+ co_varnames: tuple[str, ...] = ...,
+ co_freevars: tuple[str, ...] = ...,
+ co_cellvars: tuple[str, ...] = ...,
+ co_filename: str = ...,
+ co_name: str = ...,
+ co_linetable: bytes = ...,
+ ) -> code: ...
+ elif sys.version_info >= (3, 8):
+ def replace(
+ self,
+ *,
+ co_argcount: int = ...,
+ co_posonlyargcount: int = ...,
+ co_kwonlyargcount: int = ...,
+ co_nlocals: int = ...,
+ co_stacksize: int = ...,
+ co_flags: int = ...,
+ co_firstlineno: int = ...,
+ co_code: bytes = ...,
+ co_consts: tuple[object, ...] = ...,
+ co_names: tuple[str, ...] = ...,
+ co_varnames: tuple[str, ...] = ...,
+ co_freevars: tuple[str, ...] = ...,
+ co_cellvars: tuple[str, ...] = ...,
+ co_filename: str = ...,
+ co_name: str = ...,
+ co_lnotab: bytes = ...,
+ ) -> code: ...
class ArithmeticError(StandardError):
pass
diff --git a/pytype/tests/CMakeLists.txt b/pytype/tests/CMakeLists.txt
index 09c54d19d..6aec1f482 100644
--- a/pytype/tests/CMakeLists.txt
+++ b/pytype/tests/CMakeLists.txt
@@ -667,6 +667,15 @@ py_test(
.test_base
)
+py_test(
+ NAME
+ test_returns
+ SRCS
+ test_returns.py
+ DEPS
+ .test_base
+)
+
py_test(
NAME
test_list1
diff --git a/pytype/tests/test_builtins1.py b/pytype/tests/test_builtins1.py
index f1b32959b..ec4af91e7 100644
--- a/pytype/tests/test_builtins1.py
+++ b/pytype/tests/test_builtins1.py
@@ -19,6 +19,7 @@ def t_testRepr1(x):
def t_testRepr1(x: int) -> str: ...
""")
+ @test_base.skip("b/238794928: Function inference will be removed.")
def test_repr2(self):
ty = self.Infer("""
def t_testRepr2(x):
@@ -266,6 +267,7 @@ def t_testCmp(x, y):
def t_testCmp(x, y) -> int: ...
""")
+ @test_base.skip("b/238794928: Function inference will be removed.")
def test_cmp_multi(self):
ty = self.Infer("""
def t_testCmpMulti(x, y):
diff --git a/pytype/tests/test_methods1.py b/pytype/tests/test_methods1.py
index 19177a68a..83360fd97 100644
--- a/pytype/tests/test_methods1.py
+++ b/pytype/tests/test_methods1.py
@@ -57,6 +57,7 @@ def f(x):
self.assertHasSignature(ty.Lookup("f"), (self.int,), self.int)
self.assertHasSignature(ty.Lookup("f"), (self.float,), self.float)
+ @test_base.skip("b/238794928: Function inference will be removed.")
def test_add_float(self):
ty = self.Infer("""
def f(x):
@@ -473,6 +474,7 @@ def f(x):
""", deep=False, show_library_calls=True)
self.assertHasSignature(ty.Lookup("f"), (self.int,), self.int)
+ @test_base.skip("b/238794928: Function inference will be removed.")
def test_ambiguous_starstar(self):
ty = self.Infer("""
def f(x):
diff --git a/pytype/tests/test_overriding.py b/pytype/tests/test_overriding.py
index 84d5eb011..c69504302 100644
--- a/pytype/tests/test_overriding.py
+++ b/pytype/tests/test_overriding.py
@@ -576,8 +576,7 @@ def f(self, t) -> Sequence[A]: # signature-mismatch
return [A()]
""")
- def test_subclass_of_generic_type_mismatch(self):
- # Note: we don't detect mismatch in type parameters yet.
+ def test_subclass_of_generic_for_builtin_types(self):
self.CheckWithErrors("""
from typing import Generic, TypeVar
@@ -591,11 +590,305 @@ def g(self, t: int) -> None:
pass
class B(A[int]):
- def f(self, t: str) -> None:
+ def f(self, t: str) -> None: # signature-mismatch
pass
def g(self, t: str) -> None: # signature-mismatch
pass
+
+ class C(A[list]):
+ def f(self, t: list) -> None:
+ pass
+
+ def g(self, t: int) -> None:
+ pass
+ """)
+
+ def test_subclass_of_generic_for_simple_types(self):
+ self.CheckWithErrors("""
+ from typing import Generic, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+
+ class A(Generic[T, U]):
+ def f(self, t: T) -> U:
+ pass
+
+ class Y:
+ pass
+
+ class X(Y):
+ pass
+
+ class B(A[X, Y]):
+ def f(self, t: X) -> Y:
+ return Y()
+
+ class C(A[X, Y]):
+ def f(self, t: Y) -> X:
+ return X()
+
+ class D(A[Y, X]):
+ def f(self, t: X) -> X: # signature-mismatch
+ return X()
+
+ class E(A[Y, X]):
+ def f(self, t: Y) -> Y: # signature-mismatch
+ return Y()
+ """)
+
+ def test_subclass_of_generic_for_bound_types(self):
+ self.CheckWithErrors("""
+ from typing import Generic, TypeVar
+
+ class X:
+ pass
+
+ T = TypeVar('T', bound=X)
+
+ class A(Generic[T]):
+ def f(self, t: T) -> T:
+ return T()
+
+ class Y(X):
+ pass
+
+ class B(A[Y]):
+ def f(self, t: Y) -> Y:
+ return Y()
+
+ class C(A[Y]):
+ def f(self, t: X) -> Y:
+ return Y()
+
+ class D(A[Y]):
+ def f(self, t: Y) -> X: # signature-mismatch
+ return X()
+ """)
+
+ def test_subclass_of_generic_match_for_generic_types(self):
+ self.Check("""
+ from typing import Generic, List, Sequence, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+
+ class A(Generic[T, U]):
+ def f(self, t: List[T]) -> Sequence[U]:
+ return []
+
+ class X:
+ pass
+
+ class Y:
+ pass
+
+ class B(A[X, Y]):
+ def f(self, t: Sequence[X]) -> List[Y]:
+ return []
+
+ class Z(X):
+ pass
+
+ class C(A[Z, X]):
+ def f(self, t: List[X]) -> Sequence[Z]:
+ return []
+ """)
+
+ def test_subclass_of_generic_mismatch_for_generic_types(self):
+ self.CheckWithErrors("""
+ from typing import Generic, List, Sequence, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+
+ class A(Generic[T, U]):
+ def f(self, t: Sequence[T]) -> List[U]:
+ return []
+
+ class X:
+ pass
+
+ class Y:
+ pass
+
+ class B(A[X, Y]):
+ def f(self, t: List[X]) -> List[Y]: # signature-mismatch
+ return []
+
+ class C(A[X, Y]):
+ def f(self, t: Sequence[X]) -> Sequence[Y]: # signature-mismatch
+ return []
+
+ class Z(X):
+ pass
+
+ class D(A[X, Z]):
+ def f(self, t: Sequence[Z]) -> List[Z]: # signature-mismatch
+ return []
+
+ class E(A[X, Z]):
+ def f(self, t: Sequence[X]) -> List[X]: # signature-mismatch
+ return []
+ """)
+
+ def test_nested_generic_types(self):
+ self.CheckWithErrors("""
+ from typing import Callable, Generic, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+ V = TypeVar('V')
+
+ class A(Generic[T, U]):
+ def f(self, t: Callable[[T], U]) -> None:
+ pass
+
+ class Super:
+ pass
+
+ class Sub(Super):
+ pass
+
+ class B(A[Sub, Super]):
+ def f(self, t: Callable[[Sub], Super]) -> None:
+ pass
+
+ class C(A[Sub, Super]):
+ def f(self, t: Callable[[Super], Super]) -> None: # signature-mismatch
+ pass
+
+ class D(A[Sub, Super]):
+ def f(self, t: Callable[[Sub], Sub]) -> None: # signature-mismatch
+ pass
+ """)
+
+ def test_nested_generic_types2(self):
+ self.CheckWithErrors("""
+ from typing import Callable, Generic, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+ V = TypeVar('V') # not in the class template
+
+ class A(Generic[T, U]):
+ def f(self, t: Callable[[T, Callable[[T], V]], U]) -> None:
+ pass
+
+ class Super:
+ pass
+
+ class Sub(Super):
+ pass
+
+ class B(Generic[T], A[Sub, T]):
+ pass
+
+ class C(B[Super]):
+ def f(self, t: Callable[[Sub, Callable[[Sub], V]], Super]) -> None:
+ pass
+
+ class D(B[Super]):
+ def f(self, t: Callable[[Sub, Callable[[Super], Sub]], Super]) -> None:
+ pass
+
+ class E(B[Super]):
+ def f(self, t: Callable[[Super, Callable[[Sub], V]], Super]) -> None: # signature-mismatch
+ pass
+
+ class F(Generic[T], B[T]):
+ def f(self, t: Callable[[Sub, Callable[[Sub], V]], T]) -> None:
+ pass
+
+ class G(Generic[T], B[T]):
+ def f(self, t: Callable[[Sub, Callable[[Super], Super]], T]) -> None:
+ pass
+
+ class H(Generic[T], B[T]):
+ def f(self, t: Callable[[Super, Callable[[Sub], V]], T]) -> None: # signature-mismatch
+ pass
+ """)
+
+ def test_subclass_of_generic_for_renamed_type_parameters(self):
+ self.Check("""
+ from typing import Generic, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+
+ class A(Generic[T]):
+ def f(self, t: T) -> None:
+ pass
+
+ class B(Generic[U], A[U]):
+ pass
+
+ class X:
+ pass
+
+ class C(B[X]):
+ def f(self, t: X) -> None:
+ pass
+ """)
+
+ def test_subclass_of_generic_for_renamed_type_parameters2(self):
+ self.CheckWithErrors("""
+ from typing import Generic, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+
+ class A(Generic[T, U]):
+ def f(self, t: T) -> U:
+ return U()
+
+ class X:
+ pass
+
+ class B(Generic[T], A[X, T]):
+ pass
+
+ class Y:
+ pass
+
+ class C(B[Y]):
+ def f(self, t: X) -> Y:
+ return Y()
+
+ class D(B[Y]):
+ def f(self, t: X) -> X: # signature-mismatch
+ return X()
+ """)
+
+ def test_subclass_of_generic_for_generic_method(self):
+ self.CheckWithErrors("""
+ from typing import Generic, TypeVar
+
+ T = TypeVar('T')
+ U = TypeVar('U')
+
+ class A(Generic[T]):
+ def f(self, t: T, u: U) -> U:
+ return U()
+
+ class Y:
+ pass
+
+ class X(Y):
+ pass
+
+ class B(A[X]):
+ def f(self, t: X, u: U) -> U:
+ return U()
+
+ class C(A[X]):
+ def f(self, t: Y, u: U) -> U:
+ return U()
+
+ class D(A[Y]):
+ def f(self, t: X, u: U) -> U: # signature-mismatch
+ return U()
""")
def test_varargs_match(self):
@@ -816,6 +1109,28 @@ class Bar(Callable, Foo):
pass
""")
+ def test_async(self):
+ with self.DepTree([("foo.py", """
+ class Foo:
+ async def f(self) -> int:
+ return 0
+ def g(self) -> int:
+ return 0
+ """)]):
+ self.CheckWithErrors("""
+ import foo
+ class Good(foo.Foo):
+ async def f(self) -> int:
+ return 0
+ class Bad(foo.Foo):
+ async def f(self) -> str: # signature-mismatch
+ return ''
+ # Test that we catch the non-async/async mismatch even without a
+ # return annotation.
+ async def g(self): # signature-mismatch
+ return 0
+ """)
+
if __name__ == "__main__":
test_base.main()
diff --git a/pytype/tests/test_quick2.py b/pytype/tests/test_quick2.py
index 45d00d8b0..a95bba9bd 100644
--- a/pytype/tests/test_quick2.py
+++ b/pytype/tests/test_quick2.py
@@ -41,6 +41,26 @@ def f3():
return 42
""", pythonpath=[d.path], quick=True)
+ def test_noreturn(self):
+ self.Check("""
+ from typing import NoReturn
+
+ class A:
+ pass
+
+ class B:
+ def _raise_notimplemented(self) -> NoReturn:
+ raise NotImplementedError()
+ def f(self, x):
+ if __random__:
+ outputs = 42
+ else:
+ self._raise_notimplemented()
+ return outputs
+ def g(self):
+ outputs = self.f(A())
+ """, quick=True)
+
if __name__ == "__main__":
test_base.main()
diff --git a/pytype/tests/test_returns.py b/pytype/tests/test_returns.py
new file mode 100644
index 000000000..362d81fd2
--- /dev/null
+++ b/pytype/tests/test_returns.py
@@ -0,0 +1,63 @@
+"""Tests for bad-return-type errors."""
+
+from pytype.tests import test_base
+
+
+class TestReturns(test_base.BaseTest):
+ """Tests for bad-return-type."""
+
+ def test_implicit_none(self):
+ self.CheckWithErrors("""
+ def f(x) -> int:
+ pass # bad-return-type
+ """)
+
+ def test_if(self):
+ # NOTE(b/233047104): The implict `return None` gets reported at the end of
+ # the function even though there is also a correct return on that line.
+ self.CheckWithErrors("""
+ def f(x) -> int:
+ if x:
+ pass
+ else:
+ return 10 # bad-return-type
+ """)
+
+ def test_nested_if(self):
+ self.CheckWithErrors("""
+ def f(x) -> int:
+ if x:
+ if __random__:
+ pass
+ else:
+ return 'a' # bad-return-type
+ else:
+ return 10
+ pass # bad-return-type
+ """)
+
+ def test_with(self):
+ self.CheckWithErrors("""
+ def f(x) -> int:
+ with open('foo'):
+ if __random__:
+ pass
+ else:
+ return 'a' # bad-return-type # bad-return-type
+ """)
+
+ def test_nested_with(self):
+ self.CheckWithErrors("""
+ def f(x) -> int:
+ with open('foo'):
+ if __random__:
+ with open('bar'):
+ if __random__:
+ pass
+ else:
+ return 'a' # bad-return-type # bad-return-type
+ """)
+
+
+if __name__ == "__main__":
+ test_base.main()
diff --git a/pytype/tests/test_stdlib2.py b/pytype/tests/test_stdlib2.py
index a1cf98d6e..cbe33924c 100644
--- a/pytype/tests/test_stdlib2.py
+++ b/pytype/tests/test_stdlib2.py
@@ -99,6 +99,15 @@ class MyClass(fractions.Fraction): ...
def foo() -> MyClass: ...
""")
+ def test_codetype(self):
+ self.Check("""
+ import types
+ class Foo:
+ x: types.CodeType
+ def set_x(self):
+ self.x = compile('', '', '')
+ """)
+
class StdlibTestsFeatures(test_base.BaseTest,
test_utils.TestCollectionsMixin):
diff --git a/pytype/vm.py b/pytype/vm.py
index 399158d7e..2d3823ebb 100644
--- a/pytype/vm.py
+++ b/pytype/vm.py
@@ -230,6 +230,7 @@ def run_frame(self, frame, node, annotated_locals=None):
can_return = False
return_nodes = []
finally_tracker = vm_utils.FinallyStateTracker()
+ vm_utils.adjust_block_returns(frame.f_code, self._director.block_returns)
for block in frame.f_code.order:
state = frame.states.get(block[0])
if not state:
@@ -430,6 +431,7 @@ def run_program(self, src, filename, maximum_depth):
self.ctx.errorlog.ignored_type_comment(self.filename, line,
self._director.type_comments[line])
code = constant_folding.optimize(code)
+ vm_utils.adjust_block_returns(code, self._director.block_returns)
node = self.ctx.root_node.ConnectNew("init")
node, f_globals, f_locals, _ = self.run_bytecode(node, code)
diff --git a/pytype/vm_utils.py b/pytype/vm_utils.py
index b04944529..b0e23b8da 100644
--- a/pytype/vm_utils.py
+++ b/pytype/vm_utils.py
@@ -127,6 +127,7 @@ class _NameInOuterClassErrorDetails(_NameErrorDetails):
"""Name error details for a name defined in an outer class."""
def __init__(self, attr, prefix, class_name):
+ super().__init__()
self._attr = attr
self._prefix = prefix
self._class_name = class_name
@@ -142,8 +143,10 @@ def to_error_message(self):
class _NameInOuterFunctionErrorDetails(_NameErrorDetails):
+ """Name error details for a name defined in an outer function."""
def __init__(self, attr, outer_scope, inner_scope):
+ super().__init__()
self._attr = attr
self._outer_scope = outer_scope
self._inner_scope = inner_scope
@@ -1134,3 +1137,17 @@ def to_coroutine(state, obj, top, ctx):
for b in obj.bindings:
state = _binding_to_coroutine(state, b, bad_bindings, ret, top, ctx)
return state, ret
+
+
+def adjust_block_returns(code, block_returns):
+ """Adjust line numbers for return statements in with blocks."""
+
+ rets = {k: iter(v) for k, v in block_returns}
+ for block in code.order:
+ for op in block:
+ if op.__class__.__name__ == "RETURN_VALUE":
+ if op.line in rets:
+ lines = rets[op.line]
+ new_line = next(lines, None)
+ if new_line:
+ op.line = new_line