Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: fix uses analysis for nonreentrant functions #3927

Merged
merged 12 commits into from
Apr 12, 2024
78 changes: 78 additions & 0 deletions tests/functional/codegen/modules/test_nonreentrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable
implements: Foo
@external
@nonreentrant
def foo() -> uint256:
return 5
"""
main = """
import lib1
initializes: lib1
exports: lib1.foo
@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw
@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 5
with tx_failed():
c.re_enter()


def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable
implements: Foo
@external
def foo() -> uint256:
return self._safe_fn()
@internal
@nonreentrant
def _safe_fn() -> uint256:
return 10
"""
main = """
import lib1
initializes: lib1
exports: lib1.foo
@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw
@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 10
with tx_failed():
c.re_enter()
6 changes: 4 additions & 2 deletions tests/functional/syntax/modules/test_exports.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
from vyper.compiler import compile_code
from vyper.exceptions import ImmutableViolation, NamespaceCollision, StructureException

from .helpers import NONREENTRANT_NOTE


def test_exports_no_uses(make_input_bundle):
lib1 = """
@@ -21,7 +23,7 @@ def get_counter() -> uint256:
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -40,7 +42,7 @@ def test_exports_no_uses_variable(make_input_bundle):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
77 changes: 65 additions & 12 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@
UndeclaredDefinition,
)

from .helpers import NONREENTRANT_NOTE


def test_initialize_uses(make_input_bundle):
lib1 = """
@@ -413,7 +415,7 @@ def foo():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -450,7 +452,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -491,7 +493,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -536,7 +538,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -571,7 +573,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -612,7 +614,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -656,7 +658,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -695,7 +697,7 @@ def foo():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -734,7 +736,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
@@ -776,7 +778,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
@@ -819,7 +821,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
@@ -853,7 +855,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
@@ -1296,3 +1298,54 @@ def foo():
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`"
assert e.value._hint == "try importing lib1 first"


def test_nonreentrant_exports(make_input_bundle):
lib1 = """
# lib1.vy
@external
@nonreentrant
def bar():
pass
"""
main = """
import lib1
exports: lib1.bar # line 4
@external
def foo():
pass
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 4


def test_internal_nonreentrant_import(make_input_bundle):
lib1 = """
# lib1.vy
@internal
@nonreentrant
def bar():
pass
"""
main = """
import lib1
@external
def foo():
lib1.bar() # line 6
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 6
37 changes: 21 additions & 16 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# CMC 2024-02-03 TODO: split me into function.py and expr.py
# CMC 2024-02-03 TODO: rename me to function.py

import contextlib
from typing import Optional
@@ -35,6 +35,7 @@
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
uses_state,
validate_expected_type,
)
from vyper.semantics.data_locations import DataLocation
@@ -64,30 +65,30 @@
from vyper.semantics.types.utils import type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
def analyze_functions(vy_module: vy_ast.Module) -> None:
"""Analyzes a vyper ast and validates the function bodies"""
err_list = ExceptionList()

for node in vy_module.get_children(vy_ast.FunctionDef):
_validate_function_r(vy_module, node, err_list)
_analyze_function_r(vy_module, node, err_list)

for node in vy_module.get_children(vy_ast.VariableDecl):
if not node.is_public:
continue
_validate_function_r(vy_module, node._expanded_getter, err_list)
_analyze_function_r(vy_module, node._expanded_getter, err_list)

err_list.raise_if_not_empty()


def _validate_function_r(
def _analyze_function_r(
vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList
):
func_t = node._metadata["func_type"]

for call_t in func_t.called_functions:
if isinstance(call_t, ContractFunctionT):
assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy
_validate_function_r(vy_module, call_t.ast_def, err_list)
_analyze_function_r(vy_module, call_t.ast_def, err_list)

namespace = get_namespace()

@@ -267,7 +268,14 @@ def check_module_uses(node: vy_ast.ExprNode) -> Optional[ModuleInfo]:

for module_info in module_infos:
if module_info.ownership < ModuleOwnership.USES:
msg = f"Cannot access `{module_info.alias}` state!"
msg = f"Cannot access `{module_info.alias}` state!\n note that"
# CMC 2024-04-12 add UX note about nonreentrant. might be nice
# in the future to be more specific about exactly which state is
# used, although that requires threading a bit more context into
# this function.
msg += " use of the `@nonreentrant` decorator is also considered"
msg += " state access"

hint = f"add `uses: {module_info.alias}` or "
hint += f"`initializes: {module_info.alias}` as "
hint += "a top-level statement to your contract"
@@ -443,10 +451,7 @@ def _handle_modification(self, target: vy_ast.ExprNode):

info._writes.add(var_access)

def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode):
if not var_access.variable.is_state_variable():
return

def _handle_module_access(self, target: vy_ast.ExprNode):
root_module_info = check_module_uses(target)

if root_module_info is not None:
@@ -682,9 +687,9 @@ def visit(self, node, typ):
msg += f" `{var.decl_node.target.id}`"
raise ImmutableViolation(msg, var.decl_node, node)

variable_accesses = info._writes | info._reads
for s in variable_accesses:
self.function_analyzer._handle_module_access(s, node)
var_accesses = info._writes | info._reads
if uses_state(var_accesses):
self.function_analyzer._handle_module_access(node)

self.func.mark_variable_writes(info._writes)
self.func.mark_variable_reads(info._reads)
@@ -787,8 +792,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
if self.function_analyzer:
self._check_call_mutability(func_type.mutability)

for s in func_type.get_variable_accesses():
self.function_analyzer._handle_module_access(s, node.func)
if func_type.uses_state():
self.function_analyzer._handle_module_access(node.func)

if func_type.is_deploy and not self.func.is_deploy:
raise CallViolation(
17 changes: 8 additions & 9 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@
from vyper.semantics.analysis.constant_folding import constant_fold
from vyper.semantics.analysis.getters import generate_public_variable_getters
from vyper.semantics.analysis.import_graph import ImportGraph
from vyper.semantics.analysis.local import ExprVisitor, check_module_uses, validate_functions
from vyper.semantics.analysis.local import ExprVisitor, analyze_functions, check_module_uses
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
@@ -102,7 +102,7 @@ def _analyze_module_r(
# if this is an interface, the function is already validated
# in `ContractFunction.from_vyi()`
if not is_interface:
validate_functions(module_ast)
analyze_functions(module_ast)
analyzer.validate_initialized_modules()
analyzer.validate_used_modules()

@@ -557,14 +557,13 @@ def visit_ExportsDecl(self, node):
with tag_exceptions(item): # tag with specific item
self._self_t.typ.add_member(func_t.name, func_t)

funcs.append(func_t)
funcs.append(func_t)

# check module uses
var_accesses = func_t.get_variable_accesses()
if any(s.variable.is_state_variable() for s in var_accesses):
module_info = check_module_uses(item)
assert module_info is not None # guaranteed by above checks
used_modules.add(module_info)
# check module uses
if func_t.uses_state():
module_info = check_module_uses(item)
assert module_info is not None # guaranteed by above checks
used_modules.add(module_info)

node._metadata["exports_info"] = ExportsInfo(funcs, used_modules)

8 changes: 6 additions & 2 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Callable, List
from typing import Callable, Iterable, List

from vyper import ast as vy_ast
from vyper.exceptions import (
@@ -17,7 +17,7 @@
ZeroDivisionException,
)
from vyper.semantics import types
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
@@ -48,6 +48,10 @@ def _validate_op(node, types_list, validation_fn_name):
raise err_list[0]


def uses_state(var_accesses: Iterable[VarAccess]) -> bool:
return any(s.variable.is_state_variable() for s in var_accesses)


class _ExprAnalyser:
"""
Node type-checker class.
5 changes: 5 additions & 0 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
uses_state,
validate_expected_type,
)
from vyper.semantics.data_locations import DataLocation
@@ -163,7 +164,11 @@ def get_variable_writes(self):
def get_variable_accesses(self):
return self._variable_reads | self._variable_writes

def uses_state(self):
return self.nonreentrant or uses_state(self.get_variable_accesses())

def get_used_modules(self):
# _used_modules is populated during analysis
return self._used_modules

def mark_used_module(self, module_info):