Skip to content

Commit

Permalink
refactor[parser]: remove ASTTokens (#4364)
Browse files Browse the repository at this point in the history
this commit removes `asttokens` from the parse machinery, since the
method is buggy (see below bugs) and slow. this commit brings down
parse time (time spent in ast generation) between 40-70%.

the `mark_tokens()` machinery is replaced with a modified version of
`python.ast`'s `fix_missing_locations()` function, which recurses
through the AST and adds missing line info based on the parent node.

it also changes to a more consistent method for updating source
offsets that are modified by the `pre_parse` step, which fixes several
outstanding bugs with source location reporting.

there were some exceptions to the line info fixup working, the issues
and corresponding workarounds are described as follows:

- some python AST nodes returned by `ast.parse()` are singletons, which
  we work around by deepcopying the AST before operating on it.

- notably, there is an interaction between our AST annotation and
  `coverage.py` in the case of `USub`. in this commit we paper over the
  issue by simply always overriding line info for `USub` nodes. in the
  future, we should refactor `VyperNode` generation by bypassing the
  python AST annotation step entirely, which is a more proper fix to the
  problems encountered in this PR.

the `asttokens` package is not removed entirely since it still has a
limited usage inside of the natspec parser. we could remove it in a
future PR; for now it is out-of-scope.

referenced bugs:
- #2258
- #3059
- #3430
- #4139
  • Loading branch information
charles-cooper authored Jan 12, 2025
1 parent 10e91d5 commit db8dcc7
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 89 deletions.
94 changes: 94 additions & 0 deletions tests/unit/ast/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Tests that the tokenizer / parser are passing correct source location
info to the AST
"""
import pytest

from vyper.ast.parse import parse_to_ast
from vyper.compiler import compile_code
from vyper.exceptions import UndeclaredDefinition


def test_log_token_aligned():
# GH issue 3430
code = """
event A:
b: uint256
@external
def f():
log A(b=d)
"""
with pytest.raises(UndeclaredDefinition) as e:
compile_code(code)

expected = """
'd' has not been declared.
function "f", line 7:12
6 def f():
---> 7 log A(b=d)
-------------------^
8
""" # noqa: W291
assert expected.strip() == str(e.value).strip()


def test_log_token_aligned2():
# GH issue 3059
code = """
interface Contract:
def foo(): nonpayable
event MyEvent:
a: address
@external
def foo(c: Contract):
log MyEvent(a=c.address)
"""
compile_code(code)


def test_log_token_aligned3():
# https://github.com/vyperlang/vyper/pull/3808#pullrequestreview-1900570163
code = """
import ITest
implements: ITest
event Foo:
a: address
@external
def foo(u: uint256):
log Foo(empty(address))
log i.Foo(empty(address))
"""
# not semantically valid code, check we can at least parse it
assert parse_to_ast(code) is not None


def test_log_token_aligned4():
# GH issue 4139
code = """
b: public(uint256)
event Transfer:
random: indexed(uint256)
shi: uint256
@external
def transfer():
log Transfer(T(self).b(), 10)
return
"""
# not semantically valid code, check we can at least parse it
assert parse_to_ast(code) is not None


def test_long_string_non_coding_token():
# GH issue 2258
code = '\r[[]]\ndef _(e:[],l:[]):\n """"""""""""""""""""""""""""""""""""""""""""""""""""""\n f.n()' # noqa: E501
# not valid code, but should at least parse
assert parse_to_ast(code) is not None
1 change: 1 addition & 0 deletions vyper/ast/natspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Optional, Tuple

# NOTE: this is our only use of asttokens -- consider vendoring in the implementation.
from asttokens import LineNumbers

from vyper.ast import nodes as vy_ast
Expand Down
165 changes: 104 additions & 61 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import ast as python_ast
import pickle
import tokenize
from decimal import Decimal
from functools import cached_property
from typing import Any, Dict, List, Optional, Union

import asttokens

from vyper.ast import nodes as vy_ast
from vyper.ast.pre_parser import PreParser
from vyper.compiler.settings import Settings
Expand Down Expand Up @@ -80,12 +80,16 @@ def _parse_to_ast_with_settings(
try:
py_ast = python_ast.parse(pre_parser.reformatted_code)
except SyntaxError as e:
# TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors
offset = e.offset
if offset is not None:
# SyntaxError offset is 1-based, not 0-based (see:
# https://docs.python.org/3/library/exceptions.html#SyntaxError.offset)
offset -= 1

# adjust the column of the error if it was modified by the pre-parser
if e.lineno is not None: # help mypy
offset += pre_parser.adjustments.get((e.lineno, offset), 0)

new_e = SyntaxException(str(e), vyper_source, e.lineno, offset)

likely_errors = ("staticall", "staticcal")
Expand All @@ -97,6 +101,11 @@ def _parse_to_ast_with_settings(

raise new_e from None

# some python AST node instances are singletons and are reused between
# parse() invocations. copy the python AST so that we are using fresh
# objects.
py_ast = _deepcopy_ast(py_ast)

# Add dummy function node to ensure local variables are treated as `AnnAssign`
# instead of state variables (`VariableDecl`)
if add_fn_node:
Expand Down Expand Up @@ -129,6 +138,9 @@ def _parse_to_ast_with_settings(
return pre_parser.settings, module


LINE_INFO_FIELDS = ("lineno", "col_offset", "end_lineno", "end_col_offset")


def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]:
"""
Converts a Vyper AST node, or list of nodes, into a dictionary suitable for
Expand All @@ -155,7 +167,7 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]:


