diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index 162bd46199..8f18e18a26 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -1208,3 +1208,23 @@ def test_ownership_decl_errors_not_swallowed(make_input_bundle): with pytest.raises(UndeclaredDefinition) as e: compile_code(main, input_bundle=input_bundle) assert e.value._message == "'lib2' has not been declared." + + +def test_partial_compilation(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 + +@internal +def use_lib1(): + lib1.counter += 1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + assert ( + compile_code(main, input_bundle=input_bundle, output_formats=["annotated_ast_dict"]) + is not None + ) diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index ce86eb84ab..1b61764d57 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -40,4 +40,4 @@ def foo(): def test_invalid_checksum(code, dummy_input_bundle): with pytest.raises(InvalidLiteral): vyper_module = vy_ast.parse_to_ast(code) - semantics.validate_semantics(vyper_module, dummy_input_bundle) + semantics.analyze_module(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index 5487a47d97..b5bf86494d 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -7,7 +7,7 @@ TypeMismatch, UndeclaredDefinition, ) -from vyper.semantics.analysis import validate_semantics +from vyper.semantics.analysis import analyze_module @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) @@ -22,7 +22,7 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) @@ -37,7 +37,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) @@ -52,7 +52,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["b", "self.b"]) @@ -67,7 +67,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) @@ -82,4 +82,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index c31146b16f..990c839fde 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -2,7 +2,7 @@ from vyper.ast import parse_to_ast from vyper.exceptions import CallViolation, StructureException -from vyper.semantics.analysis import validate_semantics +from vyper.semantics.analysis import analyze_module def test_self_function_call(dummy_input_bundle): @@ -13,7 +13,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_cyclic_function_call(dummy_input_bundle): @@ -28,7 +28,7 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_multi_cyclic_function_call(dummy_input_bundle): @@ -51,7 +51,7 @@ def potato(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_global_ann_assign_callable_no_crash(dummy_input_bundle): @@ -64,5 +64,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 089eb6a661..d7d4f7083b 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -2,7 +2,7 @@ from vyper.ast import parse_to_ast from vyper.exceptions import ArgumentException, ImmutableViolation, StructureException, TypeMismatch -from vyper.semantics.analysis import validate_semantics +from vyper.semantics.analysis import analyze_module def test_modify_iterator_function_outside_loop(dummy_input_bundle): @@ -21,7 +21,7 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_pass_memory_var_to_other_function(dummy_input_bundle): @@ -41,7 +41,7 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_modify_iterator(dummy_input_bundle): @@ -56,7 +56,7 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_bad_keywords(dummy_input_bundle): @@ -70,7 +70,7 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_bad_bound(dummy_input_bundle): @@ -84,7 +84,7 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_modify_iterator_function_call(dummy_input_bundle): @@ -103,7 +103,7 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_modify_iterator_recursive_function_call(dummy_input_bundle): @@ -126,7 +126,7 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): @@ -149,7 +149,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) assert e.value._message == "Cannot modify loop variable `a`" @@ -170,7 +170,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) assert e.value._message == "Cannot modify loop variable `a`" @@ -189,7 +189,7 @@ def foo(): self.b[self.a[1]] = i """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_modify_iterator_siblings(dummy_input_bundle): @@ -207,7 +207,7 @@ def foo(): self.f.b += i """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) def test_modify_subscript_barrier(dummy_input_bundle): @@ -229,7 +229,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) assert e.value._message == "Cannot modify loop variable `b`" @@ -272,4 +272,4 @@ def foo(): def test_iterator_type_inference_checker(code, dummy_input_bundle): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, dummy_input_bundle) + analyze_module(vyper_module, dummy_input_bundle) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index af94011633..e343938021 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -12,7 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer -from vyper.semantics import set_data_positions, validate_semantics +from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout @@ -156,9 +156,19 @@ def vyper_module(self): def annotated_vyper_module(self) -> vy_ast.Module: return generate_annotated_ast(self.vyper_module, self.input_bundle) + @cached_property + def compilation_target(self): + """ + Get the annotated AST, and additionally run the global checks + required for a compilation target. + """ + module_t = self.annotated_vyper_module._metadata["type"] + validate_compilation_target(module_t) + return self.annotated_vyper_module + @cached_property def storage_layout(self) -> StorageLayout: - module_ast = self.annotated_vyper_module + module_ast = self.compilation_target return set_data_positions(module_ast, self.storage_layout_override) @property @@ -251,13 +261,11 @@ def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundl ------- vy_ast.Module Annotated Vyper AST - StorageLayout - Layout of variables in storage """ vyper_module = copy.deepcopy(vyper_module) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # note: validate_semantics does type inference on the AST - validate_semantics(vyper_module, input_bundle) + # note: analyze_module does type inference on the AST + analyze_module(vyper_module, input_bundle) return vyper_module diff --git a/vyper/semantics/__init__.py b/vyper/semantics/__init__.py index bb40c266a4..c5b4f62f5b 100644 --- a/vyper/semantics/__init__.py +++ b/vyper/semantics/__init__.py @@ -1,2 +1,2 @@ -from .analysis import validate_semantics +from .analysis import analyze_module, validate_compilation_target from .analysis.data_positions import set_data_positions diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index e23b2d2aa4..c15e0dd8b8 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,4 +1,5 @@ from .. import types # break a dependency cycle. -from .global_ import validate_semantics +from .global_ import validate_compilation_target +from .module import analyze_module -__all__ = ["validate_semantics"] +__all__ = [validate_compilation_target, analyze_module] # type: ignore[misc] diff --git a/vyper/semantics/analysis/global_.py b/vyper/semantics/analysis/global_.py index 92cdf35c5d..a632c9194c 100644 --- a/vyper/semantics/analysis/global_.py +++ b/vyper/semantics/analysis/global_.py @@ -2,17 +2,11 @@ from vyper.exceptions import ExceptionList, InitializerException from vyper.semantics.analysis.base import InitializesInfo, UsesInfo -from vyper.semantics.analysis.import_graph import ImportGraph -from vyper.semantics.analysis.module import validate_module_semantics_r from vyper.semantics.types.module import ModuleT -def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: - ret = validate_module_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) - - _validate_global_initializes_constraint(ret) - - return ret +def validate_compilation_target(module_t: ModuleT): + _validate_global_initializes_constraint(module_t) def _collect_used_modules_r(module_t): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index fe1248ede5..7a592275b6 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -51,16 +51,29 @@ from vyper.utils import OrderedSet -def validate_module_semantics_r( +def analyze_module( module_ast: vy_ast.Module, input_bundle: InputBundle, - import_graph: ImportGraph, - is_interface: bool, + import_graph: ImportGraph = None, + is_interface: bool = False, ) -> ModuleT: """ - Analyze a Vyper module AST node, add all module-level objects to the - namespace, type-check/validate semantics and annotate with type and analysis info + Analyze a Vyper module AST node, recursively analyze all its imports, + add all module-level objects to the namespace, type-check/validate + semantics and annotate with type and analysis info """ + if import_graph is None: + import_graph = ImportGraph() + + return _analyze_module_r(module_ast, input_bundle, import_graph, is_interface) + + +def _analyze_module_r( + module_ast: vy_ast.Module, + input_bundle: InputBundle, + import_graph: ImportGraph, + is_interface: bool = False, +): if "type" in module_ast._metadata: # we don't need to analyse again, skip out assert isinstance(module_ast._metadata["type"], ModuleT) @@ -742,7 +755,7 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - module_t = validate_module_semantics_r( + module_t = _analyze_module_r( module_ast, self.input_bundle, import_graph=self._import_graph, @@ -762,7 +775,7 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - validate_module_semantics_r( + _analyze_module_r( module_ast, self.input_bundle, import_graph=self._import_graph, @@ -871,7 +884,5 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: interface_ast = _parse_and_fold_ast(file) with override_global_namespace(Namespace()): - module_t = validate_module_semantics_r( - interface_ast, input_bundle, ImportGraph(), is_interface=True - ) + module_t = _analyze_module_r(interface_ast, input_bundle, ImportGraph(), is_interface=True) return module_t.interface