Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: iterator modification analysis #3764

Merged
merged 27 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bb6df4f
fix iterator modification analysis
charles-cooper Feb 8, 2024
aa4c582
impose topsort on function analysis
charles-cooper Feb 8, 2024
9454687
refactor and clean up FunctionAnalyzer.visit_For
charles-cooper Feb 9, 2024
b997fe0
add a comment
charles-cooper Feb 9, 2024
a057f39
fix missing enter_scope()
charles-cooper Feb 9, 2024
d8fa41a
fix typechecker for darray, sarray
charles-cooper Feb 9, 2024
c1d693d
add tests for repros from issue
charles-cooper Feb 9, 2024
0021efa
add tests for iterators imported from modules
charles-cooper Feb 9, 2024
6005532
add test for topsort analysis
charles-cooper Feb 9, 2024
f3f683c
fix topsort for function calls
charles-cooper Feb 9, 2024
dc0908c
fix type comparison for SelfT
charles-cooper Feb 9, 2024
4ffd3aa
fix topsort (again!)
charles-cooper Feb 9, 2024
7729ab7
refactor: improve the API for ContractFunctionT, protect some private…
charles-cooper Feb 11, 2024
40ed362
Merge branch 'master' into fix/iterator_analysis
charles-cooper Feb 11, 2024
049f6e8
remove protect_analysed
charles-cooper Feb 11, 2024
d8353ae
fix: struct touching
charles-cooper Feb 11, 2024
35a9bca
yeet VarAttributeInfo
charles-cooper Feb 11, 2024
9c2af79
fix mypy
charles-cooper Feb 12, 2024
3d32b76
fix more complicated case
charles-cooper Feb 12, 2024
0a5376c
fix bugs
charles-cooper Feb 12, 2024
5e90066
fix mypy
charles-cooper Feb 12, 2024
b647e97
add more tests
charles-cooper Feb 12, 2024
111333e
remove a comment
charles-cooper Feb 12, 2024
274de10
refactor get_variable_access
charles-cooper Feb 12, 2024
b4e7390
remove attribute_chain, use explicit traversal of the attribute/subsc…
charles-cooper Feb 12, 2024
68faed2
fix: while -> do while
charles-cooper Feb 12, 2024
591b97b
add some subscript/attribute tests for module uses
charles-cooper Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion tests/functional/codegen/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from vyper.compiler import compile_code
from vyper.exceptions import (
ArgumentException,
ImmutableViolation,
Expand Down Expand Up @@ -841,6 +842,59 @@ def foo():
]


# TODO: move these to tests/functional/syntax
@pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names)
def test_bad_code(assert_compile_failed, get_contract, code, err):
assert_compile_failed(lambda: get_contract(code), err)
with pytest.raises(err):
compile_code(code)


def test_iterator_modification_module_attribute(make_input_bundle):
# test modifying iterator via attribute
lib1 = """
queue: DynArray[uint256, 5]
"""
main = """
import lib1

initializes: lib1

@external
def foo():
for i: uint256 in lib1.queue:
lib1.queue.pop()
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot modify loop variable `queue`"


def test_iterator_modification_module_function_call(make_input_bundle):
lib1 = """
queue: DynArray[uint256, 5]

@internal
def popqueue():
self.queue.pop()
"""
main = """
import lib1

initializes: lib1

@external
def foo():
for i: uint256 in lib1.queue:
lib1.popqueue()
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot modify loop variable `queue`"
105 changes: 105 additions & 0 deletions tests/unit/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,111 @@ def baz():
validate_semantics(vyper_module, dummy_input_bundle)


def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle):
# test the analysis works no matter the order of functions
code = """
a: uint256[3]

@internal
def baz():
for i: uint256 in self.a:
self.bar()

@internal
def bar():
self.foo()

@internal
def foo():
self.a[0] = 1
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `a`"


def test_modify_iterator_through_struct(dummy_input_bundle):
# GH issue 3429
code = """
struct A:
iter: DynArray[uint256, 5]

a: A

@external
def foo():
self.a.iter = [1, 2, 3]
for i: uint256 in self.a.iter:
self.a = A({iter: [1, 2, 3, 4]})
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `a`"


def test_modify_iterator_complex_expr(dummy_input_bundle):
# GH issue 3429
# avoid false positive!
code = """
a: DynArray[uint256, 5]
b: uint256[10]

@external
def foo():
self.a = [1, 2, 3]
for i: uint256 in self.a:
self.b[self.a[1]] = i
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)


def test_modify_iterator_siblings(dummy_input_bundle):
# test we can modify siblings in an access tree
code = """
struct Foo:
a: uint256[2]
b: uint256

f: Foo

@external
def foo():
for i: uint256 in self.f.a:
self.f.b += i
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)


def test_modify_subscript_barrier(dummy_input_bundle):
# test that Subscript nodes are a barrier for analysis
code = """
struct Foo:
x: uint256[2]
y: uint256

struct Bar:
f: Foo[2]

b: Bar

@external
def foo():
for i: uint256 in self.b.f[1].x:
self.b.f[0].y += i
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `b`"


iterator_inference_codes = [
"""
@external
Expand Down
8 changes: 4 additions & 4 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ class Call(ExprNode):

