Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[ux]: refactor preparser #4293

Merged
merged 11 commits into from
Nov 19, 2024
7 changes: 4 additions & 3 deletions tests/functional/grammar/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from vyper.ast import Module, parse_to_ast
from vyper.ast.grammar import parse_vyper_source, vyper_grammar
from vyper.ast.pre_parser import pre_parse
from vyper.ast.pre_parser import PreParser


def test_basic_grammar():
Expand Down Expand Up @@ -102,6 +102,7 @@ def has_no_docstrings(c):
max_examples=500, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much]
)
def test_grammar_bruteforce(code):
pre_parse_result = pre_parse(code + "\n")
tree = parse_to_ast(pre_parse_result.reformatted_code)
pre_parser = PreParser()
pre_parser.parse(code + "\n")
tree = parse_to_ast(pre_parser.reformatted_code)
assert isinstance(tree, Module)
11 changes: 6 additions & 5 deletions tests/unit/ast/test_annotate_and_optimize_ast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast as python_ast

from vyper.ast.parse import annotate_python_ast, pre_parse
from vyper.ast.parse import PreParser, annotate_python_ast


class AssertionVisitor(python_ast.NodeVisitor):
Expand Down Expand Up @@ -28,12 +28,13 @@ def foo() -> int128:


def get_contract_info(source_code):
pre_parse_result = pre_parse(source_code)
py_ast = python_ast.parse(pre_parse_result.reformatted_code)
pre_parser = PreParser()
pre_parser.parse(source_code)
py_ast = python_ast.parse(pre_parser.reformatted_code)

annotate_python_ast(py_ast, pre_parse_result.reformatted_code, pre_parse_result)
annotate_python_ast(py_ast, pre_parser.reformatted_code, pre_parser)

return py_ast, pre_parse_result.reformatted_code
return py_ast, pre_parser.reformatted_code


def test_it_annotates_ast_with_source_code():
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/ast/test_pre_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from vyper import compile_code
from vyper.ast.pre_parser import pre_parse, validate_version_pragma
from vyper.ast.pre_parser import PreParser, validate_version_pragma
from vyper.compiler.phases import CompilerData
from vyper.compiler.settings import OptimizationLevel, Settings
from vyper.exceptions import StructureException, VersionException
Expand Down Expand Up @@ -174,9 +174,10 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version):
@pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples)
def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version):
mock_version("0.3.10")
pre_parse_result = pre_parse(code)
pre_parser = PreParser()
pre_parser.parse(code)

assert pre_parse_result.settings == pre_parse_settings
assert pre_parser.settings == pre_parse_settings

compiler_data = CompilerData(code)

Expand All @@ -203,8 +204,9 @@ def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_ve

@pytest.mark.parametrize("code", pragma_venom)
def test_parse_venom_pragma(code):
pre_parse_result = pre_parse(code)
assert pre_parse_result.settings.experimental_codegen is True
pre_parser = PreParser()
pre_parser.parse(code)
assert pre_parser.settings.experimental_codegen is True

compiler_data = CompilerData(code)
assert compiler_data.settings.experimental_codegen is True
Expand Down Expand Up @@ -252,7 +254,7 @@ def test_parse_venom_pragma(code):
@pytest.mark.parametrize("code", invalid_pragmas)
def test_invalid_pragma(code):
with pytest.raises(StructureException):
pre_parse(code)
PreParser().parse(code)


def test_version_exception_in_import(make_input_bundle):
Expand Down
39 changes: 20 additions & 19 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import asttokens

