Skip to content

Commit

Permalink
Merge pull request #625 from asottile/split-py36-plugin
Browse files Browse the repository at this point in the history
move typed class rewrite to a plugin
  • Loading branch information
asottile authored Apr 6, 2022
2 parents d9e0b90 + ca0e677 commit 087c7e6
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 356 deletions.
1 change: 1 addition & 0 deletions pyupgrade/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class State(NamedTuple):
'subprocess',
'sys',
'typing',
'typing_extensions',
))

FUNCS = collections.defaultdict(list)
Expand Down
196 changes: 1 addition & 195 deletions pyupgrade/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import argparse
import ast
import collections
import re
import string
import sys
Expand All @@ -26,7 +25,6 @@
from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import contains_await
from pyupgrade._ast_helpers import has_starargs
from pyupgrade._ast_helpers import is_name_attr
from pyupgrade._data import FUNCS
from pyupgrade._data import Settings
from pyupgrade._data import Version
Expand All @@ -35,7 +33,6 @@
from pyupgrade._string_helpers import is_codec
from pyupgrade._string_helpers import NAMED_UNICODE_RE
from pyupgrade._token_helpers import CLOSING
from pyupgrade._token_helpers import KEYWORDS
from pyupgrade._token_helpers import OPENING
from pyupgrade._token_helpers import parse_call_args
from pyupgrade._token_helpers import remove_brace
Expand Down Expand Up @@ -538,19 +535,8 @@ def _format_params(call: ast.Call) -> set[str]:
class FindPy36Plus(ast.NodeVisitor):
def __init__(self, *, min_version: Version) -> None:
self.fstrings: dict[Offset, ast.Call] = {}
self.named_tuples: dict[Offset, ast.Call] = {}
self.dict_typed_dicts: dict[Offset, ast.Call] = {}
self.kw_typed_dicts: dict[Offset, ast.Call] = {}
self._from_imports: dict[str, set[str]] = collections.defaultdict(set)
self.min_version = min_version

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.level == 0 and node.module in {'typing', 'typing_extensions'}:
for name in node.names:
if not name.asname:
self._from_imports[node.module].add(name.name)
self.generic_visit(node)