class keyword(VyperNode): ...

class Attribute(VyperNode):
class Attribute(ExprNode):
attr: str = ...
value: ExprNode = ...

class Subscript(VyperNode):
slice: VyperNode = ...
value: VyperNode = ...
class Subscript(ExprNode):
slice: ExprNode = ...
value: ExprNode = ...

class Assign(VyperNode): ...

Expand Down
58 changes: 30 additions & 28 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,24 +263,6 @@ def parse_Attribute(self):
if addr.value == "address": # for `self.code`
return IRnode.from_list(["~selfcode"], typ=BytesT(0))
return IRnode.from_list(["~extcode", addr], typ=BytesT(0))
# self.x: global attribute
elif (varinfo := self.expr._expr_info.var_info) is not None:
if varinfo.is_constant:
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

location = data_location_to_address_space(
varinfo.location, self.context.is_ctor_context
)

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

# Reserved keywords
elif (
Expand Down Expand Up @@ -336,17 +318,37 @@ def parse_Attribute(self):
"chain.id is unavailable prior to istanbul ruleset", self.expr
)
return IRnode.from_list(["chainid"], typ=UINT256_T)

# Other variables
else:
sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

# self.x: global attribute
if (varinfo := self.expr._expr_info.var_info) is not None:
if varinfo.is_constant:
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

location = data_location_to_address_space(
varinfo.location, self.context.is_ctor_context
)

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

def parse_Subscript(self):
sub = Expr(self.expr.value, self.context).ir_node
Expand Down
60 changes: 40 additions & 20 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vyper.exceptions import CompilerPanic, StructureException
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.primitives import SelfT
Fixed Show fixed Hide fixed
from vyper.utils import OrderedSet, StringEnum

if TYPE_CHECKING:
Expand Down Expand Up @@ -193,6 +194,17 @@
return res


@dataclass(frozen=True)
class VarAccess:
variable: VarInfo
attrs: tuple[str, ...]

def contains(self, other):
# VarAccess("v", ("a")) `contains` VarAccess("v", ("a", "b", "c"))
sub_attrs = other.attrs[: len(self.attrs)]
return self.variable == other.variable and sub_attrs == self.attrs


@dataclass
class ExprInfo:
"""
Expand All @@ -204,9 +216,8 @@
module_info: Optional[ModuleInfo] = None
location: DataLocation = DataLocation.UNSET
modifiability: Modifiability = Modifiability.MODIFIABLE

# the chain of attribute parents for this expr
attribute_chain: list["ExprInfo"] = field(default_factory=list)
attr: Optional[str] = None

def __post_init__(self):
should_match = ("typ", "location", "modifiability")
Expand All @@ -215,48 +226,57 @@
if getattr(self.var_info, attr) != getattr(self, attr):
raise CompilerPanic("Bad analysis: non-matching {attr}: {self}")

self._writes: OrderedSet[VarInfo] = OrderedSet()
self._reads: OrderedSet[VarInfo] = OrderedSet()
self.attribute_chain = self.attribute_chain or []

self._writes: OrderedSet[VarAccess] = OrderedSet()
self._reads: OrderedSet[VarAccess] = OrderedSet()

# find exprinfo in the attribute chain which has a varinfo
# e.x. `x` will return varinfo for `x`
# `module.foo` will return varinfo for `module.foo`
# `self.my_struct.x.y` will return varinfo for `self.my_struct`
def get_root_varinfo(self) -> Optional[VarInfo]:
for expr_info in self.attribute_chain + [self]:
if expr_info.var_info is not None:
return expr_info.var_info
# `self.my_struct.x.y` will return varinfo for `self.my_struct.x.y`
def get_variable_access(self) -> Optional[VarAccess]:
chain = self.attribute_chain + [self]
for i, expr_info in enumerate(chain):
varinfo = expr_info.var_info
if varinfo is None or isinstance(varinfo.typ, SelfT):
continue

attrs = []
for expr_info in chain[i:]:
Fixed Show fixed Hide fixed
if expr_info.attr is None:
continue
attrs.append(expr_info.attr)
return VarAccess(varinfo, tuple(attrs))

return None

@classmethod
def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo":
def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo":
return cls(
var_info.typ,
var_info=var_info,
location=var_info.location,
modifiability=var_info.modifiability,
attribute_chain=attribute_chain or [],
**kwargs,
)

@classmethod
def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo":
def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo":
modifiability = Modifiability.RUNTIME_CONSTANT
if module_info.ownership >= ModuleOwnership.USES:
modifiability = Modifiability.MODIFIABLE

return cls(
module_info.module_t,
module_info=module_info,
modifiability=modifiability,
attribute_chain=attribute_chain or [],
module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs
)

def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo":
def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo":
"""
Return a copy of the ExprInfo but with the type set to something else
"""
to_copy = ("location", "modifiability")
fields = {k: getattr(self, k) for k in to_copy}
if attribute_chain is not None:
fields["attribute_chain"] = attribute_chain
return self.__class__(typ=typ, **fields)
for t in to_copy:
assert t not in kwargs
return self.__class__(typ=typ, **fields, **kwargs)
Loading
Loading