def annotate_python_ast(
parsed_ast: python_ast.AST,
parsed_ast: python_ast.Module,
vyper_source: str,
pre_parser: PreParser,
source_id: int = 0,
Expand All @@ -178,22 +190,19 @@ def annotate_python_ast(
-------
The annotated and optimized AST.
"""
tokens = asttokens.ASTTokens(vyper_source)
assert isinstance(parsed_ast, python_ast.Module) # help mypy
tokens.mark_tokens(parsed_ast)
visitor = AnnotatingVisitor(
vyper_source,
pre_parser,
tokens,
source_id,
module_path=module_path,
resolved_path=resolved_path,
vyper_source, pre_parser, source_id, module_path=module_path, resolved_path=resolved_path
)
visitor.visit(parsed_ast)
visitor.start(parsed_ast)

return parsed_ast


def _deepcopy_ast(ast_node: python_ast.AST):
# pickle roundtrip is faster than copy.deepcopy() here.
return pickle.loads(pickle.dumps(ast_node))


class AnnotatingVisitor(python_ast.NodeTransformer):
_source_code: str
_pre_parser: PreParser
Expand All @@ -202,12 +211,10 @@ def __init__(
self,
source_code: str,
pre_parser: PreParser,
tokens: asttokens.ASTTokens,
source_id: int,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
):
self._tokens = tokens
self._source_id = source_id
self._module_path = module_path
self._resolved_path = resolved_path
Expand All @@ -216,45 +223,87 @@ def __init__(

self.counter: int = 0

@cached_property
def source_lines(self):
return self._source_code.splitlines(keepends=True)

@cached_property
def line_offsets(self):
ofst = 0
# ensure line_offsets has at least 1 entry for 0-line source
ret = {1: ofst}
for lineno, line in enumerate(self.source_lines):
ret[lineno + 1] = ofst
ofst += len(line)
return ret

def start(self, node: python_ast.Module):
self._fix_missing_locations(node)
self.visit(node)

def _fix_missing_locations(self, ast_node: python_ast.Module):
"""
adapted from cpython Lib/ast.py. adds line/col info to ast,
but unlike Lib/ast.py, adjusts *all* ast nodes, not just the
one that python defines to have line/col info.
https://github.com/python/cpython/blob/62729d79206014886f5d/Lib/ast.py#L228
"""
assert isinstance(ast_node, python_ast.Module)
ast_node.lineno = 1
ast_node.col_offset = 0
ast_node.end_lineno = max(1, len(self.source_lines))

if len(self.source_lines) > 0:
ast_node.end_col_offset = len(self.source_lines[-1])
else:
ast_node.end_col_offset = 0

def _fix(node, parent=None):
for field in LINE_INFO_FIELDS:
if parent is not None:
val = getattr(node, field, None)
# special case for USub - heisenbug when coverage is
# enabled in the test suite.
if val is None or isinstance(node, python_ast.USub):
val = getattr(parent, field)
setattr(node, field, val)
else:
assert hasattr(node, field), node

for child in python_ast.iter_child_nodes(node):
_fix(child, node)

_fix(ast_node)

def generic_visit(self, node):
"""
Annotate a node with information that simplifies Vyper node generation.
"""
# Decorate every node with the original source code to allow pretty-printing errors
node.full_source_code = self._source_code
node.node_id = self.counter
node.ast_type = node.__class__.__name__
self.counter += 1
node.ast_type = node.__class__.__name__

# Decorate every node with source end offsets
start = (None, None)
if hasattr(node, "first_token"):
start = node.first_token.start
end = (None, None)
if hasattr(node, "last_token"):
end = node.last_token.end
if node.last_token.type == 4:
# token type 4 is a `\n`, some nodes include a trailing newline
# here we ignore it when building the node offsets
end = (end[0], end[1] - 1)

node.lineno = start[0]
node.col_offset = start[1]
node.end_lineno = end[0]
node.end_col_offset = end[1]

# TODO: adjust end_lineno and end_col_offset when this node is in
# modification_offsets

if hasattr(node, "last_token"):
start_pos = node.first_token.startpos
end_pos = node.last_token.endpos

if node.last_token.type == 4:
# ignore trailing newline once more
end_pos -= 1
node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}"
node.node_source_code = self._source_code[start_pos:end_pos]
adjustments = self._pre_parser.adjustments

# Load and Store behave differently inside of fix_missing_locations;
# we don't use them in the vyper AST so just skip adjusting the line
# info.
if isinstance(node, (python_ast.Load, python_ast.Store)):
return super().generic_visit(node)

adj = adjustments.get((node.lineno, node.col_offset), 0)
node.col_offset += adj

adj = adjustments.get((node.end_lineno, node.end_col_offset), 0)
node.end_col_offset += adj

start_pos = self.line_offsets[node.lineno] + node.col_offset
end_pos = self.line_offsets[node.end_lineno] + node.end_col_offset

node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}"
node.node_source_code = self._source_code[start_pos:end_pos]

return super().generic_visit(node)

Expand Down Expand Up @@ -288,12 +337,6 @@ def visit_Module(self, node):
return self._visit_docstring(node)

def visit_FunctionDef(self, node):
if node.decorator_list:
# start the source highlight at `def` to improve annotation readability
decorator_token = node.decorator_list[-1].last_token
def_token = self._tokens.find_token(decorator_token, tokenize.NAME, tok_str="def")
node.first_token = def_token

return self._visit_docstring(node)

def visit_ClassDef(self, node):
Expand All @@ -306,7 +349,7 @@ def visit_ClassDef(self, node):
"""
self.generic_visit(node)

node.ast_type = self._pre_parser.modification_offsets[(node.lineno, node.col_offset)]
node.ast_type = self._pre_parser.keyword_translations[(node.lineno, node.col_offset)]
return node

def visit_For(self, node):
Expand Down Expand Up @@ -349,16 +392,13 @@ def visit_For(self, node):

try:
fake_node = python_ast.parse(annotation_str).body[0]
# do we need to fix location info here?
fake_node = _deepcopy_ast(fake_node)
except SyntaxError as e:
raise SyntaxException(
"invalid type annotation", self._source_code, node.lineno, node.col_offset
) from e

# fill in with asttokens info. note we can use `self._tokens` because
# it is indented to exactly the same position where it appeared
# in the original source!
self._tokens.mark_tokens(fake_node)

# replace the dummy target name with the real target name.
fake_node.target = node.target
# replace the For node target with the new ann_assign
Expand All @@ -383,14 +423,14 @@ def visit_Expr(self, node):
# CMC 2024-03-03 consider unremoving this from the enclosing Expr
node = node.value
key = (node.lineno, node.col_offset)
node.ast_type = self._pre_parser.modification_offsets[key]
node.ast_type = self._pre_parser.keyword_translations[key]

return node

def visit_Await(self, node):
start_pos = node.lineno, node.col_offset # grab these before generic_visit modifies them
start_pos = node.lineno, node.col_offset
self.generic_visit(node)
node.ast_type = self._pre_parser.modification_offsets[start_pos]
node.ast_type = self._pre_parser.keyword_translations[start_pos]
return node

def visit_Call(self, node):
Expand All @@ -410,6 +450,9 @@ def visit_Call(self, node):
assert len(dict_.keys) == len(dict_.values)
for key, value in zip(dict_.keys, dict_.values):
replacement_kw_node = python_ast.keyword(key.id, value)
# set locations
for attr in LINE_INFO_FIELDS:
setattr(replacement_kw_node, attr, getattr(key, attr))
kw_list.append(replacement_kw_node)

node.args = []
Expand Down
Loading

0 comments on commit db8dcc7

Please sign in to comment.