Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent range over decimal #3798

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
InvalidType,
IteratorException,
NamespaceCollision,
StateAccessViolation,
StructureException,
SyntaxException,
TypeMismatch,
Expand Down Expand Up @@ -714,7 +713,7 @@ def foo():
for i: uint256 in range(a):
pass
""",
StateAccessViolation,
StructureException,
),
(
"""
Expand All @@ -724,7 +723,7 @@ def foo():
for i: int128 in range(a,a-3):
pass
""",
StateAccessViolation,
StructureException,
),
# invalid argument length
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def foo() -> int128:
return 5
@external
def bar():
for i: int128 in range(self.foo(), self.foo() + 1):
for i: int128 in range(self.foo(), bound=100):
pass""",
"""
glob: int128
Expand All @@ -70,12 +70,6 @@ def bar():
for i: int128 in [1,2,3,4,self.foo()]:
pass""",
"""
@external
def foo():
x: int128 = 5
for i: int128 in range(x):
pass""",
"""
f:int128

@internal
Expand Down
55 changes: 37 additions & 18 deletions tests/functional/syntax/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
import pytest

from vyper import compiler
from vyper.exceptions import (
ArgumentException,
StateAccessViolation,
StructureException,
TypeMismatch,
UnknownType,
)
from vyper.exceptions import ArgumentException, StructureException, TypeMismatch, UnknownType

fail_list = [
(
Expand Down Expand Up @@ -44,8 +38,8 @@ def foo():
for _: uint256 in range(10, bound=x):
pass
""",
StateAccessViolation,
"Bound must be a literal",
StructureException,
"Bound must be a literal integer",
None,
"x",
),
Expand Down Expand Up @@ -106,7 +100,7 @@ def bar():
for i: uint256 in range(x):
pass
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -119,7 +113,7 @@ def bar():
for i: uint256 in range(0, x):
pass
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -132,7 +126,7 @@ def repeat(n: uint256) -> uint256:
pass
return n
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"n * 10",
Expand All @@ -145,7 +139,7 @@ def bar():
for i: uint256 in range(0, x + 1):
pass
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"x + 1",
Expand All @@ -170,7 +164,7 @@ def bar():
for i: uint256 in range(x, x):
pass
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -183,7 +177,7 @@ def foo():
for i: int128 in range(x, x + 10):
pass
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -196,7 +190,7 @@ def repeat(n: uint256) -> uint256:
pass
return x
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"n",
Expand All @@ -209,7 +203,7 @@ def foo(x: int128):
for i: int128 in range(x, x + y):
pass
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -221,7 +215,7 @@ def bar(x: uint256):
for i: uint256 in range(3, x):
pass
""",
StateAccessViolation,
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand Down Expand Up @@ -303,6 +297,31 @@ def foo():
"Did you mean 'uint96', or maybe 'uint8'?",
"uint9",
),
(
"""
@external
def foo():
for i:decimal in range(1.1, 2.2):
pass
""",
StructureException,
"Value must be a literal integer, unless a bound is specified",
None,
"1.1",
),
(
"""
@external
def foo():
x:decimal = 1.1
for i:decimal in range(x, x + 2.0, bound=10.1):
pass
""",
StructureException,
"Bound must be a literal integer",
None,
"10.1",
),
]

for_code_regex = re.compile(r"for .+ in (.*):", re.DOTALL)
Expand Down
9 changes: 2 additions & 7 deletions tests/unit/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import pytest

from vyper.ast import parse_to_ast
from vyper.exceptions import (
ArgumentException,
ImmutableViolation,
StateAccessViolation,
TypeMismatch,
)
from vyper.exceptions import ArgumentException, ImmutableViolation, StructureException, TypeMismatch
from vyper.semantics.analysis import validate_semantics


Expand Down Expand Up @@ -88,7 +83,7 @@ def bar(n: uint256):
x += i
"""
vyper_module = parse_to_ast(code)
with pytest.raises(StateAccessViolation):
with pytest.raises(StructureException):
validate_semantics(vyper_module, dummy_input_bundle)


Expand Down
15 changes: 10 additions & 5 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
EventT,
FlagT,
HashMapT,
IntegerT,
SArrayT,
StringT,
StructT,
Expand Down Expand Up @@ -513,6 +514,9 @@ def visit_For(self, node):
iter_var = None
if isinstance(node.iter, vy_ast.Call):
self._analyse_range_iter(node.iter, target_type)

# sanity check the postcondition of analyse_range_iter
assert isinstance(target_type, IntegerT)
else:
iter_var = self._analyse_list_iter(node.iter, target_type)

Expand All @@ -522,6 +526,7 @@ def visit_For(self, node):
self.namespace[target_name] = VarInfo(
target_type, modifiability=Modifiability.RUNTIME_CONSTANT
)

self.expr_visitor.visit(node.target.target, target_type)

for stmt in node.body:
Expand Down Expand Up @@ -870,17 +875,17 @@ def _validate_range_call(node: vy_ast.Call):
bound = kwargs["bound"]
if bound.has_folded_value:
bound = bound.get_folded_value()
if not isinstance(bound, vy_ast.Num):
raise StateAccessViolation("Bound must be a literal", bound)
if not isinstance(bound, vy_ast.Int):
raise StructureException("Bound must be a literal integer", bound)
if bound.value <= 0:
raise StructureException("Bound must be at least 1", bound)
if isinstance(start, vy_ast.Num) and isinstance(end, vy_ast.Num):
if isinstance(start, vy_ast.Int) and isinstance(end, vy_ast.Int):
error = "Please remove the `bound=` kwarg when using range with constants"
raise StructureException(error, bound)
else:
for arg in (start, end):
if not isinstance(arg, vy_ast.Num):
if not isinstance(arg, vy_ast.Int):
error = "Value must be a literal integer, unless a bound is specified"
raise StateAccessViolation(error, arg)
raise StructureException(error, arg)
if end.value <= start.value:
raise StructureException("End must be greater than start", end)
Loading