Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: typechecking for folding of literal ops #3201

Closed
wants to merge 15 commits into from
26 changes: 20 additions & 6 deletions tests/parser/syntax/test_minmax_value.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
import pytest

from vyper.exceptions import InvalidType
from vyper.exceptions import InvalidType, TypeMismatch

fail_list = [
"""
(
"""
@external
def foo():
a: address = min_value(address)
""",
"""
InvalidType,
),
(
"""
@external
def foo():
a: address = max_value(address)
""",
InvalidType,
),
(
"""
@external
def foo():
a: int256 = min(-1, max_value(int256) + 1)
""",
TypeMismatch,
),
]


@pytest.mark.parametrize("bad_code", fail_list)
def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code):
@pytest.mark.parametrize("bad_code,exc", fail_list)
def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc):

assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), InvalidType)
assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc)
97 changes: 96 additions & 1 deletion tests/parser/types/numbers/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import InvalidType
from vyper.exceptions import InvalidType, OverflowException, TypeMismatch
from vyper.utils import MemoryPositions


Expand Down Expand Up @@ -123,6 +123,8 @@ def zoo() -> uint256:
def test_custom_constants(get_contract):
code = """
X_VALUE: constant(uint256) = 33
Y_VALUE: constant(uint256) = 34
Z_VALUE: constant(int256) = 35

@external
def test() -> uint256:
Expand All @@ -131,11 +133,26 @@ def test() -> uint256:
@external
def test_add(a: uint256) -> uint256:
return X_VALUE + a

@external
def test_add_constants() -> uint256:
return X_VALUE + Y_VALUE

@external
def test_compare_constants() -> bool:
return X_VALUE > Y_VALUE

@external
def test_unary_constant() -> int256:
return -Z_VALUE
"""
c = get_contract(code)

assert c.test() == 33
assert c.test_add(7) == 40
assert c.test_add_constants() == 67
assert c.test_compare_constants() is False
assert c.test_unary_constant() == -35


# Would be nice to put this somewhere accessible, like in vyper.types or something
Expand Down Expand Up @@ -240,3 +257,81 @@ def contains(a: int128) -> bool:
assert c.contains(44) is True
assert c.contains(33) is True
assert c.contains(3) is False


fail_list = [
(
"""
a: constant(uint16) = 200

@external
def foo() -> int16:
return a - 201
""",
OverflowException,
),
(
"""
a: constant(uint16) = 200
b: constant(int248) = 100

@external
def foo() -> int16:
return a - b
""",
TypeMismatch,
),
(
"""
a: constant(int8) = 25
b: constant(uint8) = 38

@external
def foo() -> bool:
return a < b
""",
TypeMismatch,
),
(
"""
a: constant(uint256) = 16

@external
def foo() -> int8:
return -a
""",
OverflowException,
),
(
"""
a: constant(uint256) = 3
b: constant(uint256) = 4

@external
def foo() -> int8:
return a + b
""",
InvalidType,
),
(
"""
a: constant(uint256) = 1
b: constant(uint16) = a + 0
""",
InvalidType,
),
(
"""
a: constant(uint256) = 1
b: constant(uint16) = a
""",
InvalidType,
),
]


@pytest.mark.parametrize("bad_code,exc", fail_list)
def test_invalid_constant_folds(
assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc
):
assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc)
31 changes: 30 additions & 1 deletion vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from vyper.ast import nodes as vy_ast
from vyper.builtins.functions import DISPATCH_TABLE
from vyper.exceptions import UnfoldableNode, UnknownType
from vyper.exceptions import TypeMismatch, UnfoldableNode, UnknownType
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.utils import type_from_annotation
from vyper.utils import SizeLimits
Expand Down Expand Up @@ -63,7 +63,36 @@ def replace_literal_ops(vyper_module: vy_ast.Module) -> int:
node_types = (vy_ast.BoolOp, vy_ast.BinOp, vy_ast.UnaryOp, vy_ast.Compare)
for node in vyper_module.get_descendants(node_types, reverse=True):
try:
typ = None
if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)):
propagated_types = [
n._metadata["type"]
for n in (node.left, node.right) # type: ignore
if n._metadata.get("type") is not None
]
# if there are two propagated types, check for type mismatch
if len(propagated_types) == 2 and propagated_types[0] != propagated_types[1]:
raise TypeMismatch(
f"Unable to perform {node.op.description} on "
f"{propagated_types[0]} and {propagated_types[1]}",
node,
)
elif isinstance(node, vy_ast.BinOp) and len(propagated_types) >= 1:
# if there is one propagated type for vy_ast.BinOp, set
# folded node to that type for typechecking downstream
# this is not needed for vy_ast.Compare because it must be of BoolT
typ = propagated_types.pop()
elif isinstance(node, vy_ast.UnaryOp):
typ = node.operand._metadata.get("type") # type: ignore

# Propagate the type to the to-be-folded node to check bounds
# in validate_numeric_bounds
if typ is not None:
node._metadata["type"] = typ

new_node = node.evaluate()
if typ is not None:
new_node._metadata["type"] = typ
except UnfoldableNode:
continue

Expand Down
20 changes: 15 additions & 5 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,27 @@ def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None:
def _validate_numeric_bounds(
node: Union["BinOp", "UnaryOp"], value: Union[decimal.Decimal, int]
) -> None:
typ = node._metadata.get("type")
if isinstance(value, decimal.Decimal):
# this will change if/when we add more decimal types
lower, upper = SizeLimits.MIN_AST_DECIMAL, SizeLimits.MAX_AST_DECIMAL
elif isinstance(value, int):
lower, upper = SizeLimits.MIN_INT256, SizeLimits.MAX_UINT256
if typ is not None:
lower, upper = typ.ast_bounds
else:
lower, upper = SizeLimits.MIN_INT256, SizeLimits.MAX_UINT256
else:
raise CompilerPanic(f"Unexpected return type from {node._op}: {type(value)}")
if not lower <= value <= upper:
raise OverflowException(
f"Result of {node.op.description} ({value}) is outside bounds of all numeric types",
node,
)
if typ is not None:
raise OverflowException(
f"Result of {node.op.description} ({value}) is outside bounds of {typ}", node
)
else:
raise OverflowException(
f"Result of {node.op.description} ({value}) is outside bounds of all numeric types",
node,
)


class VyperNode:
Expand Down Expand Up @@ -961,6 +970,7 @@ def evaluate(self) -> ExprNode:
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

value = self.op._op(left.value, right.value)

_validate_numeric_bounds(self, value)
return type(left).from_node(self, value=value)

Expand Down