Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor assert_tx_failed into a context #3706

Merged
merged 17 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/testing-contracts-ethtester.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ To test events and failed transactions we expand our simple storage contract to

Next, we take a look at the two fixtures that will allow us to read the event logs and to check for failed transactions.

.. literalinclude:: ../tests/base_conftest.py
.. literalinclude:: ../tests/conftest.py
:language: python
:pyobject: assert_tx_failed
:pyobject: tx_failed

The fixture to assert failed transactions defaults to check for a ``TransactionFailed`` exception, but can be used to check for different exceptions too, as shown below. Also note that the chain gets reverted to the state before the failed transaction.

Expand Down
28 changes: 6 additions & 22 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from contextlib import contextmanager
from functools import wraps

import hypothesis
Expand Down Expand Up @@ -411,23 +412,6 @@ def assert_compile_failed(function_to_test, exception=Exception):
return assert_compile_failed


# TODO this should not be a fixture
@pytest.fixture
def search_for_sublist():
def search_for_sublist(ir, sublist):
_list = ir.to_list() if hasattr(ir, "to_list") else ir
if _list == sublist:
return True
if isinstance(_list, list):
for i in _list:
ret = search_for_sublist(i, sublist)
if ret is True:
return ret
return False

return search_for_sublist


@pytest.fixture
def create2_address_of(keccak):
def _f(_addr, _salt, _initcode):
Expand Down Expand Up @@ -484,16 +468,16 @@ def get_logs(tx_hash, c, event_name):
return get_logs


# TODO replace me with function like `with anchor_state()`
@pytest.fixture(scope="module")
def assert_tx_failed(tester):
def assert_tx_failed(function_to_test, exception=TransactionFailed, exc_text=None):
def tx_failed(tester):
@contextmanager
def fn(exception=TransactionFailed, exc_text=None):
snapshot_id = tester.take_snapshot()
with pytest.raises(exception) as excinfo:
function_to_test()
yield excinfo
tester.revert_to_snapshot(snapshot_id)
if exc_text:
# TODO test equality
assert exc_text in str(excinfo.value), (exc_text, excinfo.value)

return assert_tx_failed
return fn
25 changes: 15 additions & 10 deletions tests/functional/builtins/codegen/test_abi_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def abi_decode(x: Bytes[32]) -> uint256:
b"\x01" * 96, # Length of byte array is beyond size bound of output type
],
)
def test_clamper(get_contract, assert_tx_failed, input_):
def test_clamper(get_contract, tx_failed, input_):
contract = """
@external
def abi_decode(x: Bytes[96]) -> (uint256, uint256):
Expand All @@ -341,10 +341,11 @@ def abi_decode(x: Bytes[96]) -> (uint256, uint256):
return a, b
"""
c = get_contract(contract)
assert_tx_failed(lambda: c.abi_decode(input_))
with tx_failed():
c.abi_decode(input_)


def test_clamper_nested_uint8(get_contract, assert_tx_failed):
def test_clamper_nested_uint8(get_contract, tx_failed):
# check that _abi_decode clamps on word-types even when it is in a nested expression
# decode -> validate uint8 -> revert if input >= 256 -> cast back to uint256
contract = """
Expand All @@ -355,10 +356,11 @@ def abi_decode(x: uint256) -> uint256:
"""
c = get_contract(contract)
assert c.abi_decode(255) == 255
assert_tx_failed(lambda: c.abi_decode(256))
with tx_failed():
c.abi_decode(256)


def test_clamper_nested_bytes(get_contract, assert_tx_failed):
def test_clamper_nested_bytes(get_contract, tx_failed):
# check that _abi_decode clamps dynamic even when it is in a nested expression
# decode -> validate Bytes[20] -> revert if len(input) > 20 -> convert back to -> add 1
contract = """
Expand All @@ -369,7 +371,8 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]:
"""
c = get_contract(contract)
assert c.abi_decode(abi.encode("(bytes)", (b"bc",))) == b"abc"
assert_tx_failed(lambda: c.abi_decode(abi.encode("(bytes)", (b"a" * 22,))))
with tx_failed():
c.abi_decode(abi.encode("(bytes)", (b"a" * 22,)))


@pytest.mark.parametrize(
Expand All @@ -381,7 +384,7 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]:
("Bytes[5]", b"\x01" * 192),
],
)
def test_clamper_dynamic(get_contract, assert_tx_failed, output_typ, input_):
def test_clamper_dynamic(get_contract, tx_failed, output_typ, input_):
contract = f"""
@external
def abi_decode(x: Bytes[192]) -> {output_typ}:
Expand All @@ -390,7 +393,8 @@ def abi_decode(x: Bytes[192]) -> {output_typ}:
return a
"""
c = get_contract(contract)
assert_tx_failed(lambda: c.abi_decode(input_))
with tx_failed():
c.abi_decode(input_)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -422,7 +426,7 @@ def abi_decode(x: Bytes[160]) -> uint256:
("Bytes[5]", "address", b"\x01" * 128),
],
)
def test_clamper_dynamic_tuple(get_contract, assert_tx_failed, output_typ1, output_typ2, input_):
def test_clamper_dynamic_tuple(get_contract, tx_failed, output_typ1, output_typ2, input_):
contract = f"""
@external
def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}):
Expand All @@ -432,7 +436,8 @@ def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}):
return a, b
"""
c = get_contract(contract)
assert_tx_failed(lambda: c.abi_decode(input_))
with tx_failed():
c.abi_decode(input_)