def _parse(self, node: ast.Call) -> tuple[DotFormatPart, ...] | None:
if not (
isinstance(node.func, ast.Attribute) and
Expand Down Expand Up @@ -597,117 +583,6 @@ def visit_Call(self, node: ast.Call) -> None:

self.generic_visit(node)

def visit_Assign(self, node: ast.Assign) -> None:
if (
# NT = ...("NT", ...)
len(node.targets) == 1 and
isinstance(node.targets[0], ast.Name) and
isinstance(node.value, ast.Call) and
len(node.value.args) >= 1 and
isinstance(node.value.args[0], ast.Str) and
node.targets[0].id == node.value.args[0].s and
not has_starargs(node.value)
):
if (
is_name_attr(
node.value.func,
self._from_imports,
('typing',),
('NamedTuple',),
) and
len(node.value.args) == 2 and
not node.value.keywords and
isinstance(node.value.args[1], (ast.List, ast.Tuple)) and
len(node.value.args[1].elts) > 0 and
all(
isinstance(tup, ast.Tuple) and
len(tup.elts) == 2 and
isinstance(tup.elts[0], ast.Str) and
tup.elts[0].s.isidentifier() and
tup.elts[0].s not in KEYWORDS
for tup in node.value.args[1].elts
)
):
self.named_tuples[ast_to_offset(node)] = node.value
elif (
is_name_attr(
node.value.func,
self._from_imports,
('typing', 'typing_extensions'),
('TypedDict',),
) and
len(node.value.args) == 1 and
len(node.value.keywords) > 0 and
not any(
keyword.arg == 'total'
for keyword in node.value.keywords
)
):
self.kw_typed_dicts[ast_to_offset(node)] = node.value
elif (
is_name_attr(
node.value.func,
self._from_imports,
('typing', 'typing_extensions'),
('TypedDict',),
) and
len(node.value.args) == 2 and
(
not node.value.keywords or
(
len(node.value.keywords) == 1 and
node.value.keywords[0].arg == 'total' and
isinstance(
node.value.keywords[0].value,
(ast.Constant, ast.NameConstant),
)
)
) and
isinstance(node.value.args[1], ast.Dict) and
node.value.args[1].keys and
all(
isinstance(k, ast.Str) and
k.s.isidentifier() and
k.s not in KEYWORDS
for k in node.value.args[1].keys
)
):
self.dict_typed_dicts[ast_to_offset(node)] = node.value

self.generic_visit(node)


def _unparse(node: ast.expr) -> str:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return ''.join((_unparse(node.value), '.', node.attr))
elif isinstance(node, ast.Subscript):
if sys.version_info >= (3, 9): # pragma: >=3.9 cover
node_slice: ast.expr = node.slice
elif isinstance(node.slice, ast.Index): # pragma: <3.9 cover
node_slice = node.slice.value
else:
raise AssertionError(f'expected Slice: {ast.dump(node)}')
if isinstance(node_slice, ast.Tuple):
if len(node_slice.elts) == 1:
slice_s = f'{_unparse(node_slice.elts[0])},'
else:
slice_s = ', '.join(_unparse(elt) for elt in node_slice.elts)
else:
slice_s = _unparse(node_slice)
return f'{_unparse(node.value)}[{slice_s}]'
elif isinstance(node, ast.Str):
return repr(node.s)
elif isinstance(node, ast.Ellipsis):
return '...'
elif isinstance(node, ast.List):
return '[{}]'.format(', '.join(_unparse(elt) for elt in node.elts))
elif isinstance(node, ast.NameConstant):
return repr(node.value)
else:
raise NotImplementedError(ast.dump(node))


def _skip_unimportant_ws(tokens: list[Token], i: int) -> int:
while tokens[i].name == 'UNIMPORTANT_WS':
Expand Down Expand Up @@ -744,29 +619,6 @@ def _to_fstring(
return unparse_parsed_string(parts)


def _typed_class_replacement(
tokens: list[Token],
i: int,
call: ast.Call,
types: dict[str, ast.expr],
) -> tuple[int, str]:
while i > 0 and tokens[i - 1].name == 'DEDENT':
i -= 1
if i > 0 and tokens[i - 1].name in {'INDENT', UNIMPORTANT_WS}:
indent = f'{tokens[i - 1].src}{" " * 4}'
else:
indent = ' ' * 4

# NT = NamedTuple("nt", [("a", int)])
# ^i ^end
end = i + 1
while end < len(tokens) and tokens[end].name != 'NEWLINE':
end += 1

attrs = '\n'.join(f'{indent}{k}: {_unparse(v)}' for k, v in types.items())
return end, attrs


def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str:
try:
ast_obj = ast_parse(contents_text)
Expand All @@ -776,12 +628,7 @@ def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str:
visitor = FindPy36Plus(min_version=min_version)
visitor.visit(ast_obj)

if not any((
visitor.fstrings,
visitor.named_tuples,
visitor.dict_typed_dicts,
visitor.kw_typed_dicts,
)):
if not visitor.fstrings:
return contents_text

try:
Expand All @@ -807,47 +654,6 @@ def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str:
src=_to_fstring(token.src, tokens, args),
)
del tokens[i + 1:end]
elif token.offset in visitor.named_tuples and token.name == 'NAME':
call = visitor.named_tuples[token.offset]
types: dict[str, ast.expr] = {
tup.elts[0].s: tup.elts[1]
for tup in call.args[1].elts # type: ignore # (checked above)
}
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
tokens[i:end] = [Token('CODE', src)]
elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME':
call = visitor.kw_typed_dicts[token.offset]
types = {
arg.arg: arg.value # type: ignore # (checked above)
for arg in call.keywords
}
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
tokens[i:end] = [Token('CODE', src)]
elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME':
call = visitor.dict_typed_dicts[token.offset]
types = {
k.s: v
for k, v in zip(
call.args[1].keys, # type: ignore # (checked above)
call.args[1].values, # type: ignore # (checked above)
)
}
if call.keywords:
total = call.keywords[0].value.value # type: ignore # (checked above) # noqa: E501
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = (
f'class {tokens[i].src}('
f'{_unparse(call.func)}, total={total}'
f'):\n'
f'{attrs}'
)
tokens[i:end] = [Token('CODE', src)]
else:
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
tokens[i:end] = [Token('CODE', src)]

return tokens_to_src(tokens)

Expand Down
Loading

0 comments on commit 087c7e6

Please sign in to comment.