Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[tool]: add all imported modules to -f annotated_ast output #4209

Merged
merged 14 commits into from
Oct 10, 2024
Merged
58 changes: 54 additions & 4 deletions tests/unit/ast/test_ast_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json

from vyper import compiler
Expand Down Expand Up @@ -216,24 +217,27 @@ def foo():
input_bundle = make_input_bundle({"lib1.vy": lib1, "main.vy": main})

lib1_file = input_bundle.load_file("lib1.vy")
out = compiler.compile_from_file_input(
lib1_out = compiler.compile_from_file_input(
lib1_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"]
)
lib1_ast = out["annotated_ast_dict"]["ast"]

lib1_ast = copy.deepcopy(lib1_out["annotated_ast_dict"]["ast"])
lib1_sha256sum = lib1_ast.pop("source_sha256sum")
assert lib1_sha256sum == lib1_file.sha256sum
to_strip = NODE_SRC_ATTRIBUTES + ("resolved_path", "variable_reads", "variable_writes")
_strip_source_annotations(lib1_ast, to_strip=to_strip)

main_file = input_bundle.load_file("main.vy")
out = compiler.compile_from_file_input(
main_out = compiler.compile_from_file_input(
main_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"]
)
main_ast = out["annotated_ast_dict"]["ast"]
main_ast = main_out["annotated_ast_dict"]["ast"]
main_sha256sum = main_ast.pop("source_sha256sum")
assert main_sha256sum == main_file.sha256sum
_strip_source_annotations(main_ast, to_strip=to_strip)

assert main_out["annotated_ast_dict"]["imports"][0] == lib1_out["annotated_ast_dict"]["ast"]

# TODO: would be nice to refactor this into bunch of small test cases
assert main_ast == {
"ast_type": "Module",
Expand Down Expand Up @@ -1776,3 +1780,49 @@ def qux2():
},
}
]


def test_annotated_ast_export_recursion(make_input_bundle):
sources = {
"main.vy": """
import lib1

@external
def foo():
lib1.foo()
""",
"lib1.vy": """
import lib2

def foo():
lib2.foo()
""",
"lib2.vy": """
def foo():
pass
""",
}

input_bundle = make_input_bundle(sources)

def compile_and_get_ast(file_name):
file = input_bundle.load_file(file_name)
output = compiler.compile_from_file_input(
file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"]
)
return output["annotated_ast_dict"]

lib1_ast = compile_and_get_ast("lib1.vy")["ast"]
lib2_ast = compile_and_get_ast("lib2.vy")["ast"]
main_out = compile_and_get_ast("main.vy")

lib1_import_ast = main_out["imports"][1]
lib2_import_ast = main_out["imports"][0]

# path is once virtual, once libX.vy
# type contains name which is based on path
keys = [s for s in lib1_import_ast.keys() if s not in {"path", "type"}]

for key in keys:
assert lib1_ast[key] == lib1_import_ast[key]
assert lib2_ast[key] == lib2_import_ast[key]
6 changes: 6 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,12 @@ def validate(self):
class Ellipsis(Constant):
__slots__ = ()

def to_dict(self):
ast_dict = super().to_dict()
# python ast ellipsis() is not json serializable; use a string
ast_dict["value"] = self.node_source_code
return ast_dict


class Dict(ExprNode):
__slots__ = ("keys", "values")
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class TopLevel(VyperNode):
class Module(TopLevel):
path: str = ...
resolved_path: str = ...
source_id: int = ...
def namespace(self) -> Any: ... # context manager

class FunctionDef(TopLevel):
Expand Down
28 changes: 27 additions & 1 deletion vyper/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
from collections import deque
from pathlib import PurePath

from vyper.ast import ast_to_dict
import vyper.ast as vy_ast
from vyper.ast.utils import ast_to_dict
from vyper.codegen.ir_node import IRnode
from vyper.compiler.output_bundle import SolcJSONWriter, VyperArchiveWriter
from vyper.compiler.phases import CompilerData
from vyper.compiler.utils import build_gas_estimates
from vyper.evm import opcodes
from vyper.exceptions import VyperException
from vyper.ir import compile_ir
from vyper.semantics.analysis.base import ModuleInfo
from vyper.semantics.types.function import FunctionVisibility, StateMutability
from vyper.semantics.types.module import InterfaceT
from vyper.typing import StorageLayout
from vyper.utils import vyper_warn
from vyper.warnings import ContractSizeLimitWarning
Expand All @@ -26,9 +29,32 @@ def build_ast_dict(compiler_data: CompilerData) -> dict:


def build_annotated_ast_dict(compiler_data: CompilerData) -> dict:
module_t = compiler_data.annotated_vyper_module._metadata["type"]
# get all reachable imports including recursion
imported_module_infos = module_t.reachable_imports
unique_modules: dict[str, vy_ast.Module] = {}
for info in imported_module_infos:
if isinstance(info.typ, InterfaceT):
ast = info.typ.decl_node
if ast is None: # json abi
continue
else:
assert isinstance(info.typ, ModuleInfo)
ast = info.typ.module_t._module

assert isinstance(ast, vy_ast.Module) # help mypy
# use resolved_path for uniqueness, since Module objects can actually
# come from multiple InputBundles (particularly builtin interfaces),
# so source_id is not guaranteed to be unique.
if ast.resolved_path in unique_modules:
# sanity check -- objects must be identical
assert unique_modules[ast.resolved_path] is ast
unique_modules[ast.resolved_path] = ast

annotated_ast_dict = {
"contract_name": str(compiler_data.contract_path),
"ast": ast_to_dict(compiler_data.annotated_vyper_module),
"imports": [ast_to_dict(ast) for ast in unique_modules.values()],
}
return annotated_ast_dict

Expand Down
Loading