Skip to content

Commit

Permalink
use validate_expected_type
Browse files Browse the repository at this point in the history
  • Loading branch information
trocher committed Feb 21, 2024
1 parent 0aa4fa6 commit 5ebcc59
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 30 deletions.
41 changes: 17 additions & 24 deletions tests/functional/syntax/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
import pytest

from vyper import compiler
from vyper.exceptions import (
ArgumentException,
StateAccessViolation,
StructureException,
TypeCheckFailure,
TypeMismatch,
UnknownType,
)
from vyper.exceptions import ArgumentException, StructureException, TypeMismatch, UnknownType

fail_list = [
(
Expand Down Expand Up @@ -45,7 +38,7 @@ def foo():
for _: uint256 in range(10, bound=x):
pass
""",
StateAccessViolation,
TypeMismatch,
"Bound must be a literal integer",
None,
"x",
Expand Down Expand Up @@ -107,7 +100,7 @@ def bar():
for i: uint256 in range(x):
pass
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -120,7 +113,7 @@ def bar():
for i: uint256 in range(0, x):
pass
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -133,7 +126,7 @@ def repeat(n: uint256) -> uint256:
pass
return n
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"n * 10",
Expand All @@ -146,7 +139,7 @@ def bar():
for i: uint256 in range(0, x + 1):
pass
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"x + 1",
Expand All @@ -171,7 +164,7 @@ def bar():
for i: uint256 in range(x, x):
pass
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -184,7 +177,7 @@ def foo():
for i: int128 in range(x, x + 10):
pass
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -197,7 +190,7 @@ def repeat(n: uint256) -> uint256:
pass
return x
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"n",
Expand All @@ -210,7 +203,7 @@ def foo(x: int128):
for i: int128 in range(x, x + y):
pass
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand All @@ -222,7 +215,7 @@ def bar(x: uint256):
for i: uint256 in range(3, x):
pass
""",
StateAccessViolation,
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
Expand Down Expand Up @@ -311,10 +304,10 @@ def foo():
for i:decimal in range(1.1, 2.2):
pass
""",
TypeCheckFailure,
"Range can only be defined over an integer type",
TypeMismatch,
"Value must be a literal integer, unless a bound is specified",
None,
"decimal",
"1.1",
),
(
"""
Expand All @@ -324,10 +317,10 @@ def foo():
for i:decimal in range(x, x + 2.0, bound=10.1):
pass
""",
TypeCheckFailure,
"Range can only be defined over an integer type",
TypeMismatch,
"Bound must be a literal integer",
None,
"decimal",
"10.1",
),
]

Expand Down
15 changes: 9 additions & 6 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,10 @@ def visit_For(self, node):
target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY)

iter_var = None
is_range = False
if isinstance(node.iter, vy_ast.Call):
if not isinstance(target_type, IntegerT):
raise TypeCheckFailure(
"Range can only be defined over an integer type", node.target.annotation
)
self._analyse_range_iter(node.iter, target_type)
is_range = True
else:
iter_var = self._analyse_list_iter(node.iter, target_type)

Expand All @@ -527,6 +525,11 @@ def visit_For(self, node):
self.namespace[target_name] = VarInfo(
target_type, modifiability=Modifiability.RUNTIME_CONSTANT
)
# ideally should be performed before calling _analyse_range_iter
# but there is a dependence on the namespace update
if is_range:
validate_expected_type(node.target.target, IntegerT.any())

self.expr_visitor.visit(node.target.target, target_type)

for stmt in node.body:
Expand Down Expand Up @@ -876,7 +879,7 @@ def _validate_range_call(node: vy_ast.Call):
if bound.has_folded_value:
bound = bound.get_folded_value()
if not isinstance(bound, vy_ast.Int):
raise StateAccessViolation("Bound must be a literal integer", bound)
raise TypeMismatch("Bound must be a literal integer", bound)
if bound.value <= 0:
raise StructureException("Bound must be at least 1", bound)
if isinstance(start, vy_ast.Int) and isinstance(end, vy_ast.Int):
Expand All @@ -886,6 +889,6 @@ def _validate_range_call(node: vy_ast.Call):
for arg in (start, end):
if not isinstance(arg, vy_ast.Int):
error = "Value must be a literal integer, unless a bound is specified"
raise StateAccessViolation(error, arg)
raise TypeMismatch(error, arg)
if end.value <= start.value:
raise StructureException("End must be greater than start", end)

0 comments on commit 5ebcc59

Please sign in to comment.