diff --git a/tests/parser/syntax/test_minmax_value.py b/tests/parser/syntax/test_minmax_value.py index f71dc92e40..d81b0c698d 100644 --- a/tests/parser/syntax/test_minmax_value.py +++ b/tests/parser/syntax/test_minmax_value.py @@ -1,22 +1,36 @@ import pytest -from vyper.exceptions import InvalidType +from vyper.exceptions import InvalidType, TypeMismatch fail_list = [ - """ + ( + """ @external def foo(): a: address = min_value(address) """, - """ + InvalidType, + ), + ( + """ @external def foo(): a: address = max_value(address) """, + InvalidType, + ), + ( + """ +@external +def foo(): + a: int256 = min(-1, max_value(int256) + 1) + """, + TypeMismatch, + ), ] -@pytest.mark.parametrize("bad_code", fail_list) -def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code): +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), InvalidType) + assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) diff --git a/tests/parser/types/numbers/test_constants.py b/tests/parser/types/numbers/test_constants.py index ea56968d14..5c713c56a9 100644 --- a/tests/parser/types/numbers/test_constants.py +++ b/tests/parser/types/numbers/test_constants.py @@ -4,7 +4,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import InvalidType, OverflowException, TypeMismatch from vyper.utils import MemoryPositions @@ -123,6 +123,8 @@ def zoo() -> uint256: def test_custom_constants(get_contract): code = """ X_VALUE: constant(uint256) = 33 +Y_VALUE: constant(uint256) = 34 +Z_VALUE: constant(int256) = 35 @external def test() -> uint256: @@ -131,11 +133,26 @@ def test() -> uint256: @external def test_add(a: uint256) -> uint256: return X_VALUE + a + +@external +def test_add_constants() -> uint256: + return X_VALUE + Y_VALUE + +@external +def test_compare_constants() -> bool: + return X_VALUE > Y_VALUE + +@external +def test_unary_constant() -> int256: + return -Z_VALUE """ c = get_contract(code) assert c.test() == 33 assert c.test_add(7) == 40 + assert c.test_add_constants() == 67 + assert c.test_compare_constants() is False + assert c.test_unary_constant() == -35 # Would be nice to put this somewhere accessible, like in vyper.types or something @@ -240,3 +257,81 @@ def contains(a: int128) -> bool: assert c.contains(44) is True assert c.contains(33) is True assert c.contains(3) is False + + +fail_list = [ + ( + """ +a: constant(uint16) = 200 + +@external +def foo() -> int16: + return a - 201 + """, + OverflowException, + ), + ( + """ +a: constant(uint16) = 200 +b: constant(int248) = 100 + +@external +def foo() -> int16: + return a - b + """, + TypeMismatch, + ), + ( + """ +a: constant(int8) = 25 +b: constant(uint8) = 38 + +@external +def foo() -> bool: + return a < b + """, + TypeMismatch, + ), + ( + """ +a: constant(uint256) = 16 + +@external +def foo() -> int8: + return -a + """, + OverflowException, + ), + ( + """ +a: constant(uint256) = 3 +b: constant(uint256) = 4 + +@external +def foo() -> int8: + return a + b + """, + InvalidType, + ), + ( + """ +a: constant(uint256) = 1 +b: constant(uint16) = a + 0 + """, + InvalidType, + ), + ( + """ +a: constant(uint256) = 1 +b: constant(uint16) = a + """, + InvalidType, + ), +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_invalid_constant_folds( + assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc +): + assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index cd0fcc3c55..b4c82f8e28 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -3,7 +3,7 @@ from vyper.ast import nodes as vy_ast from vyper.builtins.functions import DISPATCH_TABLE -from vyper.exceptions import UnfoldableNode, UnknownType +from vyper.exceptions import TypeMismatch, UnfoldableNode, UnknownType from vyper.semantics.types.base import VyperType from vyper.semantics.types.utils import type_from_annotation from vyper.utils import SizeLimits @@ -63,7 +63,36 @@ def replace_literal_ops(vyper_module: vy_ast.Module) -> int: node_types = (vy_ast.BoolOp, vy_ast.BinOp, vy_ast.UnaryOp, vy_ast.Compare) for node in vyper_module.get_descendants(node_types, reverse=True): try: + typ = None + if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)): + propagated_types = [ + n._metadata["type"] + for n in (node.left, node.right) # type: ignore + if n._metadata.get("type") is not None + ] + # if there are two propagated types, check for type mismatch + if len(propagated_types) == 2 and propagated_types[0] != propagated_types[1]: + raise TypeMismatch( + f"Unable to perform {node.op.description} on " + f"{propagated_types[0]} and {propagated_types[1]}", + node, + ) + elif isinstance(node, vy_ast.BinOp) and len(propagated_types) >= 1: + # if there is one propagated type for vy_ast.BinOp, set + # folded node to that type for typechecking downstream + # this is not needed for vy_ast.Compare because it must be of BoolT + typ = propagated_types.pop() + elif isinstance(node, vy_ast.UnaryOp): + typ = node.operand._metadata.get("type") # type: ignore + + # Propagate the type to the to-be-folded node to check bounds + # in validate_numeric_bounds + if typ is not None: + node._metadata["type"] = typ + new_node = node.evaluate() + if typ is not None: + new_node._metadata["type"] = typ except UnfoldableNode: continue diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index cc42acdf61..1b4303a53e 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -194,18 +194,27 @@ def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None: def _validate_numeric_bounds( node: Union["BinOp", "UnaryOp"], value: Union[decimal.Decimal, int] ) -> None: + typ = node._metadata.get("type") if isinstance(value, decimal.Decimal): # this will change if/when we add more decimal types lower, upper = SizeLimits.MIN_AST_DECIMAL, SizeLimits.MAX_AST_DECIMAL elif isinstance(value, int): - lower, upper = SizeLimits.MIN_INT256, SizeLimits.MAX_UINT256 + if typ is not None: + lower, upper = typ.ast_bounds + else: + lower, upper = SizeLimits.MIN_INT256, SizeLimits.MAX_UINT256 else: raise CompilerPanic(f"Unexpected return type from {node._op}: {type(value)}") if not lower <= value <= upper: - raise OverflowException( - f"Result of {node.op.description} ({value}) is outside bounds of all numeric types", - node, - ) + if typ is not None: + raise OverflowException( + f"Result of {node.op.description} ({value}) is outside bounds of {typ}", node + ) + else: + raise OverflowException( + f"Result of {node.op.description} ({value}) is outside bounds of all numeric types", + node, + ) class VyperNode: @@ -961,6 +970,7 @@ def evaluate(self) -> ExprNode: raise UnfoldableNode("Node contains invalid field(s) for evaluation") value = self.op._op(left.value, right.value) + _validate_numeric_bounds(self, value) return type(left).from_node(self, value=value)