FAIL_LIST = [
Expand Down
5 changes: 3 additions & 2 deletions tests/functional/builtins/codegen/test_addmod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation):
def test_uint256_addmod(tx_failed, get_contract_with_gas_estimation):
uint256_code = """
@external
def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256:
Expand All @@ -11,7 +11,8 @@ def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256:
assert c._uint256_addmod(32, 2, 32) == 2
assert c._uint256_addmod((2**256) - 1, 0, 2) == 1
assert c._uint256_addmod(2**255, 2**255, 6) == 4
assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0))
with tx_failed():
c._uint256_addmod(1, 2, 0)


def test_uint256_addmod_ext_call(
Expand Down
16 changes: 9 additions & 7 deletions tests/functional/builtins/codegen/test_as_wei_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
def test_wei_uint256(get_contract, assert_tx_failed, denom, multiplier):
def test_wei_uint256(get_contract, tx_failed, denom, multiplier):
code = f"""
@external
def foo(a: uint256) -> uint256:
Expand All @@ -36,11 +36,12 @@ def foo(a: uint256) -> uint256:
assert c.foo(value) == value * (10**multiplier)

value = (2**256 - 1) // (10 ** (multiplier - 1))
assert_tx_failed(lambda: c.foo(value))
with tx_failed():
c.foo(value)


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
def test_wei_int128(get_contract, assert_tx_failed, denom, multiplier):
def test_wei_int128(get_contract, tx_failed, denom, multiplier):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused fixture

code = f"""
@external
def foo(a: int128) -> uint256:
Expand All @@ -54,7 +55,7 @@ def foo(a: int128) -> uint256:


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
def test_wei_decimal(get_contract, assert_tx_failed, denom, multiplier):
def test_wei_decimal(get_contract, tx_failed, denom, multiplier):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused fixture

code = f"""
@external
def foo(a: decimal) -> uint256:
Expand All @@ -69,20 +70,21 @@ def foo(a: decimal) -> uint256:

@pytest.mark.parametrize("value", (-1, -(2**127)))
@pytest.mark.parametrize("data_type", ["decimal", "int128"])
def test_negative_value_reverts(get_contract, assert_tx_failed, value, data_type):
def test_negative_value_reverts(get_contract, tx_failed, value, data_type):
code = f"""
@external
def foo(a: {data_type}) -> uint256:
return as_wei_value(a, "ether")
"""

c = get_contract(code)
assert_tx_failed(lambda: c.foo(value))
with tx_failed():
c.foo(value)


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
@pytest.mark.parametrize("data_type", ["decimal", "int128", "uint256"])
def test_zero_value(get_contract, assert_tx_failed, denom, multiplier, data_type):
def test_zero_value(get_contract, tx_failed, denom, multiplier, data_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused fixture

code = f"""
@external
def foo(a: {data_type}) -> uint256:
Expand Down
13 changes: 8 additions & 5 deletions tests/functional/builtins/codegen/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def bar(a: uint256) -> Roles:
@pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"])
@pytest.mark.parametrize("val", [1, 2, 3, 4, 2**128, 2**256 - 1, 2**256 - 2])
def test_flag_conversion_2(
get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, val, typ
get_contract_with_gas_estimation, assert_compile_failed, tx_failed, val, typ
):
contract = f"""
flag Status:
Expand All @@ -529,7 +529,8 @@ def foo(a: {typ}) -> Status:
if lo <= val <= hi:
assert c.foo(val) == val
else:
assert_tx_failed(lambda: c.foo(val))
with tx_failed():
c.foo(val)
else:
assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch)

Expand Down Expand Up @@ -608,7 +609,7 @@ def foo() -> {t_bytes}:
@pytest.mark.parametrize("i_typ,o_typ,val", generate_reverting_cases())
@pytest.mark.fuzzing
def test_conversion_failures(
get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, i_typ, o_typ, val
get_contract_with_gas_estimation, assert_compile_failed, tx_failed, i_typ, o_typ, val
):
"""
Test multiple contracts and check for a specific exception.
Expand Down Expand Up @@ -650,7 +651,8 @@ def foo():
"""

c2 = get_contract_with_gas_estimation(contract_2)
assert_tx_failed(lambda: c2.foo())
with tx_failed():
c2.foo()

contract_3 = f"""
@external
Expand All @@ -659,4 +661,5 @@ def foo(bar: {i_typ}) -> {o_typ}:
"""

c3 = get_contract_with_gas_estimation(contract_3)
assert_tx_failed(lambda: c3.foo(val))
with tx_failed():
c3.foo(val)
Loading
Loading