from vyper.ast import nodes as vy_ast
from vyper.ast.pre_parser import PreParseResult, pre_parse
from vyper.ast.pre_parser import PreParser
from vyper.compiler.settings import Settings
from vyper.exceptions import CompilerPanic, ParserException, SyntaxException
from vyper.utils import sha256sum, vyper_warn
Expand Down Expand Up @@ -54,9 +54,10 @@ def parse_to_ast_with_settings(
"""
if "\x00" in vyper_source:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
pre_parse_result = pre_parse(vyper_source)
pre_parser = PreParser()
pre_parser.parse(vyper_source)
try:
py_ast = python_ast.parse(pre_parse_result.reformatted_code)
py_ast = python_ast.parse(pre_parser.reformatted_code)
except SyntaxError as e:
# TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors
raise SyntaxException(str(e), vyper_source, e.lineno, e.offset) from None
Expand All @@ -72,20 +73,20 @@ def parse_to_ast_with_settings(
annotate_python_ast(
py_ast,
vyper_source,
pre_parse_result,
pre_parser,
source_id=source_id,
module_path=module_path,
resolved_path=resolved_path,
)

# postcondition: consumed all the for loop annotations
assert len(pre_parse_result.for_loop_annotations) == 0
assert len(pre_parser.for_loop_annotations) == 0

# Convert to Vyper AST.
module = vy_ast.get_node(py_ast)
assert isinstance(module, vy_ast.Module) # mypy hint

return pre_parse_result.settings, module
return pre_parser.settings, module


def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]:
Expand Down Expand Up @@ -116,7 +117,7 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]:
def annotate_python_ast(
parsed_ast: python_ast.AST,
vyper_source: str,
pre_parse_result: PreParseResult,
pre_parser: PreParser,
source_id: int = 0,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
Expand All @@ -130,8 +131,8 @@ def annotate_python_ast(
The AST to be annotated and optimized.
vyper_source: str
The original vyper source code
pre_parse_result: PreParseResult
Outputs from pre-parsing.
pre_parser: PreParser
PreParser object.

Returns
-------
Expand All @@ -142,7 +143,7 @@ def annotate_python_ast(
tokens.mark_tokens(parsed_ast)
visitor = AnnotatingVisitor(
vyper_source,
pre_parse_result,
pre_parser,
tokens,
source_id,
module_path=module_path,
Expand All @@ -155,12 +156,12 @@ def annotate_python_ast(

class AnnotatingVisitor(python_ast.NodeTransformer):
_source_code: str
_pre_parse_result: PreParseResult
_pre_parser: PreParser

def __init__(
self,
source_code: str,
pre_parse_result: PreParseResult,
pre_parser: PreParser,
tokens: asttokens.ASTTokens,
source_id: int,
module_path: Optional[str] = None,
Expand All @@ -171,7 +172,7 @@ def __init__(
self._module_path = module_path
self._resolved_path = resolved_path
self._source_code = source_code
self._pre_parse_result = pre_parse_result
self._pre_parser = pre_parser

self.counter: int = 0

Expand Down Expand Up @@ -265,7 +266,7 @@ def visit_ClassDef(self, node):
"""
self.generic_visit(node)

node.ast_type = self._pre_parse_result.modification_offsets[(node.lineno, node.col_offset)]
node.ast_type = self._pre_parser.modification_offsets[(node.lineno, node.col_offset)]
return node

def visit_For(self, node):
Expand All @@ -274,7 +275,7 @@ def visit_For(self, node):
the pre-parser
"""
key = (node.lineno, node.col_offset)
annotation_tokens = self._pre_parse_result.for_loop_annotations.pop(key)
annotation_tokens = self._pre_parser.for_loop_annotations.pop(key)

if not annotation_tokens:
# a common case for people migrating to 0.4.0, provide a more
Expand Down Expand Up @@ -342,14 +343,14 @@ def visit_Expr(self, node):
# CMC 2024-03-03 consider unremoving this from the enclosing Expr
node = node.value
key = (node.lineno, node.col_offset)
node.ast_type = self._pre_parse_result.modification_offsets[key]
node.ast_type = self._pre_parser.modification_offsets[key]

return node

def visit_Await(self, node):
start_pos = node.lineno, node.col_offset # grab these before generic_visit modifies them
self.generic_visit(node)
node.ast_type = self._pre_parse_result.modification_offsets[start_pos]
node.ast_type = self._pre_parser.modification_offsets[start_pos]
return node

def visit_Call(self, node):
Expand Down Expand Up @@ -394,10 +395,10 @@ def visit_Constant(self, node):
node.ast_type = "NameConstant"
elif isinstance(node.value, str):
key = (node.lineno, node.col_offset)
if key in self._pre_parse_result.native_hex_literal_locations:
if key in self._pre_parser.hex_string_locations:
if len(node.value) % 2 != 0:
raise SyntaxException(
"Native hex string must have an even number of characters",
"Hex string must have an even number of characters",
self._source_code,
node.lineno,
node.col_offset,
Expand Down
112 changes: 46 additions & 66 deletions vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,67 +158,52 @@ def consume(self, token, result):
CUSTOM_EXPRESSION_TYPES = {"extcall": "ExtCall", "staticcall": "StaticCall"}


class PreParseResult:
class PreParser:
# Compilation settings based on the directives in the source code
settings: Settings
# A mapping of class names to their original class types.
modification_offsets: dict[tuple[int, int], str]
# A mapping of line/column offsets of `For` nodes to the annotation of the for loop target
for_loop_annotations: dict[tuple[int, int], list[TokenInfo]]
# A list of line/column offsets of native hex literals
native_hex_literal_locations: list[tuple[int, int]]
# A list of line/column offsets of hex string literals
hex_string_locations: list[tuple[int, int]]
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
# Reformatted python source string.
reformatted_code: str

def __init__(
self,
settings,
modification_offsets,
for_loop_annotations,
native_hex_literal_locations,
reformatted_code,
):
self.settings = settings
self.modification_offsets = modification_offsets
self.for_loop_annotations = for_loop_annotations
self.native_hex_literal_locations = native_hex_literal_locations
self.reformatted_code = reformatted_code


def pre_parse(code: str) -> PreParseResult:
"""
Re-formats a vyper source string into a python source string and performs
some validation. More specifically,

* Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword
* Validates "@version" pragma against current compiler version
* Prevents direct use of python "class" keyword
* Prevents use of python semi-colon statement separator
* Extracts type annotation of for loop iterators into a separate dictionary

Also returns a mapping of detected interface and struct names to their
respective vyper class types ("interface" or "struct"), and a mapping of line numbers
of for loops to the type annotation of their iterators.

Parameters
----------
code : str
The vyper source code to be re-formatted.

Returns
-------
PreParseResult
Outputs for transforming the python AST to vyper AST
"""
result: list[TokenInfo] = []
modification_offsets: dict[tuple[int, int], str] = {}
settings = Settings()
for_parser = ForParser(code)
native_hex_parser = HexStringParser()
def parse(self, code: str):
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
"""
Re-formats a vyper source string into a python source string and performs
some validation. More specifically,

* Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword
* Validates "@version" pragma against current compiler version
* Prevents direct use of python "class" keyword
* Prevents use of python semi-colon statement separator
* Extracts type annotation of for loop iterators into a separate dictionary

Stores a mapping of detected interface and struct names to their
respective vyper class types ("interface" or "struct"), and a mapping of line numbers
of for loops to the type annotation of their iterators.

Parameters
----------
code : str
The vyper source code to be re-formatted.
"""
try:
self._parse(code)
except TokenError as e:
raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e

def _parse(self, code: str):
result: list[TokenInfo] = []
modification_offsets: dict[tuple[int, int], str] = {}
settings = Settings()
for_parser = ForParser(code)
hex_string_parser = HexStringParser()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed from native_hex_parser


_col_adjustments: dict[int, int] = defaultdict(lambda: 0)

_col_adjustments: dict[int, int] = defaultdict(lambda: 0)

try:
code_bytes = code.encode("utf-8")
token_list = list(tokenize(io.BytesIO(code_bytes).readline))

Expand Down Expand Up @@ -301,7 +286,7 @@ def pre_parse(code: str) -> PreParseResult:
# a bit cursed technique to get untokenize to put
# the new tokens in the right place so that modification_offsets
# will work correctly.
# (recommend comparing the result of pre_parse with the
# (recommend comparing the result of parse with the
# source code side by side to visualize the whitespace)
new_keyword = "await"
vyper_type = CUSTOM_EXPRESSION_TYPES[string]
Expand All @@ -322,20 +307,15 @@ def pre_parse(code: str) -> PreParseResult:
if (typ, string) == (OP, ";"):
raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1])

if not for_parser.consume(token) and not native_hex_parser.consume(token, result):
if not for_parser.consume(token) and not hex_string_parser.consume(token, result):
result.extend(toks)

except TokenError as e:
raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e

for_loop_annotations = {}
for k, v in for_parser.annotations.items():
for_loop_annotations[k] = v.copy()
for_loop_annotations = {}
for k, v in for_parser.annotations.items():
for_loop_annotations[k] = v.copy()

return PreParseResult(
settings,
modification_offsets,
for_loop_annotations,
native_hex_parser.locations,
untokenize(result).decode("utf-8"),
)
self.settings = settings
self.modification_offsets = modification_offsets
self.for_loop_annotations = for_loop_annotations
self.hex_string_locations = hex_string_parser.locations
self.reformatted_code = untokenize(result).decode("utf-8")
Loading