From ca0e6770d129c9320a49d6f9008d4ad9e63005e2 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Wed, 6 Apr 2022 18:21:12 -0400 Subject: [PATCH] move typed class rewrite to a plugin --- pyupgrade/_data.py | 1 + pyupgrade/_main.py | 196 +--------------- pyupgrade/_plugins/typing_classes.py | 212 ++++++++++++++++++ ...d_tuple_test.py => typing_classes_test.py} | 160 ++++++++++++- tests/features/typing_typed_dict_test.py | 158 ------------- 5 files changed, 371 insertions(+), 356 deletions(-) create mode 100644 pyupgrade/_plugins/typing_classes.py rename tests/features/{typing_named_tuple_test.py => typing_classes_test.py} (50%) delete mode 100644 tests/features/typing_typed_dict_test.py diff --git a/pyupgrade/_data.py b/pyupgrade/_data.py index 37ff8eae..06d99891 100644 --- a/pyupgrade/_data.py +++ b/pyupgrade/_data.py @@ -53,6 +53,7 @@ class State(NamedTuple): 'subprocess', 'sys', 'typing', + 'typing_extensions', )) FUNCS = collections.defaultdict(list) diff --git a/pyupgrade/_main.py b/pyupgrade/_main.py index b1fc43b1..7bba452d 100644 --- a/pyupgrade/_main.py +++ b/pyupgrade/_main.py @@ -2,7 +2,6 @@ import argparse import ast -import collections import re import string import sys @@ -26,7 +25,6 @@ from pyupgrade._ast_helpers import ast_to_offset from pyupgrade._ast_helpers import contains_await from pyupgrade._ast_helpers import has_starargs -from pyupgrade._ast_helpers import is_name_attr from pyupgrade._data import FUNCS from pyupgrade._data import Settings from pyupgrade._data import Version @@ -35,7 +33,6 @@ from pyupgrade._string_helpers import is_codec from pyupgrade._string_helpers import NAMED_UNICODE_RE from pyupgrade._token_helpers import CLOSING -from pyupgrade._token_helpers import KEYWORDS from pyupgrade._token_helpers import OPENING from pyupgrade._token_helpers import parse_call_args from pyupgrade._token_helpers import remove_brace @@ -538,19 +535,8 @@ def _format_params(call: ast.Call) -> set[str]: class FindPy36Plus(ast.NodeVisitor): def __init__(self, *, min_version: Version) -> None: self.fstrings: dict[Offset, ast.Call] = {} - self.named_tuples: dict[Offset, ast.Call] = {} - self.dict_typed_dicts: dict[Offset, ast.Call] = {} - self.kw_typed_dicts: dict[Offset, ast.Call] = {} - self._from_imports: dict[str, set[str]] = collections.defaultdict(set) self.min_version = min_version - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - if node.level == 0 and node.module in {'typing', 'typing_extensions'}: - for name in node.names: - if not name.asname: - self._from_imports[node.module].add(name.name) - self.generic_visit(node) - def _parse(self, node: ast.Call) -> tuple[DotFormatPart, ...] | None: if not ( isinstance(node.func, ast.Attribute) and @@ -597,117 +583,6 @@ def visit_Call(self, node: ast.Call) -> None: self.generic_visit(node) - def visit_Assign(self, node: ast.Assign) -> None: - if ( - # NT = ...("NT", ...) - len(node.targets) == 1 and - isinstance(node.targets[0], ast.Name) and - isinstance(node.value, ast.Call) and - len(node.value.args) >= 1 and - isinstance(node.value.args[0], ast.Str) and - node.targets[0].id == node.value.args[0].s and - not has_starargs(node.value) - ): - if ( - is_name_attr( - node.value.func, - self._from_imports, - ('typing',), - ('NamedTuple',), - ) and - len(node.value.args) == 2 and - not node.value.keywords and - isinstance(node.value.args[1], (ast.List, ast.Tuple)) and - len(node.value.args[1].elts) > 0 and - all( - isinstance(tup, ast.Tuple) and - len(tup.elts) == 2 and - isinstance(tup.elts[0], ast.Str) and - tup.elts[0].s.isidentifier() and - tup.elts[0].s not in KEYWORDS - for tup in node.value.args[1].elts - ) - ): - self.named_tuples[ast_to_offset(node)] = node.value - elif ( - is_name_attr( - node.value.func, - self._from_imports, - ('typing', 'typing_extensions'), - ('TypedDict',), - ) and - len(node.value.args) == 1 and - len(node.value.keywords) > 0 and - not any( - keyword.arg == 'total' - for keyword in node.value.keywords - ) - ): - self.kw_typed_dicts[ast_to_offset(node)] = node.value - elif ( - is_name_attr( - node.value.func, - self._from_imports, - ('typing', 'typing_extensions'), - ('TypedDict',), - ) and - len(node.value.args) == 2 and - ( - not node.value.keywords or - ( - len(node.value.keywords) == 1 and - node.value.keywords[0].arg == 'total' and - isinstance( - node.value.keywords[0].value, - (ast.Constant, ast.NameConstant), - ) - ) - ) and - isinstance(node.value.args[1], ast.Dict) and - node.value.args[1].keys and - all( - isinstance(k, ast.Str) and - k.s.isidentifier() and - k.s not in KEYWORDS - for k in node.value.args[1].keys - ) - ): - self.dict_typed_dicts[ast_to_offset(node)] = node.value - - self.generic_visit(node) - - -def _unparse(node: ast.expr) -> str: - if isinstance(node, ast.Name): - return node.id - elif isinstance(node, ast.Attribute): - return ''.join((_unparse(node.value), '.', node.attr)) - elif isinstance(node, ast.Subscript): - if sys.version_info >= (3, 9): # pragma: >=3.9 cover - node_slice: ast.expr = node.slice - elif isinstance(node.slice, ast.Index): # pragma: <3.9 cover - node_slice = node.slice.value - else: - raise AssertionError(f'expected Slice: {ast.dump(node)}') - if isinstance(node_slice, ast.Tuple): - if len(node_slice.elts) == 1: - slice_s = f'{_unparse(node_slice.elts[0])},' - else: - slice_s = ', '.join(_unparse(elt) for elt in node_slice.elts) - else: - slice_s = _unparse(node_slice) - return f'{_unparse(node.value)}[{slice_s}]' - elif isinstance(node, ast.Str): - return repr(node.s) - elif isinstance(node, ast.Ellipsis): - return '...' - elif isinstance(node, ast.List): - return '[{}]'.format(', '.join(_unparse(elt) for elt in node.elts)) - elif isinstance(node, ast.NameConstant): - return repr(node.value) - else: - raise NotImplementedError(ast.dump(node)) - def _skip_unimportant_ws(tokens: list[Token], i: int) -> int: while tokens[i].name == 'UNIMPORTANT_WS': @@ -744,29 +619,6 @@ def _to_fstring( return unparse_parsed_string(parts) -def _typed_class_replacement( - tokens: list[Token], - i: int, - call: ast.Call, - types: dict[str, ast.expr], -) -> tuple[int, str]: - while i > 0 and tokens[i - 1].name == 'DEDENT': - i -= 1 - if i > 0 and tokens[i - 1].name in {'INDENT', UNIMPORTANT_WS}: - indent = f'{tokens[i - 1].src}{" " * 4}' - else: - indent = ' ' * 4 - - # NT = NamedTuple("nt", [("a", int)]) - # ^i ^end - end = i + 1 - while end < len(tokens) and tokens[end].name != 'NEWLINE': - end += 1 - - attrs = '\n'.join(f'{indent}{k}: {_unparse(v)}' for k, v in types.items()) - return end, attrs - - def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str: try: ast_obj = ast_parse(contents_text) @@ -776,12 +628,7 @@ def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str: visitor = FindPy36Plus(min_version=min_version) visitor.visit(ast_obj) - if not any(( - visitor.fstrings, - visitor.named_tuples, - visitor.dict_typed_dicts, - visitor.kw_typed_dicts, - )): + if not visitor.fstrings: return contents_text try: @@ -807,47 +654,6 @@ def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str: src=_to_fstring(token.src, tokens, args), ) del tokens[i + 1:end] - elif token.offset in visitor.named_tuples and token.name == 'NAME': - call = visitor.named_tuples[token.offset] - types: dict[str, ast.expr] = { - tup.elts[0].s: tup.elts[1] - for tup in call.args[1].elts # type: ignore # (checked above) - } - end, attrs = _typed_class_replacement(tokens, i, call, types) - src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' - tokens[i:end] = [Token('CODE', src)] - elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME': - call = visitor.kw_typed_dicts[token.offset] - types = { - arg.arg: arg.value # type: ignore # (checked above) - for arg in call.keywords - } - end, attrs = _typed_class_replacement(tokens, i, call, types) - src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' - tokens[i:end] = [Token('CODE', src)] - elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME': - call = visitor.dict_typed_dicts[token.offset] - types = { - k.s: v - for k, v in zip( - call.args[1].keys, # type: ignore # (checked above) - call.args[1].values, # type: ignore # (checked above) - ) - } - if call.keywords: - total = call.keywords[0].value.value # type: ignore # (checked above) # noqa: E501 - end, attrs = _typed_class_replacement(tokens, i, call, types) - src = ( - f'class {tokens[i].src}(' - f'{_unparse(call.func)}, total={total}' - f'):\n' - f'{attrs}' - ) - tokens[i:end] = [Token('CODE', src)] - else: - end, attrs = _typed_class_replacement(tokens, i, call, types) - src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' - tokens[i:end] = [Token('CODE', src)] return tokens_to_src(tokens) diff --git a/pyupgrade/_plugins/typing_classes.py b/pyupgrade/_plugins/typing_classes.py new file mode 100644 index 00000000..bb029d32 --- /dev/null +++ b/pyupgrade/_plugins/typing_classes.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import ast +import functools +import sys +from typing import Iterable + +from tokenize_rt import Offset +from tokenize_rt import Token +from tokenize_rt import UNIMPORTANT_WS + +from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import has_starargs +from pyupgrade._ast_helpers import is_name_attr +from pyupgrade._data import register +from pyupgrade._data import State +from pyupgrade._data import TokenFunc +from pyupgrade._token_helpers import KEYWORDS + + +def _unparse(node: ast.expr) -> str: + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + return ''.join((_unparse(node.value), '.', node.attr)) + elif isinstance(node, ast.Subscript): + if sys.version_info >= (3, 9): # pragma: >=3.9 cover + node_slice: ast.expr = node.slice + elif isinstance(node.slice, ast.Index): # pragma: <3.9 cover + node_slice = node.slice.value + else: + raise AssertionError(f'expected Slice: {ast.dump(node)}') + if isinstance(node_slice, ast.Tuple): + if len(node_slice.elts) == 1: + slice_s = f'{_unparse(node_slice.elts[0])},' + else: + slice_s = ', '.join(_unparse(elt) for elt in node_slice.elts) + else: + slice_s = _unparse(node_slice) + return f'{_unparse(node.value)}[{slice_s}]' + elif isinstance(node, ast.Str): + return repr(node.s) + elif isinstance(node, ast.Ellipsis): + return '...' + elif isinstance(node, ast.List): + return '[{}]'.format(', '.join(_unparse(elt) for elt in node.elts)) + elif isinstance(node, ast.NameConstant): + return repr(node.value) + else: + raise NotImplementedError(ast.dump(node)) + + +def _typed_class_replacement( + tokens: list[Token], + i: int, + call: ast.Call, + types: dict[str, ast.expr], +) -> tuple[int, str]: + while i > 0 and tokens[i - 1].name == 'DEDENT': + i -= 1 + if i > 0 and tokens[i - 1].name in {'INDENT', UNIMPORTANT_WS}: + indent = f'{tokens[i - 1].src}{" " * 4}' + else: + indent = ' ' * 4 + + # NT = NamedTuple("nt", [("a", int)]) + # ^i ^end + end = i + 1 + while end < len(tokens) and tokens[end].name != 'NEWLINE': + end += 1 + + attrs = '\n'.join(f'{indent}{k}: {_unparse(v)}' for k, v in types.items()) + return end, attrs + + +def _fix_named_tuple(i: int, tokens: list[Token], *, call: ast.Call) -> None: + types = { + tup.elts[0].s: tup.elts[1] + for tup in call.args[1].elts # type: ignore # (checked below) + } + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' + tokens[i:end] = [Token('CODE', src)] + + +def _fix_kw_typed_dict(i: int, tokens: list[Token], *, call: ast.Call) -> None: + types = { + arg.arg: arg.value + for arg in call.keywords + if arg.arg is not None + } + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' + tokens[i:end] = [Token('CODE', src)] + + +def _fix_dict_typed_dict( + i: int, + tokens: list[Token], + *, + call: ast.Call, +) -> None: + types = { + k.s: v + for k, v in zip( + call.args[1].keys, # type: ignore # (checked below) + call.args[1].values, # type: ignore # (checked below) + ) + } + if call.keywords: + total = call.keywords[0].value.value # type: ignore # (checked below) # noqa: E501 + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = ( + f'class {tokens[i].src}(' + f'{_unparse(call.func)}, total={total}' + f'):\n' + f'{attrs}' + ) + tokens[i:end] = [Token('CODE', src)] + else: + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' + tokens[i:end] = [Token('CODE', src)] + + +@register(ast.Assign) +def visit_Assign( + state: State, + node: ast.Assign, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + if state.settings.min_version < (3, 6): + return + + if ( + # NT = ...("NT", ...) + len(node.targets) == 1 and + isinstance(node.targets[0], ast.Name) and + isinstance(node.value, ast.Call) and + len(node.value.args) >= 1 and + isinstance(node.value.args[0], ast.Str) and + node.targets[0].id == node.value.args[0].s and + not has_starargs(node.value) + ): + if ( + is_name_attr( + node.value.func, + state.from_imports, + ('typing',), + ('NamedTuple',), + ) and + len(node.value.args) == 2 and + not node.value.keywords and + isinstance(node.value.args[1], (ast.List, ast.Tuple)) and + len(node.value.args[1].elts) > 0 and + all( + isinstance(tup, ast.Tuple) and + len(tup.elts) == 2 and + isinstance(tup.elts[0], ast.Str) and + tup.elts[0].s.isidentifier() and + tup.elts[0].s not in KEYWORDS + for tup in node.value.args[1].elts + ) + ): + func = functools.partial(_fix_named_tuple, call=node.value) + yield ast_to_offset(node), func + elif ( + is_name_attr( + node.value.func, + state.from_imports, + ('typing', 'typing_extensions'), + ('TypedDict',), + ) and + len(node.value.args) == 1 and + len(node.value.keywords) > 0 and + not any( + keyword.arg == 'total' + for keyword in node.value.keywords + ) + ): + func = functools.partial(_fix_kw_typed_dict, call=node.value) + yield ast_to_offset(node), func + elif ( + is_name_attr( + node.value.func, + state.from_imports, + ('typing', 'typing_extensions'), + ('TypedDict',), + ) and + len(node.value.args) == 2 and + ( + not node.value.keywords or + ( + len(node.value.keywords) == 1 and + node.value.keywords[0].arg == 'total' and + isinstance( + node.value.keywords[0].value, + (ast.Constant, ast.NameConstant), + ) + ) + ) and + isinstance(node.value.args[1], ast.Dict) and + node.value.args[1].keys and + all( + isinstance(k, ast.Str) and + k.s.isidentifier() and + k.s not in KEYWORDS + for k in node.value.args[1].keys + ) + ): + func = functools.partial(_fix_dict_typed_dict, call=node.value) + yield ast_to_offset(node), func diff --git a/tests/features/typing_named_tuple_test.py b/tests/features/typing_classes_test.py similarity index 50% rename from tests/features/typing_named_tuple_test.py rename to tests/features/typing_classes_test.py index 9cf8a602..f5b99beb 100644 --- a/tests/features/typing_named_tuple_test.py +++ b/tests/features/typing_classes_test.py @@ -2,7 +2,8 @@ import pytest -from pyupgrade._main import _fix_py36_plus +from pyupgrade._data import Settings +from pyupgrade._main import _fix_plugins @pytest.mark.parametrize( @@ -59,7 +60,7 @@ ), ) def test_typing_named_tuple_noop(s): - assert _fix_py36_plus(s, min_version=(3, 6)) == s + assert _fix_plugins(s, settings=Settings(min_version=(3, 6))) == s @pytest.mark.parametrize( @@ -173,4 +174,157 @@ def test_typing_named_tuple_noop(s): ), ) def test_fix_typing_named_tuple(s, expected): - assert _fix_py36_plus(s, min_version=(3, 6)) == expected + assert _fix_plugins(s, settings=Settings(min_version=(3, 6))) == expected + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'from wat import TypedDict\n' + 'Q = TypedDict("Q")\n', + id='from imported from elsewhere', + ), + pytest.param('D = typing.TypedDict("D")', id='no typed kwargs'), + pytest.param('D = typing.TypedDict("D", {})', id='no typed args'), + pytest.param('D = typing.TypedDict("D", {}, a=int)', id='both'), + pytest.param('D = typing.TypedDict("D", 1)', id='not a dict'), + pytest.param( + 'D = typing.TypedDict("D", {1: str})', + id='key is not a string', + ), + pytest.param( + 'D = typing.TypedDict("D", {"a-b": str})', + id='key is not an identifier', + ), + pytest.param( + 'D = typing.TypedDict("D", {"class": str})', + id='key is a keyword', + ), + pytest.param( + 'D = typing.TypedDict("D", {**d, "a": str})', + id='dictionary splat operator', + ), + pytest.param( + 'C = typing.TypedDict("C", *types)', + id='starargs', + ), + pytest.param( + 'D = typing.TypedDict("D", **types)', + id='starstarkwargs', + ), + pytest.param( + 'D = typing.TypedDict("D", x=int, total=False)', + id='kw_typed_dict with total', + ), + ), +) +def test_typing_typed_dict_noop(s): + assert _fix_plugins(s, settings=Settings(min_version=(3, 6))) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'from typing import TypedDict\n' + 'D = TypedDict("D", a=int)\n', + + 'from typing import TypedDict\n' + 'class D(TypedDict):\n' + ' a: int\n', + + id='keyword TypedDict from imported', + ), + pytest.param( + 'import typing\n' + 'D = typing.TypedDict("D", a=int)\n', + + 'import typing\n' + 'class D(typing.TypedDict):\n' + ' a: int\n', + + id='keyword TypedDict from attribute', + ), + pytest.param( + 'import typing\n' + 'D = typing.TypedDict("D", {"a": int})\n', + + 'import typing\n' + 'class D(typing.TypedDict):\n' + ' a: int\n', + + id='TypedDict from dict literal', + ), + pytest.param( + 'import typing\n' + 'D = typing.TypedDict("D", {"a": int}, total=False)\n', + + 'import typing\n' + 'class D(typing.TypedDict, total=False):\n' + ' a: int\n', + + id='TypedDict from dict literal with total', + ), + pytest.param( + 'from typing_extensions import TypedDict\n' + 'D = TypedDict("D", a=int)\n', + + 'from typing_extensions import TypedDict\n' + 'class D(TypedDict):\n' + ' a: int\n', + + id='keyword TypedDict from typing_extensions', + ), + pytest.param( + 'import typing_extensions\n' + 'D = typing_extensions.TypedDict("D", {"a": int})\n', + + 'import typing_extensions\n' + 'class D(typing_extensions.TypedDict):\n' + ' a: int\n', + + id='dict TypedDict from typing_extensions', + ), + pytest.param( + 'import typing_extensions\n' + 'D = typing_extensions.TypedDict("D", {"a": int}, total=True)\n', + + 'import typing_extensions\n' + 'class D(typing_extensions.TypedDict, total=True):\n' + ' a: int\n', + + id='keyword TypedDict from typing_extensions, with total', + ), + pytest.param( + 'from typing import List\n' + 'from typing_extensions import TypedDict\n' + 'Foo = TypedDict("Foo", {"lsts": List[List[int]]})', + + 'from typing import List\n' + 'from typing_extensions import TypedDict\n' + 'class Foo(TypedDict):\n' + ' lsts: List[List[int]]', + + id='index unparse error', + ), + pytest.param( + 'import typing\n' + 'if True:\n' + ' if False:\n' + ' pass\n' + ' D = typing.TypedDict("D", a=int)\n', + + 'import typing\n' + 'if True:\n' + ' if False:\n' + ' pass\n' + ' class D(typing.TypedDict):\n' + ' a: int\n', + + id='right after a dedent', + ), + ), +) +def test_typing_typed_dict(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3, 6))) == expected diff --git a/tests/features/typing_typed_dict_test.py b/tests/features/typing_typed_dict_test.py deleted file mode 100644 index b585f208..00000000 --- a/tests/features/typing_typed_dict_test.py +++ /dev/null @@ -1,158 +0,0 @@ -from __future__ import annotations - -import pytest - -from pyupgrade._main import _fix_py36_plus - - -@pytest.mark.parametrize( - 's', - ( - pytest.param( - 'from wat import TypedDict\n' - 'Q = TypedDict("Q")\n', - id='from imported from elsewhere', - ), - pytest.param('D = typing.TypedDict("D")', id='no typed kwargs'), - pytest.param('D = typing.TypedDict("D", {})', id='no typed args'), - pytest.param('D = typing.TypedDict("D", {}, a=int)', id='both'), - pytest.param('D = typing.TypedDict("D", 1)', id='not a dict'), - pytest.param( - 'D = typing.TypedDict("D", {1: str})', - id='key is not a string', - ), - pytest.param( - 'D = typing.TypedDict("D", {"a-b": str})', - id='key is not an identifier', - ), - pytest.param( - 'D = typing.TypedDict("D", {"class": str})', - id='key is a keyword', - ), - pytest.param( - 'D = typing.TypedDict("D", {**d, "a": str})', - id='dictionary splat operator', - ), - pytest.param( - 'C = typing.TypedDict("C", *types)', - id='starargs', - ), - pytest.param( - 'D = typing.TypedDict("D", **types)', - id='starstarkwargs', - ), - pytest.param( - 'D = typing.TypedDict("D", x=int, total=False)', - id='kw_typed_dict with total', - ), - ), -) -def test_typing_typed_dict_noop(s): - assert _fix_py36_plus(s, min_version=(3, 6)) == s - - -@pytest.mark.parametrize( - ('s', 'expected'), - ( - pytest.param( - 'from typing import TypedDict\n' - 'D = TypedDict("D", a=int)\n', - - 'from typing import TypedDict\n' - 'class D(TypedDict):\n' - ' a: int\n', - - id='keyword TypedDict from imported', - ), - pytest.param( - 'import typing\n' - 'D = typing.TypedDict("D", a=int)\n', - - 'import typing\n' - 'class D(typing.TypedDict):\n' - ' a: int\n', - - id='keyword TypedDict from attribute', - ), - pytest.param( - 'import typing\n' - 'D = typing.TypedDict("D", {"a": int})\n', - - 'import typing\n' - 'class D(typing.TypedDict):\n' - ' a: int\n', - - id='TypedDict from dict literal', - ), - pytest.param( - 'import typing\n' - 'D = typing.TypedDict("D", {"a": int}, total=False)\n', - - 'import typing\n' - 'class D(typing.TypedDict, total=False):\n' - ' a: int\n', - - id='TypedDict from dict literal with total', - ), - pytest.param( - 'from typing_extensions import TypedDict\n' - 'D = TypedDict("D", a=int)\n', - - 'from typing_extensions import TypedDict\n' - 'class D(TypedDict):\n' - ' a: int\n', - - id='keyword TypedDict from typing_extensions', - ), - pytest.param( - 'import typing_extensions\n' - 'D = typing_extensions.TypedDict("D", {"a": int})\n', - - 'import typing_extensions\n' - 'class D(typing_extensions.TypedDict):\n' - ' a: int\n', - - id='keyword TypedDict from typing_extensions', - ), - pytest.param( - 'import typing_extensions\n' - 'D = typing_extensions.TypedDict("D", {"a": int}, total=True)\n', - - 'import typing_extensions\n' - 'class D(typing_extensions.TypedDict, total=True):\n' - ' a: int\n', - - id='keyword TypedDict from typing_extensions, with total', - ), - pytest.param( - 'from typing import List\n' - 'from typing_extensions import TypedDict\n' - 'Foo = TypedDict("Foo", {"lsts": List[List[int]]})', - - 'from typing import List\n' - 'from typing_extensions import TypedDict\n' - 'class Foo(TypedDict):\n' - ' lsts: List[List[int]]', - - id='index unparse error', - ), - pytest.param( - 'import typing\n' - 'if True:\n' - ' if False:\n' - ' pass\n' - ' D = typing.TypedDict("D", a=int)\n', - - 'import typing\n' - 'if True:\n' - ' if False:\n' - ' pass\n' - ' class D(typing.TypedDict):\n' - ' a: int\n', - - id='right after a dedent', - ), - ), -) -def test_typing_typed_dict(s, expected): - assert _fix_py36_plus(s, min_version=(3, 6)) == expected