Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: recursion in uses analysis for nonreentrant functions #3971

Merged
80 changes: 57 additions & 23 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,52 +1300,86 @@ def foo():
assert e.value._hint == "try importing lib1 first"


def test_nonreentrant_exports(make_input_bundle):
@pytest.fixture
def nonreentrant_library_bundle(make_input_bundle):
# test simple case
lib1 = """
# lib1.vy
@external
@internal
@nonreentrant
def bar():
pass

# lib1.vy
@external
@nonreentrant
def ext_bar():
pass
"""
main = """
# test case with recursion
lib2 = """
@internal
def bar():
self.baz()

@external
def ext_bar():
self.baz()

@nonreentrant
@internal
def baz():
return
"""
# test case with nested recursion
lib3 = """
import lib1
uses: lib1

exports: lib1.bar # line 4
@internal
def bar():
lib1.bar()

@external
def ext_bar():
lib1.bar()
"""

return make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3})


@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
def test_nonreentrant_exports(nonreentrant_library_bundle, lib):
main = f"""
import {lib}

exports: {lib}.ext_bar # line 4

@external
def foo():
pass
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
compile_code(main, input_bundle=nonreentrant_library_bundle)
assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE
hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 4


def test_internal_nonreentrant_import(make_input_bundle):
lib1 = """
# lib1.vy
@internal
@nonreentrant
def bar():
pass
"""
main = """
import lib1
@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
def test_internal_nonreentrant_import(nonreentrant_library_bundle, lib):
main = f"""
import {lib}

@external
def foo():
lib1.bar() # line 6
{lib}.bar() # line 6
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
compile_code(main, input_bundle=nonreentrant_library_bundle)
assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE

hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 6
6 changes: 5 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def get_variable_accesses(self):
return self._variable_reads | self._variable_writes

def uses_state(self):
return self.nonreentrant or uses_state(self.get_variable_accesses())
return (
self.nonreentrant
or uses_state(self.get_variable_accesses())
or any(f.nonreentrant for f in self.reachable_internal_functions)
)

def get_used_modules(self):
# _used_modules is populated during analysis
Expand Down
Loading