From 471a8646fbaeb5377f3a82f71cf5e59e9a24033a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 9 Apr 2024 08:01:36 -0400 Subject: [PATCH 1/9] fix[lang]: fix `uses` analysis for nonreentrant functions `uses` analysis ignores nonreentrant functions, even those (implicitly) use state. this commit adds checks both for internally (called) and external (exported) modules --- vyper/semantics/analysis/local.py | 21 +++++++++++---------- vyper/semantics/analysis/module.py | 2 +- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b0a6e38d10..f726e60b4f 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,7 +1,7 @@ # CMC 2024-02-03 TODO: split me into function.py and expr.py import contextlib -from typing import Optional +from typing import Iterable, Optional from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args @@ -229,6 +229,10 @@ def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: return VarAccess(info.var_info, tuple(path)) +def _uses_state(var_accesses: Iterable[VarAccess]) -> bool: + return any(s.variable.is_state_variable() for s in var_accesses) + + # get the chain of modules, e.g. # mod1.mod2.x.y -> [ModuleInfo(mod1), ModuleInfo(mod2)] # CMC 2024-02-12 note that the Attribute/Subscript traversal in this and @@ -443,10 +447,7 @@ def _handle_modification(self, target: vy_ast.ExprNode): info._writes.add(var_access) - def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode): - if not var_access.variable.is_state_variable(): - return - + def _handle_module_access(self, target: vy_ast.ExprNode): root_module_info = check_module_uses(target) if root_module_info is not None: @@ -684,9 +685,9 @@ def visit(self, node, typ): msg += f" `{var.decl_node.target.id}`" raise ImmutableViolation(msg, var.decl_node, node) - variable_accesses = info._writes | info._reads - for s in variable_accesses: - self.function_analyzer._handle_module_access(s, node) + var_accesses = info._writes | info._reads + if _uses_state(var_accesses): + self.function_analyzer._handle_module_access(node) self.func.mark_variable_writes(info._writes) self.func.mark_variable_reads(info._reads) @@ -789,8 +790,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if self.function_analyzer: self._check_call_mutability(func_type.mutability) - for s in func_type.get_variable_accesses(): - self.function_analyzer._handle_module_access(s, node.func) + if func_type.nonreentrant or _uses_state(func_type.get_variable_accesses()): + self.function_analyzer._handle_module_access(node.func) if func_type.is_deploy and not self.func.is_deploy: raise CallViolation( diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 619f4e4c10..1ae796a684 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -561,7 +561,7 @@ def visit_ExportsDecl(self, node): # check module uses var_accesses = func_t.get_variable_accesses() - if any(s.variable.is_state_variable() for s in var_accesses): + if func_t.nonreentrant or any(s.variable.is_state_variable() for s in var_accesses): module_info = check_module_uses(item) assert module_info is not None # guaranteed by above checks used_modules.add(module_info) From 9db4bfb76773a551d28f967e40e3d9db026e42f0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 9 Apr 2024 08:37:19 -0400 Subject: [PATCH 2/9] refactor: factor out uses_state util also, rename `validate_functions` to more accurate `analyze_functions` --- vyper/semantics/analysis/local.py | 23 ++++++++++------------- vyper/semantics/analysis/module.py | 7 +++---- vyper/semantics/analysis/utils.py | 8 ++++++-- vyper/semantics/types/function.py | 5 +++++ 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index f726e60b4f..a61fe4dadb 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,7 +1,7 @@ -# CMC 2024-02-03 TODO: split me into function.py and expr.py +# CMC 2024-02-03 TODO: rename me to function.py import contextlib -from typing import Iterable, Optional +from typing import Optional from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args @@ -33,6 +33,7 @@ get_exact_type_from_node, get_expr_info, get_possible_types_from_node, + uses_state, validate_expected_type, ) from vyper.semantics.data_locations import DataLocation @@ -64,22 +65,22 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_functions(vy_module: vy_ast.Module) -> None: +def analyze_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() for node in vy_module.get_children(vy_ast.FunctionDef): - _validate_function_r(vy_module, node, err_list) + _analyze_function_r(vy_module, node, err_list) for node in vy_module.get_children(vy_ast.VariableDecl): if not node.is_public: continue - _validate_function_r(vy_module, node._expanded_getter, err_list) + _analyze_function_r(vy_module, node._expanded_getter, err_list) err_list.raise_if_not_empty() -def _validate_function_r( +def _analyze_function_r( vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList ): func_t = node._metadata["func_type"] @@ -87,7 +88,7 @@ def _validate_function_r( for call_t in func_t.called_functions: if isinstance(call_t, ContractFunctionT): assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy - _validate_function_r(vy_module, call_t.ast_def, err_list) + _analyze_function_r(vy_module, call_t.ast_def, err_list) namespace = get_namespace() @@ -229,10 +230,6 @@ def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: return VarAccess(info.var_info, tuple(path)) -def _uses_state(var_accesses: Iterable[VarAccess]) -> bool: - return any(s.variable.is_state_variable() for s in var_accesses) - - # get the chain of modules, e.g. # mod1.mod2.x.y -> [ModuleInfo(mod1), ModuleInfo(mod2)] # CMC 2024-02-12 note that the Attribute/Subscript traversal in this and @@ -686,7 +683,7 @@ def visit(self, node, typ): raise ImmutableViolation(msg, var.decl_node, node) var_accesses = info._writes | info._reads - if _uses_state(var_accesses): + if uses_state(var_accesses): self.function_analyzer._handle_module_access(node) self.func.mark_variable_writes(info._writes) @@ -790,7 +787,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if self.function_analyzer: self._check_call_mutability(func_type.mutability) - if func_type.nonreentrant or _uses_state(func_type.get_variable_accesses()): + if func_type.uses_state(): self.function_analyzer._handle_module_access(node.func) if func_type.is_deploy and not self.func.is_deploy: diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 1ae796a684..0353d73bfc 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -43,7 +43,7 @@ from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.getters import generate_public_variable_getters from vyper.semantics.analysis.import_graph import ImportGraph -from vyper.semantics.analysis.local import ExprVisitor, check_module_uses, validate_functions +from vyper.semantics.analysis.local import ExprVisitor, analyze_functions, check_module_uses from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, @@ -102,7 +102,7 @@ def _analyze_module_r( # if this is an interface, the function is already validated # in `ContractFunction.from_vyi()` if not is_interface: - validate_functions(module_ast) + analyze_functions(module_ast) analyzer.validate_initialized_modules() analyzer.validate_used_modules() @@ -560,8 +560,7 @@ def visit_ExportsDecl(self, node): funcs.append(func_t) # check module uses - var_accesses = func_t.get_variable_accesses() - if func_t.nonreentrant or any(s.variable.is_state_variable() for s in var_accesses): + if func_t.uses_state(): module_info = check_module_uses(item) assert module_info is not None # guaranteed by above checks used_modules.add(module_info) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 4b751e7406..b4b31ca358 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -1,5 +1,5 @@ import itertools -from typing import Callable, List +from typing import Callable, Iterable, List from vyper import ast as vy_ast from vyper.exceptions import ( @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -48,6 +48,10 @@ def _validate_op(node, types_list, validation_fn_name): raise err_list[0] +def uses_state(var_accesses: Iterable[VarAccess]) -> bool: + return any(s.variable.is_state_variable() for s in var_accesses) + + class _ExprAnalyser: """ Node type-checker class. diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index fbeb3e37cd..86fd90f0f9 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -27,6 +27,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, + uses_state, validate_expected_type, ) from vyper.semantics.data_locations import DataLocation @@ -163,7 +164,11 @@ def get_variable_writes(self): def get_variable_accesses(self): return self._variable_reads | self._variable_writes + def uses_state(self): + return self.nonreentrant or uses_state(self.get_variable_accesses()) + def get_used_modules(self): + # _used_modules is populated during analysis return self._used_modules def mark_used_module(self, module_info): From 0cb0f93fb588e85c2f202ac1d23411eb33ebbdfd Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 9 Apr 2024 08:46:48 -0400 Subject: [PATCH 3/9] add syntax tests --- .../syntax/modules/test_initializers.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index 9825e4618f..3cbc987a95 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -1292,3 +1292,53 @@ def foo(): compile_code(main, input_bundle=input_bundle) assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`" assert e.value._hint == "try importing lib1 first" + + +def test_nonreentrant_exports(make_input_bundle): + lib1 = """ +# lib1.vy +@external +@nonreentrant +def bar(): + pass + """ + main = """ +import lib1 + +exports: lib1.bar # line 4 + +@external +def foo(): + pass + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "Cannot access `lib1` state!" + hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" + assert e.value._hint == hint + assert e.value.annotations[0].lineno == 4 + + +def test_internal_nonreentrant_import(make_input_bundle): + lib1 = """ +# lib1.vy +@internal +@nonreentrant +def bar(): + pass + """ + main = """ +import lib1 + +@external +def foo(): + lib1.bar() # line 6 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "Cannot access `lib1` state!" + hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" + assert e.value._hint == hint + assert e.value.annotations[0].lineno == 6 From 2f4525ccff0d35bb2fa06273e9a67fe8310b13a2 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 9 Apr 2024 08:51:48 -0400 Subject: [PATCH 4/9] add nonreentrant export test --- .../codegen/modules/test_exports.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/functional/codegen/modules/test_exports.py b/tests/functional/codegen/modules/test_exports.py index 2dc90bfe74..91b6a01138 100644 --- a/tests/functional/codegen/modules/test_exports.py +++ b/tests/functional/codegen/modules/test_exports.py @@ -147,3 +147,41 @@ def foo() -> uint256: c = get_contract(main, input_bundle=input_bundle) assert c.foo() == 5 + + +def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed): + lib1 = """ +interface Foo: + def foo() -> uint256: nonpayable + +implements: Foo + +@external +@nonreentrant +def foo() -> uint256: + return 5 + """ + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.foo + +@external +@nonreentrant +def re_enter(): + extcall lib1.Foo(self).foo() # should always throw + +@external +def __default__(): + # sanity: make sure we don't revert due to bad selector + pass + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 5 + with tx_failed(): + c.re_enter() From 8a69b834cca19e021b8607414694a5ef3da8fdc2 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 9 Apr 2024 08:55:06 -0400 Subject: [PATCH 5/9] add nonreentrant internal import test, move exported nonreentrant function to new file --- .../codegen/modules/test_exports.py | 38 --------- .../codegen/modules/test_nonreentrant.py | 78 +++++++++++++++++++ 2 files changed, 78 insertions(+), 38 deletions(-) create mode 100644 tests/functional/codegen/modules/test_nonreentrant.py diff --git a/tests/functional/codegen/modules/test_exports.py b/tests/functional/codegen/modules/test_exports.py index 91b6a01138..2dc90bfe74 100644 --- a/tests/functional/codegen/modules/test_exports.py +++ b/tests/functional/codegen/modules/test_exports.py @@ -147,41 +147,3 @@ def foo() -> uint256: c = get_contract(main, input_bundle=input_bundle) assert c.foo() == 5 - - -def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed): - lib1 = """ -interface Foo: - def foo() -> uint256: nonpayable - -implements: Foo - -@external -@nonreentrant -def foo() -> uint256: - return 5 - """ - main = """ -import lib1 - -initializes: lib1 - -exports: lib1.foo - -@external -@nonreentrant -def re_enter(): - extcall lib1.Foo(self).foo() # should always throw - -@external -def __default__(): - # sanity: make sure we don't revert due to bad selector - pass - """ - - input_bundle = make_input_bundle({"lib1.vy": lib1}) - - c = get_contract(main, input_bundle=input_bundle) - assert c.foo() == 5 - with tx_failed(): - c.re_enter() diff --git a/tests/functional/codegen/modules/test_nonreentrant.py b/tests/functional/codegen/modules/test_nonreentrant.py new file mode 100644 index 0000000000..69b17cbfa2 --- /dev/null +++ b/tests/functional/codegen/modules/test_nonreentrant.py @@ -0,0 +1,78 @@ +def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed): + lib1 = """ +interface Foo: + def foo() -> uint256: nonpayable + +implements: Foo + +@external +@nonreentrant +def foo() -> uint256: + return 5 + """ + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.foo + +@external +@nonreentrant +def re_enter(): + extcall lib1.Foo(self).foo() # should always throw + +@external +def __default__(): + # sanity: make sure we don't revert due to bad selector + pass + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 5 + with tx_failed(): + c.re_enter() + + +def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed): + lib1 = """ +interface Foo: + def foo() -> uint256: nonpayable + +implements: Foo + +@external +def foo() -> uint256: + return self._safe_fn() + +@internal +@nonreentrant +def _safe_fn() -> uint256: + return 10 + """ + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.foo + +@external +@nonreentrant +def re_enter(): + extcall lib1.Foo(self).foo() # should always throw + +@external +def __default__(): + # sanity: make sure we don't revert due to bad selector + pass + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 10 + with tx_failed(): + c.re_enter() From f6d1c81b14e976d82f6e5a3c2e7d4236b97199bd Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 11 Apr 2024 07:15:37 -0400 Subject: [PATCH 6/9] improve locality of exceptions thrown in check_module_uses --- vyper/semantics/analysis/module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 0353d73bfc..af449e2a8e 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -557,13 +557,13 @@ def visit_ExportsDecl(self, node): with tag_exceptions(item): # tag with specific item self._self_t.typ.add_member(func_t.name, func_t) - funcs.append(func_t) + funcs.append(func_t) - # check module uses - if func_t.uses_state(): - module_info = check_module_uses(item) - assert module_info is not None # guaranteed by above checks - used_modules.add(module_info) + # check module uses + if func_t.uses_state(): + module_info = check_module_uses(item) + assert module_info is not None # guaranteed by above checks + used_modules.add(module_info) node._metadata["exports_info"] = ExportsInfo(funcs, used_modules) From eaa59886274ac5e6af41c86a2e7d3d8d47f17c86 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 11 Apr 2024 08:36:50 -0400 Subject: [PATCH 7/9] update error message --- vyper/semantics/analysis/local.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a61fe4dadb..c08c3c0706 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -268,7 +268,14 @@ def check_module_uses(node: vy_ast.ExprNode) -> Optional[ModuleInfo]: for module_info in module_infos: if module_info.ownership < ModuleOwnership.USES: - msg = f"Cannot access `{module_info.alias}` state!" + msg = f"Cannot access `{module_info.alias}` state!\n note that" + # CMC 2024-04-12 add UX note about nonreentrant. might be nice + # in the future to be more specific about exactly which state is + # used, although that requires threading a bit more context into + # this function. + msg += " use of the `@nonreentrant` decorator is also considered" + msg += " state access" + hint = f"add `uses: {module_info.alias}` or " hint += f"`initializes: {module_info.alias}` as " hint += "a top-level statement to your contract" From 16ca98f71a621c12621ce579709f80609ce52cf2 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 12 Apr 2024 10:48:42 -0400 Subject: [PATCH 8/9] update tests --- .../functional/syntax/modules/test_exports.py | 6 ++-- .../syntax/modules/test_initializers.py | 31 ++++++++++--------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/functional/syntax/modules/test_exports.py b/tests/functional/syntax/modules/test_exports.py index 24a233da9d..1edb99bc7f 100644 --- a/tests/functional/syntax/modules/test_exports.py +++ b/tests/functional/syntax/modules/test_exports.py @@ -3,6 +3,8 @@ from vyper.compiler import compile_code from vyper.exceptions import ImmutableViolation, NamespaceCollision, StructureException +from .helpers import NONREENTRANT_NOTE + def test_exports_no_uses(make_input_bundle): lib1 = """ @@ -21,7 +23,7 @@ def get_counter() -> uint256: with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -40,7 +42,7 @@ def test_exports_no_uses_variable(make_input_bundle): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index 276742963f..2193050a5f 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -17,6 +17,8 @@ UndeclaredDefinition, ) +from .helpers import NONREENTRANT_NOTE + def test_initialize_uses(make_input_bundle): lib1 = """ @@ -413,7 +415,7 @@ def foo(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -450,7 +452,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -491,7 +493,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -536,7 +538,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -571,7 +573,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -612,7 +614,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -656,7 +658,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -695,7 +697,7 @@ def foo(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -734,7 +736,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib2` state!" + assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib2` or `initializes: lib2` as a " expected_hint += "top-level statement to your contract" @@ -776,7 +778,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib2` state!" + assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib2` or `initializes: lib2` as a " expected_hint += "top-level statement to your contract" @@ -819,7 +821,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib2` state!" + assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib2` or `initializes: lib2` as a " expected_hint += "top-level statement to your contract" @@ -853,7 +855,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -1318,7 +1320,7 @@ def foo(): input_bundle = make_input_bundle({"lib1.vy": lib1}) with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" assert e.value._hint == hint assert e.value.annotations[0].lineno == 4 @@ -1342,7 +1344,8 @@ def foo(): input_bundle = make_input_bundle({"lib1.vy": lib1}) with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE + hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" assert e.value._hint == hint assert e.value.annotations[0].lineno == 6 From d899831ae5fd9ca405b3c955e6fdc812c259a1c3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 12 Apr 2024 11:07:38 -0400 Subject: [PATCH 9/9] add missing file --- tests/functional/syntax/modules/helpers.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tests/functional/syntax/modules/helpers.py diff --git a/tests/functional/syntax/modules/helpers.py b/tests/functional/syntax/modules/helpers.py new file mode 100644 index 0000000000..2a54073afb --- /dev/null +++ b/tests/functional/syntax/modules/helpers.py @@ -0,0 +1,3 @@ +NONREENTRANT_NOTE = ( + "\n note that use of the `@nonreentrant` decorator is also considered state access" +)