diff --git a/CHANGELOG b/CHANGELOG index bed93916c..2457edac6 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,14 @@ +Version 2022.07.18: + +Bug fixes: +* Look up methods properly on classes with _HAS_DYNAMIC_ATTRIBUTES. +* Handle .pyi-1 files in load_pytd.Module.is_package(). +* Adjust opcode line numbers for return statements in python 3.10+. +* Remove optimize.Factorize, which unnecessarily flattens overloaded functions. +* Fix coroutine signatures in overriding_checks. +* Handle generic types correctly in signature compatibility checks. +* Respect NoReturn annotations even when maximum depth is reached. + Version 2022.06.30: Updates: diff --git a/docs/faq.md b/docs/faq.md index 898429429..62713cb53 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -22,7 +22,7 @@ * [How do I annotate *args and **kwargs?](#how-do-i-annotate-args-and-kwargs) * [Why are signature mismatches in subclasses bad? {#signature-mismatch}](#why-are-signature-mismatches-in-subclasses-bad-signature-mismatch) - + @@ -371,7 +371,6 @@ TypeError: func() got multiple values for argument 'y' [pep-561-issue]: https://github.com/google/pytype/issues/151 [typeshed]: https://github.com/python/typeshed [typing-faq]: typing_faq.md -[why-is-pytype-taking-so-long]: #why-is-pytype-taking-so-long diff --git a/pytype/__version__.py b/pytype/__version__.py index 8cffc8b5f..b3afc0ee5 100644 --- a/pytype/__version__.py +++ b/pytype/__version__.py @@ -1,2 +1,2 @@ # pylint: skip-file -__version__ = '2022.06.30' +__version__ = '2022.07.18' diff --git a/pytype/abstract/_interpreter_function.py b/pytype/abstract/_interpreter_function.py index a802d5adb..6b66bd812 100644 --- a/pytype/abstract/_interpreter_function.py +++ b/pytype/abstract/_interpreter_function.py @@ -642,7 +642,12 @@ def call(self, node, func, args, alias_map=None, new_locals=False, not abstract_utils.func_name_is_class_init(self.name)): log.info("Maximum depth reached. Not analyzing %r", self.name) self._set_callself_maybe_missing_members() - return node, self.ctx.new_unsolvable(node) + if self.signature.annotations.get("return") == self.ctx.convert.no_return: + # TODO(b/147230757): Use all return annotations, not just NoReturn. + ret = self.signature.annotations["return"] + else: + ret = self.ctx.convert.unsolvable + return node, ret.to_variable(node) args = self._fix_args_for_unannotated_contextmanager_exit(node, func, args) args = args.simplify(node, self.ctx, self.signature) sig, substs, callargs = self._find_matching_sig(node, args, alias_map) diff --git a/pytype/annotation_utils.py b/pytype/annotation_utils.py index b85d03f86..a3b890f70 100644 --- a/pytype/annotation_utils.py +++ b/pytype/annotation_utils.py @@ -54,6 +54,14 @@ def _get_type_parameter_subst(self, node, annot, substs, instantiate_unbound): return self.ctx.convert.merge_classes(vals) def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True): + + def get_type_parameter_subst(annotation): + return self._get_type_parameter_subst(node, annotation, substs, + instantiate_unbound) + + return self._do_sub_one_annotation(node, annot, get_type_parameter_subst) + + def _do_sub_one_annotation(self, node, annot, get_type_parameter_subst_fn): """Apply type parameter substitutions to an annotation.""" # We push annotations onto 'stack' and move them to the 'done' stack as they # are processed. For each annotation, we also track an 'inner_type_keys' @@ -77,8 +85,8 @@ def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True): if cur not in late_annotations: param_strings = [] for t in utils.unique_list(self.get_type_parameters(cur)): - s = pytd_utils.Print(self._get_type_parameter_subst( - node, t, substs, instantiate_unbound).get_instance_type(node)) + s = pytd_utils.Print( + get_type_parameter_subst_fn(t).get_instance_type(node)) param_strings.append(s) expr = f"{cur.expr}[{', '.join(param_strings)}]" late_annot = abstract.LateAnnotation(expr, cur.stack, cur.ctx) @@ -107,11 +115,53 @@ def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True): late_annot.expr.split("[", 1)[0]].append(late_annot) done.append(done_annot) else: - done.append(self._get_type_parameter_subst( - node, cur, substs, instantiate_unbound)) + done.append(get_type_parameter_subst_fn(cur)) assert len(done) == 1 return done[0] + def sub_annotations_for_parameterized_class(self, cls, annotations): + """Apply type parameter substitutions to a dictionary of annotations. + + Args: + cls: ParameterizedClass that defines type parameter substitutions. + annotations: A dictionary of annotations to which type parameter + substition should be applied. + + Returns: + Annotations with type parameters substituted. + """ + assert isinstance(cls, abstract.ParameterizedClass) + formal_type_parameters = cls.get_formal_type_parameters() + + def get_type_parameter_subst(annotation): + assert isinstance(annotation, abstract.TypeParameter) + # Normally the type parameter module is set correctly at this point. + # Except for the case when a method that references this type parameter + # is inherited in a subclass that does not specialize this parameter: + # class A(Generic[T]): + # def f(self, t: T): ... + # + # class B(Generic[T], A[T]): + # pass + # + # class C(B[int]): ... + # In this case t in A[T].f will be annotated with T with no module set, + # since we don't know the correct module until T is specialized in + # B[int]. + annotation = annotation.with_module(cls.full_name) + # Method parameter can be annotated with a typevar that doesn't + # belong to the class template: + # class A(Generic[T]): + # def f(self, t: U): ... + # In this case we return it as is. + return formal_type_parameters.get(annotation.full_name, annotation) + + return { + name: self._do_sub_one_annotation(self.ctx.root_node, annot, + get_type_parameter_subst) + for name, annot in annotations.items() + } + def get_late_annotations(self, annot): if annot.is_late_annotation() and not annot.resolved: yield annot diff --git a/pytype/convert.py b/pytype/convert.py index 8185a4205..9be7dcd48 100644 --- a/pytype/convert.py +++ b/pytype/convert.py @@ -615,6 +615,8 @@ def _special_constant_to_value(self, name): return self.function_type elif name == "types.NoneType": return self.none_type + elif name == "types.CodeType": + return self.primitive_classes[types.CodeType] else: return None diff --git a/pytype/directors/directors.py b/pytype/directors/directors.py index 5dffdc32d..883034dbc 100644 --- a/pytype/directors/directors.py +++ b/pytype/directors/directors.py @@ -280,6 +280,7 @@ def __init__(self, src_tree, errorlog, filename, disable, code): # Store function ranges and return lines to distinguish explicit and # implicit returns (the bytecode has a `RETURN None` for implcit returns). self._return_lines = set() + self.block_returns = None self._function_ranges = _BlockRanges({}) # Parse the source code for directives. self._parse_src_tree(src_tree, code) @@ -313,7 +314,8 @@ def _parse_src_tree(self, src_tree, code): else: opcode_lines = None - self._return_lines = visitor.returns + self.block_returns = visitor.block_returns + self._return_lines = visitor.block_returns.all_returns() self._function_ranges = _BlockRanges(visitor.function_ranges) for line_range, group in visitor.structured_comment_groups.items(): diff --git a/pytype/directors/parser.py b/pytype/directors/parser.py index 3c8f02698..f105898ef 100644 --- a/pytype/directors/parser.py +++ b/pytype/directors/parser.py @@ -32,6 +32,9 @@ class LineRange: def from_node(cls, node): return cls(node.lineno, node.end_lineno) + def __contains__(self, line): + return self.start_line <= line <= self.end_line + @dataclasses.dataclass(frozen=True) class Call(LineRange): @@ -66,6 +69,44 @@ class _SourceTree: structured_comments: Mapping[int, Sequence[_StructuredComment]] +class BlockReturns: + """Tracks return statements in with/try blocks.""" + + def __init__(self): + self._block_ranges = [] + self._returns = [] + self._block_returns = {} + self._final = False + + def add_block(self, node): + line_range = LineRange.from_node(node) + self._block_ranges.append(line_range) + + def add_return(self, node): + self._returns.append(node.lineno) + + def finalize(self): + for br in self._block_ranges: + self._block_returns[br.start_line] = sorted( + r for r in self._returns if r in br + ) + self._final = True + + def all_returns(self): + return set(self._returns) + + def __iter__(self): + assert self._final + return iter(self._block_returns.items()) + + def __repr__(self): + return f""" + Blocks: {self._block_ranges} + Returns: {self._returns} + {self._block_returns} + """ + + class _ParseVisitor(visitor.BaseVisitor): """Visitor for parsing a source tree. @@ -97,8 +138,9 @@ def __init__(self, raw_structured_comments): self.variable_annotations = [] self.decorators = [] self.defs_start = None - self.returns = set() self.function_ranges = {} + self.block_returns = BlockReturns() + self.block_depth = 0 def _add_structured_comment_group(self, start_line, end_line, cls=LineRange): """Adds an empty _StructuredComment group with the given line range.""" @@ -171,6 +213,9 @@ def should_add(comment, group): if cls is not LineRange: group.extend(c for c in structured_comments if should_add(c, group)) + def leave_Module(self, node): + self.block_returns.finalize() + def visit_Call(self, node): self._process_structured_comments(LineRange.from_node(node), cls=Call) @@ -200,8 +245,22 @@ def visit_Try(self, node): def _visit_with(self, node): item = node.items[-1] end_lineno = (item.optional_vars or item.context_expr).end_lineno + if self.block_depth == 1: + self.block_returns.add_block(node) self._process_structured_comments(LineRange(node.lineno, end_lineno)) + def enter_With(self, node): + self.block_depth += 1 + + def leave_With(self, node): + self.block_depth -= 1 + + def enter_AsyncWith(self, node): + self.block_depth += 1 + + def leave_AsyncWith(self, node): + self.block_depth -= 1 + def visit_With(self, node): self._visit_with(node) @@ -226,8 +285,8 @@ def generic_visit(self, node): self._process_structured_comments(LineRange.from_node(node)) def visit_Return(self, node): + self.block_returns.add_return(node) self._process_structured_comments(LineRange.from_node(node)) - self.returns.add(node.lineno) def _visit_decorators(self, node): if not node.decorator_list: diff --git a/pytype/directors/parser_libcst.py b/pytype/directors/parser_libcst.py index cabb36c6a..6f20a3315 100644 --- a/pytype/directors/parser_libcst.py +++ b/pytype/directors/parser_libcst.py @@ -25,6 +25,9 @@ class LineRange: start_line: int end_line: int + def __contains__(self, line): + return self.start_line <= line <= self.end_line + @dataclasses.dataclass(frozen=True) class Call(LineRange): @@ -53,6 +56,31 @@ class _VariableAnnotation(LineRange): annotation: str +class BlockReturns: + """Tracks return statements in with/try blocks.""" + + def __init__(self): + self._block_ranges = [] + self._returns = [] + self._block_returns = {} + + def add_return(self, pos): + self._returns.append(pos.start.line) + + def all_returns(self): + return set(self._returns) + + def __iter__(self): + return iter(self._block_returns.items()) + + def __repr__(self): + return f""" + Blocks: {self._block_ranges} + Returns: {self._returns} + {self._block_returns} + """ + + class _ParseVisitor(libcst.CSTVisitor): """Visitor for parsing a source tree. @@ -80,8 +108,8 @@ def __init__(self): self.variable_annotations = [] self.decorators = [] self.defs_start = None - self.returns = set() self.function_ranges = {} + self.block_returns = BlockReturns() def _get_containing_groups(self, start_line, end_line=None): """Get _StructuredComment groups that fully contain the given line range.""" @@ -240,7 +268,7 @@ def visit_AnnAssign(self, node): _VariableAnnotation(pos.start.line, pos.end.line, annotation)) def visit_Return(self, node): - self.returns.add(self._get_position(node).start.line) + self.block_returns.add_return(self._get_position(node)) def _visit_decorators(self, node): if not node.decorators: diff --git a/pytype/load_pytd.py b/pytype/load_pytd.py index 1f760b25b..ee955a002 100644 --- a/pytype/load_pytd.py +++ b/pytype/load_pytd.py @@ -70,8 +70,6 @@ class Module: metadata: The metadata extracted from the picked file. """ - _INIT_NAMES = ("__init__.pyi", f"__init__.{pytd_utils.PICKLE_EXT}") - # pylint: disable=redefined-outer-name def __init__(self, module_name, filename, ast, metadata=None, pickle=None, has_unresolved_pointers=True): @@ -91,7 +89,10 @@ def is_package(self): # imports_map_loader adds os.devnull entries for __init__.py files in # intermediate directories. return True - return self.filename and os.path.basename(self.filename) in self._INIT_NAMES + if self.filename: + base, _ = os.path.splitext(os.path.basename(self.filename)) + return base == "__init__" + return False class BadDependencyError(Exception): @@ -209,7 +210,7 @@ def _unpickle_module(self, module): newly_loaded_asts.append(loaded_ast) m.ast = loaded_ast.ast if loaded_ast.is_package: - init_file = f"__init__.{pytd_utils.PICKLE_EXT}" + init_file = f"__init__{pytd_utils.PICKLE_EXT}" if m.filename and os.path.basename(m.filename) != init_file: base, _ = os.path.splitext(m.filename) m.filename = os.path.join(base, init_file) diff --git a/pytype/load_pytd_test.py b/pytype/load_pytd_test.py index 979ae52ba..47e569780 100644 --- a/pytype/load_pytd_test.py +++ b/pytype/load_pytd_test.py @@ -19,6 +19,20 @@ import unittest +class ModuleTest(test_base.UnitTest): + """Tests for load_pytd.Module.""" + + def test_is_package(self): + for filename, is_package in [("foo/bar.pyi", False), + ("foo/__init__.pyi", True), + ("foo/__init__.pyi-1", True), + ("foo/__init__.pickled", True), + (os.devnull, True)]: + with self.subTest(filename=filename): + mod = load_pytd.Module(module_name=None, filename=filename, ast=None) + self.assertEqual(mod.is_package(), is_package) + + class _LoaderTest(test_base.UnitTest): @contextlib.contextmanager diff --git a/pytype/overriding_checks.py b/pytype/overriding_checks.py index 490c0de32..173b951f4 100644 --- a/pytype/overriding_checks.py +++ b/pytype/overriding_checks.py @@ -464,6 +464,32 @@ def _get_pytd_class_signature_map(cls, ctx): return method_signature_map +def _get_parameterized_class_signature_map(cls, ctx): + """Returns a map from method names to signatures for a ParameterizedClass.""" + assert isinstance(cls, abstract.ParameterizedClass) + if cls in ctx.method_signature_map: + return ctx.method_signature_map[cls] + + base_class = cls.base_cls + + if isinstance(base_class, abstract.InterpreterClass): + base_signature_map = ctx.method_signature_map[base_class] + else: + assert isinstance(base_class, abstract.PyTDClass) + base_signature_map = _get_pytd_class_signature_map(base_class, ctx) + + method_signature_map = {} + for base_method_name, base_method_signature in base_signature_map.items(): + # Replace formal type parameters with their values. + annotations = ctx.annotation_utils.sub_annotations_for_parameterized_class( + cls, base_method_signature.annotations) + method_signature_map[base_method_name] = base_method_signature._replace( + annotations=annotations) + + ctx.method_signature_map[cls] = method_signature_map + return method_signature_map + + def check_overriding_members(cls, bases, members, matcher, ctx): """Check that the method signatures of the new class match base classes.""" @@ -484,19 +510,31 @@ def check_overriding_members(cls, bases, members, matcher, ctx): assert member_name not in class_method_map class_method_map[member_name] = method - class_signature_map = { - method_name: method.signature - for method_name, method in class_method_map.items() - } + class_signature_map = {} + for method_name, method in class_method_map.items(): + if method.is_coroutine(): + annotations = dict(method.signature.annotations) + coroutine_params = { + abstract_utils.T: ctx.convert.unsolvable, + abstract_utils.T2: ctx.convert.unsolvable, + abstract_utils.V: annotations.get("return", ctx.convert.unsolvable), + } + annotations["return"] = abstract.ParameterizedClass( + ctx.convert.coroutine_type, coroutine_params, ctx) + signature = method.signature._replace(annotations=annotations) + else: + signature = method.signature + class_signature_map[method_name] = signature for base in bases: try: base_class = abstract_utils.get_atomic_value(base) except abstract_utils.ConversionError: continue - if isinstance(base_class, abstract.ParameterizedClass): - base_class = base_class.base_cls if isinstance(base_class, abstract.InterpreterClass): base_signature_map = ctx.method_signature_map[base_class] + elif isinstance(base_class, abstract.ParameterizedClass): + base_signature_map = _get_parameterized_class_signature_map( + base_class, ctx) elif isinstance(base_class, abstract.PyTDClass): base_signature_map = _get_pytd_class_signature_map(base_class, ctx) else: diff --git a/pytype/pytd/main_test.py b/pytype/pytd/main_test.py index fb4d0feaa..bdfabfd25 100644 --- a/pytype/pytd/main_test.py +++ b/pytype/pytd/main_test.py @@ -77,26 +77,6 @@ def f(x: str) -> str: ... with open(outpath) as f: self.assertMultiLineEqual(f.read(), src) - def test_optimize(self): - with file_utils.Tempdir() as d: - inpath = d.create_file("in.pytd", """ - from typing import overload - - @overload - def f(x: int) -> str: ... - @overload - def f(x: str) -> str: ... - """) - outpath = os.path.join(d.path, "out.pytd") - sys.argv = ["main.py", "--optimize", inpath, outpath] - pytd_tool.main() - with open(outpath) as f: - self.assertMultiLineEqual(f.read(), textwrap.dedent(""" - from typing import Union - - def f(x: Union[int, str]) -> str: ... - """).strip()) - if __name__ == "__main__": unittest.main() diff --git a/pytype/pytd/optimize.py b/pytype/pytd/optimize.py index ee6382ecf..044f897f0 100644 --- a/pytype/pytd/optimize.py +++ b/pytype/pytd/optimize.py @@ -307,93 +307,6 @@ def VisitUnionType(self, union): return result -class Factorize(visitors.Visitor): - """Opposite of ExpandSignatures. Factorizes cartesian products of functions. - - For example, this transforms - def f(x: int, y: int) - def f(x: int, y: float) - def f(x: float, y: int) - def f(x: float, y: float) - to - def f(x: Union[int, float], y: Union[int, float]) - """ - - def _GroupByOmittedArg(self, signatures, i): - """Group functions that are identical if you ignore one of the arguments. - - Arguments: - signatures: A list of function signatures - i: The index of the argument to ignore during comparison. - - Returns: - A list of tuples (signature, types). "signature" is a signature with - argument i omitted, "types" is the list of types that argument was - found to have. signatures that don't have argument i are represented - as (original, None). - """ - groups = {} - for sig in signatures: - if i >= len(sig.params): - # We can't omit argument i, because this signature has too few - # arguments. Represent this signature as (original, None). - groups[sig] = None - continue - if sig.params[i].mutated_type is not None: - # We can't group mutable parameters. Leave this signature alone. - groups[sig] = None - continue - - # Set type of parameter i to None - params = list(sig.params) - param_i = params[i] - params[i] = param_i.Replace(type=None) - - stripped_signature = sig.Replace(params=tuple(params)) - existing = groups.get(stripped_signature) - if existing: - existing.append(param_i.type) - else: - groups[stripped_signature] = [param_i.type] - return groups.items() - - def VisitFunction(self, f): - """Shrink a function, by factorizing cartesian products of arguments. - - Greedily groups signatures, looking at the arguments from left to right. - This algorithm is *not* optimal. But it does the right thing for the - typical cases. - - Arguments: - f: An instance of pytd.Function. If this function has more - than one signature, we will try to combine some of these signatures by - introducing union types. - - Returns: - A new, potentially optimized, instance of pytd.Function. - - """ - max_argument_count = max(len(s.params) for s in f.signatures) - signatures = f.signatures - - for i in range(max_argument_count): - new_sigs = [] - for sig, types in self._GroupByOmittedArg(signatures, i): - if types: - # One or more options for argument : - new_params = list(sig.params) - new_params[i] = sig.params[i].Replace( - type=pytd_utils.JoinTypes(types)) - sig = sig.Replace(params=tuple(new_params)) - new_sigs.append(sig) - else: - # Signature doesn't have argument , so we store the original: - new_sigs.append(sig) - signatures = new_sigs - - return f.Replace(signatures=tuple(signatures)) - - class SuperClassHierarchy: """Utility class for optimizations working with superclasses.""" @@ -962,7 +875,6 @@ def Optimize(node, node = node.Visit(RemoveDuplicates()) node = node.Visit(SimplifyUnions()) node = node.Visit(CombineReturnsAndExceptions()) - node = node.Visit(Factorize()) node = node.Visit(CombineContainers()) node = node.Visit(SimplifyContainers()) if builtins: diff --git a/pytype/pytd/optimize_test.py b/pytype/pytd/optimize_test.py index eb34e13ba..051a402b5 100644 --- a/pytype/pytd/optimize_test.py +++ b/pytype/pytd/optimize_test.py @@ -225,46 +225,6 @@ def test_simplify_unions(self): self.ApplyVisitorToString(src, optimize.SimplifyUnions()), new_src) - def test_factorize(self): - src = pytd_src(""" - def foo(a: int) -> file: ... - def foo(a: int, x: complex) -> file: ... - def foo(a: int, x: str) -> file: ... - def foo(a: float, x: complex) -> file: ... - def foo(a: float, x: str) -> file: ... - def foo(a: int, x: file, *args) -> file: ... - """) - new_src = pytd_src(""" - def foo(a: int) -> file: ... - def foo(a: float, x: Union[complex, str]) -> file: ... - def foo(a: int, x: file, *args) -> file: ... - """) - self.AssertSourceEquals( - self.ApplyVisitorToString(src, optimize.Factorize()), new_src) - - def test_factorize_mutable(self): - src = pytd_src(""" - def foo(a: list[bool], b: X) -> file: - a = list[int] - def foo(a: list[bool], b: Y) -> file: - a = list[int] - # not groupable: - def bar(a: int, b: list[int]) -> file: - b = list[complex] - def bar(a: int, b: list[float]) -> file: - b = list[str] - """) - new_src = pytd_src(""" - def foo(a: list[bool], b: Union[X, Y]) -> file: - a = list[int] - def bar(a: int, b: list[int]) -> file: - b = list[complex] - def bar(a: int, b: list[float]) -> file: - b = list[str] - """) - self.AssertSourceEquals( - self.ApplyVisitorToString(src, optimize.Factorize()), new_src) - def test_builtin_superclasses(self): src = pytd_src(""" def f(x: Union[list, object], y: Union[complex, memoryview]) -> Union[int, bool]: ... @@ -633,5 +593,18 @@ def ipsum(self, x: K) -> K: ... new_tree = tree.Visit(optimize.MergeTypeParameters()) self.AssertSourceEquals(new_tree, expected) + def test_overloads_not_flattened(self): + # This test checks that @overloaded functions are not flattened into a + # single signature. + src = pytd_src(""" + from typing import overload + @overload + def f(x: int) -> str: ... + @overload + def f(x: str) -> str: ... + """) + self.AssertOptimizeEquals(src, src) + + if __name__ == "__main__": unittest.main() diff --git a/pytype/stubs/builtins/builtins.pytd b/pytype/stubs/builtins/builtins.pytd index 3860ed19a..a9fe714b4 100644 --- a/pytype/stubs/builtins/builtins.pytd +++ b/pytype/stubs/builtins/builtins.pytd @@ -1011,8 +1011,200 @@ class PyCapsule(object): pass # types.CodeType, a.k.a., [type 'code'] +# Definition copied from +# https://github.com/python/typeshed/blob/master/stdlib/types.pyi class code(object): - pass + @property + def co_argcount(self) -> int: ... + if sys.version_info >= (3, 8): + @property + def co_posonlyargcount(self) -> int: ... + + @property + def co_kwonlyargcount(self) -> int: ... + @property + def co_nlocals(self) -> int: ... + @property + def co_stacksize(self) -> int: ... + @property + def co_flags(self) -> int: ... + @property + def co_code(self) -> bytes: ... + @property + def co_consts(self) -> tuple[Any, ...]: ... + @property + def co_names(self) -> tuple[str, ...]: ... + @property + def co_varnames(self) -> tuple[str, ...]: ... + @property + def co_filename(self) -> str: ... + @property + def co_name(self) -> str: ... + @property + def co_firstlineno(self) -> int: ... + @property + def co_lnotab(self) -> bytes: ... + @property + def co_freevars(self) -> tuple[str, ...]: ... + @property + def co_cellvars(self) -> tuple[str, ...]: ... + if sys.version_info >= (3, 10): + @property + def co_linetable(self) -> bytes: ... + def co_lines(self) -> Iterator[tuple[int, int, int | None]]: ... + if sys.version_info >= (3, 11): + @property + def co_exceptiontable(self) -> bytes: ... + @property + def co_qualname(self) -> str: ... + def co_positions(self) -> Iterable[tuple[int | None, int | None, int | None, int | None]]: ... + + if sys.version_info >= (3, 11): + def __init__( + self, + __argcount: int, + __posonlyargcount: int, + __kwonlyargcount: int, + __nlocals: int, + __stacksize: int, + __flags: int, + __codestring: bytes, + __constants: tuple[object, ...], + __names: tuple[str, ...], + __varnames: tuple[str, ...], + __filename: str, + __name: str, + __qualname: str, + __firstlineno: int, + __linetable: bytes, + __exceptiontable: bytes, + __freevars: tuple[str, ...] = ..., + __cellvars: tuple[str, ...] = ..., + ) -> None: ... + elif sys.version_info >= (3, 10): + def __init__( + self, + __argcount: int, + __posonlyargcount: int, + __kwonlyargcount: int, + __nlocals: int, + __stacksize: int, + __flags: int, + __codestring: bytes, + __constants: tuple[object, ...], + __names: tuple[str, ...], + __varnames: tuple[str, ...], + __filename: str, + __name: str, + __firstlineno: int, + __linetable: bytes, + __freevars: tuple[str, ...] = ..., + __cellvars: tuple[str, ...] = ..., + ) -> None: ... + elif sys.version_info >= (3, 8): + def __init__( + self, + __argcount: int, + __posonlyargcount: int, + __kwonlyargcount: int, + __nlocals: int, + __stacksize: int, + __flags: int, + __codestring: bytes, + __constants: tuple[object, ...], + __names: tuple[str, ...], + __varnames: tuple[str, ...], + __filename: str, + __name: str, + __firstlineno: int, + __lnotab: bytes, + __freevars: tuple[str, ...] = ..., + __cellvars: tuple[str, ...] = ..., + ) -> None: ... + else: + def __init__( + self, + __argcount: int, + __kwonlyargcount: int, + __nlocals: int, + __stacksize: int, + __flags: int, + __codestring: bytes, + __constants: tuple[object, ...], + __names: tuple[str, ...], + __varnames: tuple[str, ...], + __filename: str, + __name: str, + __firstlineno: int, + __lnotab: bytes, + __freevars: tuple[str, ...] = ..., + __cellvars: tuple[str, ...] = ..., + ) -> None: ... + if sys.version_info >= (3, 11): + def replace( + self, + *, + co_argcount: int = ..., + co_posonlyargcount: int = ..., + co_kwonlyargcount: int = ..., + co_nlocals: int = ..., + co_stacksize: int = ..., + co_flags: int = ..., + co_firstlineno: int = ..., + co_code: bytes = ..., + co_consts: tuple[object, ...] = ..., + co_names: tuple[str, ...] = ..., + co_varnames: tuple[str, ...] = ..., + co_freevars: tuple[str, ...] = ..., + co_cellvars: tuple[str, ...] = ..., + co_filename: str = ..., + co_name: str = ..., + co_qualname: str = ..., + co_linetable: bytes = ..., + co_exceptiontable: bytes = ..., + ) -> code: ... + elif sys.version_info >= (3, 10): + def replace( + self, + *, + co_argcount: int = ..., + co_posonlyargcount: int = ..., + co_kwonlyargcount: int = ..., + co_nlocals: int = ..., + co_stacksize: int = ..., + co_flags: int = ..., + co_firstlineno: int = ..., + co_code: bytes = ..., + co_consts: tuple[object, ...] = ..., + co_names: tuple[str, ...] = ..., + co_varnames: tuple[str, ...] = ..., + co_freevars: tuple[str, ...] = ..., + co_cellvars: tuple[str, ...] = ..., + co_filename: str = ..., + co_name: str = ..., + co_linetable: bytes = ..., + ) -> code: ... + elif sys.version_info >= (3, 8): + def replace( + self, + *, + co_argcount: int = ..., + co_posonlyargcount: int = ..., + co_kwonlyargcount: int = ..., + co_nlocals: int = ..., + co_stacksize: int = ..., + co_flags: int = ..., + co_firstlineno: int = ..., + co_code: bytes = ..., + co_consts: tuple[object, ...] = ..., + co_names: tuple[str, ...] = ..., + co_varnames: tuple[str, ...] = ..., + co_freevars: tuple[str, ...] = ..., + co_cellvars: tuple[str, ...] = ..., + co_filename: str = ..., + co_name: str = ..., + co_lnotab: bytes = ..., + ) -> code: ... class ArithmeticError(StandardError): pass diff --git a/pytype/tests/CMakeLists.txt b/pytype/tests/CMakeLists.txt index 09c54d19d..6aec1f482 100644 --- a/pytype/tests/CMakeLists.txt +++ b/pytype/tests/CMakeLists.txt @@ -667,6 +667,15 @@ py_test( .test_base ) +py_test( + NAME + test_returns + SRCS + test_returns.py + DEPS + .test_base +) + py_test( NAME test_list1 diff --git a/pytype/tests/test_builtins1.py b/pytype/tests/test_builtins1.py index f1b32959b..ec4af91e7 100644 --- a/pytype/tests/test_builtins1.py +++ b/pytype/tests/test_builtins1.py @@ -19,6 +19,7 @@ def t_testRepr1(x): def t_testRepr1(x: int) -> str: ... """) + @test_base.skip("b/238794928: Function inference will be removed.") def test_repr2(self): ty = self.Infer(""" def t_testRepr2(x): @@ -266,6 +267,7 @@ def t_testCmp(x, y): def t_testCmp(x, y) -> int: ... """) + @test_base.skip("b/238794928: Function inference will be removed.") def test_cmp_multi(self): ty = self.Infer(""" def t_testCmpMulti(x, y): diff --git a/pytype/tests/test_methods1.py b/pytype/tests/test_methods1.py index 19177a68a..83360fd97 100644 --- a/pytype/tests/test_methods1.py +++ b/pytype/tests/test_methods1.py @@ -57,6 +57,7 @@ def f(x): self.assertHasSignature(ty.Lookup("f"), (self.int,), self.int) self.assertHasSignature(ty.Lookup("f"), (self.float,), self.float) + @test_base.skip("b/238794928: Function inference will be removed.") def test_add_float(self): ty = self.Infer(""" def f(x): @@ -473,6 +474,7 @@ def f(x): """, deep=False, show_library_calls=True) self.assertHasSignature(ty.Lookup("f"), (self.int,), self.int) + @test_base.skip("b/238794928: Function inference will be removed.") def test_ambiguous_starstar(self): ty = self.Infer(""" def f(x): diff --git a/pytype/tests/test_overriding.py b/pytype/tests/test_overriding.py index 84d5eb011..c69504302 100644 --- a/pytype/tests/test_overriding.py +++ b/pytype/tests/test_overriding.py @@ -576,8 +576,7 @@ def f(self, t) -> Sequence[A]: # signature-mismatch return [A()] """) - def test_subclass_of_generic_type_mismatch(self): - # Note: we don't detect mismatch in type parameters yet. + def test_subclass_of_generic_for_builtin_types(self): self.CheckWithErrors(""" from typing import Generic, TypeVar @@ -591,11 +590,305 @@ def g(self, t: int) -> None: pass class B(A[int]): - def f(self, t: str) -> None: + def f(self, t: str) -> None: # signature-mismatch pass def g(self, t: str) -> None: # signature-mismatch pass + + class C(A[list]): + def f(self, t: list) -> None: + pass + + def g(self, t: int) -> None: + pass + """) + + def test_subclass_of_generic_for_simple_types(self): + self.CheckWithErrors(""" + from typing import Generic, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + + class A(Generic[T, U]): + def f(self, t: T) -> U: + pass + + class Y: + pass + + class X(Y): + pass + + class B(A[X, Y]): + def f(self, t: X) -> Y: + return Y() + + class C(A[X, Y]): + def f(self, t: Y) -> X: + return X() + + class D(A[Y, X]): + def f(self, t: X) -> X: # signature-mismatch + return X() + + class E(A[Y, X]): + def f(self, t: Y) -> Y: # signature-mismatch + return Y() + """) + + def test_subclass_of_generic_for_bound_types(self): + self.CheckWithErrors(""" + from typing import Generic, TypeVar + + class X: + pass + + T = TypeVar('T', bound=X) + + class A(Generic[T]): + def f(self, t: T) -> T: + return T() + + class Y(X): + pass + + class B(A[Y]): + def f(self, t: Y) -> Y: + return Y() + + class C(A[Y]): + def f(self, t: X) -> Y: + return Y() + + class D(A[Y]): + def f(self, t: Y) -> X: # signature-mismatch + return X() + """) + + def test_subclass_of_generic_match_for_generic_types(self): + self.Check(""" + from typing import Generic, List, Sequence, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + + class A(Generic[T, U]): + def f(self, t: List[T]) -> Sequence[U]: + return [] + + class X: + pass + + class Y: + pass + + class B(A[X, Y]): + def f(self, t: Sequence[X]) -> List[Y]: + return [] + + class Z(X): + pass + + class C(A[Z, X]): + def f(self, t: List[X]) -> Sequence[Z]: + return [] + """) + + def test_subclass_of_generic_mismatch_for_generic_types(self): + self.CheckWithErrors(""" + from typing import Generic, List, Sequence, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + + class A(Generic[T, U]): + def f(self, t: Sequence[T]) -> List[U]: + return [] + + class X: + pass + + class Y: + pass + + class B(A[X, Y]): + def f(self, t: List[X]) -> List[Y]: # signature-mismatch + return [] + + class C(A[X, Y]): + def f(self, t: Sequence[X]) -> Sequence[Y]: # signature-mismatch + return [] + + class Z(X): + pass + + class D(A[X, Z]): + def f(self, t: Sequence[Z]) -> List[Z]: # signature-mismatch + return [] + + class E(A[X, Z]): + def f(self, t: Sequence[X]) -> List[X]: # signature-mismatch + return [] + """) + + def test_nested_generic_types(self): + self.CheckWithErrors(""" + from typing import Callable, Generic, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + V = TypeVar('V') + + class A(Generic[T, U]): + def f(self, t: Callable[[T], U]) -> None: + pass + + class Super: + pass + + class Sub(Super): + pass + + class B(A[Sub, Super]): + def f(self, t: Callable[[Sub], Super]) -> None: + pass + + class C(A[Sub, Super]): + def f(self, t: Callable[[Super], Super]) -> None: # signature-mismatch + pass + + class D(A[Sub, Super]): + def f(self, t: Callable[[Sub], Sub]) -> None: # signature-mismatch + pass + """) + + def test_nested_generic_types2(self): + self.CheckWithErrors(""" + from typing import Callable, Generic, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + V = TypeVar('V') # not in the class template + + class A(Generic[T, U]): + def f(self, t: Callable[[T, Callable[[T], V]], U]) -> None: + pass + + class Super: + pass + + class Sub(Super): + pass + + class B(Generic[T], A[Sub, T]): + pass + + class C(B[Super]): + def f(self, t: Callable[[Sub, Callable[[Sub], V]], Super]) -> None: + pass + + class D(B[Super]): + def f(self, t: Callable[[Sub, Callable[[Super], Sub]], Super]) -> None: + pass + + class E(B[Super]): + def f(self, t: Callable[[Super, Callable[[Sub], V]], Super]) -> None: # signature-mismatch + pass + + class F(Generic[T], B[T]): + def f(self, t: Callable[[Sub, Callable[[Sub], V]], T]) -> None: + pass + + class G(Generic[T], B[T]): + def f(self, t: Callable[[Sub, Callable[[Super], Super]], T]) -> None: + pass + + class H(Generic[T], B[T]): + def f(self, t: Callable[[Super, Callable[[Sub], V]], T]) -> None: # signature-mismatch + pass + """) + + def test_subclass_of_generic_for_renamed_type_parameters(self): + self.Check(""" + from typing import Generic, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + + class A(Generic[T]): + def f(self, t: T) -> None: + pass + + class B(Generic[U], A[U]): + pass + + class X: + pass + + class C(B[X]): + def f(self, t: X) -> None: + pass + """) + + def test_subclass_of_generic_for_renamed_type_parameters2(self): + self.CheckWithErrors(""" + from typing import Generic, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + + class A(Generic[T, U]): + def f(self, t: T) -> U: + return U() + + class X: + pass + + class B(Generic[T], A[X, T]): + pass + + class Y: + pass + + class C(B[Y]): + def f(self, t: X) -> Y: + return Y() + + class D(B[Y]): + def f(self, t: X) -> X: # signature-mismatch + return X() + """) + + def test_subclass_of_generic_for_generic_method(self): + self.CheckWithErrors(""" + from typing import Generic, TypeVar + + T = TypeVar('T') + U = TypeVar('U') + + class A(Generic[T]): + def f(self, t: T, u: U) -> U: + return U() + + class Y: + pass + + class X(Y): + pass + + class B(A[X]): + def f(self, t: X, u: U) -> U: + return U() + + class C(A[X]): + def f(self, t: Y, u: U) -> U: + return U() + + class D(A[Y]): + def f(self, t: X, u: U) -> U: # signature-mismatch + return U() """) def test_varargs_match(self): @@ -816,6 +1109,28 @@ class Bar(Callable, Foo): pass """) + def test_async(self): + with self.DepTree([("foo.py", """ + class Foo: + async def f(self) -> int: + return 0 + def g(self) -> int: + return 0 + """)]): + self.CheckWithErrors(""" + import foo + class Good(foo.Foo): + async def f(self) -> int: + return 0 + class Bad(foo.Foo): + async def f(self) -> str: # signature-mismatch + return '' + # Test that we catch the non-async/async mismatch even without a + # return annotation. + async def g(self): # signature-mismatch + return 0 + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tests/test_quick2.py b/pytype/tests/test_quick2.py index 45d00d8b0..a95bba9bd 100644 --- a/pytype/tests/test_quick2.py +++ b/pytype/tests/test_quick2.py @@ -41,6 +41,26 @@ def f3(): return 42 """, pythonpath=[d.path], quick=True) + def test_noreturn(self): + self.Check(""" + from typing import NoReturn + + class A: + pass + + class B: + def _raise_notimplemented(self) -> NoReturn: + raise NotImplementedError() + def f(self, x): + if __random__: + outputs = 42 + else: + self._raise_notimplemented() + return outputs + def g(self): + outputs = self.f(A()) + """, quick=True) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tests/test_returns.py b/pytype/tests/test_returns.py new file mode 100644 index 000000000..362d81fd2 --- /dev/null +++ b/pytype/tests/test_returns.py @@ -0,0 +1,63 @@ +"""Tests for bad-return-type errors.""" + +from pytype.tests import test_base + + +class TestReturns(test_base.BaseTest): + """Tests for bad-return-type.""" + + def test_implicit_none(self): + self.CheckWithErrors(""" + def f(x) -> int: + pass # bad-return-type + """) + + def test_if(self): + # NOTE(b/233047104): The implict `return None` gets reported at the end of + # the function even though there is also a correct return on that line. + self.CheckWithErrors(""" + def f(x) -> int: + if x: + pass + else: + return 10 # bad-return-type + """) + + def test_nested_if(self): + self.CheckWithErrors(""" + def f(x) -> int: + if x: + if __random__: + pass + else: + return 'a' # bad-return-type + else: + return 10 + pass # bad-return-type + """) + + def test_with(self): + self.CheckWithErrors(""" + def f(x) -> int: + with open('foo'): + if __random__: + pass + else: + return 'a' # bad-return-type # bad-return-type + """) + + def test_nested_with(self): + self.CheckWithErrors(""" + def f(x) -> int: + with open('foo'): + if __random__: + with open('bar'): + if __random__: + pass + else: + return 'a' # bad-return-type # bad-return-type + """) + + +if __name__ == "__main__": + test_base.main() diff --git a/pytype/tests/test_stdlib2.py b/pytype/tests/test_stdlib2.py index a1cf98d6e..cbe33924c 100644 --- a/pytype/tests/test_stdlib2.py +++ b/pytype/tests/test_stdlib2.py @@ -99,6 +99,15 @@ class MyClass(fractions.Fraction): ... def foo() -> MyClass: ... """) + def test_codetype(self): + self.Check(""" + import types + class Foo: + x: types.CodeType + def set_x(self): + self.x = compile('', '', '') + """) + class StdlibTestsFeatures(test_base.BaseTest, test_utils.TestCollectionsMixin): diff --git a/pytype/vm.py b/pytype/vm.py index 399158d7e..2d3823ebb 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -230,6 +230,7 @@ def run_frame(self, frame, node, annotated_locals=None): can_return = False return_nodes = [] finally_tracker = vm_utils.FinallyStateTracker() + vm_utils.adjust_block_returns(frame.f_code, self._director.block_returns) for block in frame.f_code.order: state = frame.states.get(block[0]) if not state: @@ -430,6 +431,7 @@ def run_program(self, src, filename, maximum_depth): self.ctx.errorlog.ignored_type_comment(self.filename, line, self._director.type_comments[line]) code = constant_folding.optimize(code) + vm_utils.adjust_block_returns(code, self._director.block_returns) node = self.ctx.root_node.ConnectNew("init") node, f_globals, f_locals, _ = self.run_bytecode(node, code) diff --git a/pytype/vm_utils.py b/pytype/vm_utils.py index b04944529..b0e23b8da 100644 --- a/pytype/vm_utils.py +++ b/pytype/vm_utils.py @@ -127,6 +127,7 @@ class _NameInOuterClassErrorDetails(_NameErrorDetails): """Name error details for a name defined in an outer class.""" def __init__(self, attr, prefix, class_name): + super().__init__() self._attr = attr self._prefix = prefix self._class_name = class_name @@ -142,8 +143,10 @@ def to_error_message(self): class _NameInOuterFunctionErrorDetails(_NameErrorDetails): + """Name error details for a name defined in an outer function.""" def __init__(self, attr, outer_scope, inner_scope): + super().__init__() self._attr = attr self._outer_scope = outer_scope self._inner_scope = inner_scope @@ -1134,3 +1137,17 @@ def to_coroutine(state, obj, top, ctx): for b in obj.bindings: state = _binding_to_coroutine(state, b, bad_bindings, ret, top, ctx) return state, ret + + +def adjust_block_returns(code, block_returns): + """Adjust line numbers for return statements in with blocks.""" + + rets = {k: iter(v) for k, v in block_returns} + for block in code.order: + for op in block: + if op.__class__.__name__ == "RETURN_VALUE": + if op.line in rets: + lines = rets[op.line] + new_line = next(lines, None) + if new_line: + op.line = new_line