Skip to content

Commit

Permalink
Merge pull request #123 from HyperLink-Technology/coverage-memory
Browse files Browse the repository at this point in the history
Final commits for v1.0.0b5
  • Loading branch information
iamdefinitelyahuman authored May 13, 2019
2 parents cb6150f + a2467e3 commit 96fb1cb
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 65 deletions.
73 changes: 40 additions & 33 deletions brownie/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from brownie.cli.utils import color
from brownie.test.coverage import (
analyze_coverage,
merge_coverage,
merge_coverage_eval,
merge_coverage_files,
generate_report
)
from brownie.exceptions import ExpectedFailing
Expand All @@ -33,6 +34,8 @@
WARN = "{0[error]}WARNING{0}: ".format(color)
ERROR = "{0[error]}ERROR{0}: ".format(color)

history = TxHistory()

__doc__ = """Usage: brownie test [<filename>] [<range>] [options]
Arguments:
Expand All @@ -55,9 +58,10 @@
def main():
args = docopt(__doc__)
ARGV._update_from_args(args)
# if type(CONFIG['test']['gas_limit']) is int:
# network.gas_limit(CONFIG['test']['gas_limit'])
if ARGV['coverage']:
ARGV['always_transact'] = True
history = TxHistory()
history._revert_lock = True

test_paths = get_test_paths(args['<filename>'])
Expand Down Expand Up @@ -109,8 +113,11 @@ def get_test_data(test_paths):
coverage_files = []
test_data = []
project = Path(CONFIG['folders']['project'])
build_path = project.joinpath('build/coverage')
for path in test_paths:
coverage_json = project.joinpath("build/coverage/"+path.stem+".json")
path = path.relative_to(project)
coverage_json = build_path.joinpath(path.parent.relative_to('tests'))
coverage_json = coverage_json.joinpath(path.stem+".json")
coverage_files.append(coverage_json)
if coverage_json.exists():
coverage_eval = json.load(coverage_json.open())['coverage']
Expand All @@ -120,7 +127,7 @@ def get_test_data(test_paths):
coverage_eval = {}
for p in list(coverage_json.parents)[::-1]:
p.mkdir(exist_ok=True)
module_name = str(path.relative_to(project))[:-3].replace(os.sep, '.')
module_name = str(path)[:-3].replace(os.sep, '.')
module = importlib.import_module(module_name)
test_names = re.findall(r'\ndef[\s ]{1,}([^_]\w*)[\s ]*\([^)]*\)', path.open().read())
if not test_names:
Expand All @@ -129,7 +136,7 @@ def get_test_data(test_paths):
duplicates = set([i for i in test_names if test_names.count(i) > 1])
if duplicates:
raise ValueError("{} contains multiple test methods of the same name: {}".format(
path.relative_to(project),
path,
", ".join(duplicates)
))
if 'setup' in test_names:
Expand All @@ -145,6 +152,10 @@ def run_test_modules(test_data, save):
count = sum([len([x for x in i[3] if x != "setup"]) for i in test_data])
print("Running {} tests across {} modules.".format(count, len(test_data)))
network.connect(ARGV['network'])
for key in ('broadcast_reverting_tx', 'gas_limit'):
CONFIG['active_network'][key] = CONFIG['test'][key]
if not CONFIG['active_network']['broadcast_reverting_tx']:
print("{0[error]}WARNING{0}: Reverting transactions will NOT be broadcasted.".format(color))
traceback_info = []
start_time = time.time()
try:
Expand Down Expand Up @@ -178,7 +189,7 @@ def run_test_modules(test_data, save):
print("\n\nTest execution has been terminated by KeyboardInterrupt.")
sys.exit()
finally:
print("\nTotal runtime: {:.4}s".format(time.time() - start_time))
print("\nTotal runtime: {:.4f}s".format(time.time() - start_time))
if traceback_info:
print("{0}{1} test{2} failed.".format(
WARN,
Expand All @@ -193,10 +204,7 @@ def run_test_modules(test_data, save):

def run_test(module, network, test_names):
network.rpc.reset()
if type(CONFIG['test']['gas_limit']) is int:
network.gas_limit(CONFIG['test']['gas_limit'])

traceback_info = []
if 'setup' in test_names:
test_names.remove('setup')
fn, default_args = _get_fn(module, 'setup')
Expand All @@ -206,70 +214,69 @@ def run_test(module, network, test_names):
):
return [], {}
p = TestPrinter(module.__file__, 0, len(test_names))
traceback_info += run_test_method(fn, default_args, p)
if traceback_info:
return traceback_info, {}
tb, coverage_eval = run_test_method(fn, default_args, {}, p)
if tb:
return tb, {}
else:
p = TestPrinter(module.__file__, 1, len(test_names))
default_args = FalseyDict()
coverage_eval = {}
network.rpc.snapshot()
traceback_info = []
for t in test_names:
network.rpc.revert()
fn, fn_args = _get_fn(module, t)
args = default_args.copy()
args.update(fn_args)
traceback_info += run_test_method(fn, args, p)
if traceback_info and traceback_info[-1][2] == ReadTimeout:
tb, coverage_eval = run_test_method(fn, args, coverage_eval, p)
traceback_info += tb
if tb and tb[0][2] == ReadTimeout:
print(WARN+"RPC crashed, terminating test")
network.rpc.kill(False)
network.rpc.launch(CONFIG['active_network']['test-rpc'])
break
coverage_eval = {}
if not traceback_info and ARGV['coverage']:
p.start("Evaluating test coverage")
coverage_eval = analyze_coverage(TxHistory().copy())
p.stop()
if traceback_info and ARGV['coverage']:
coverage_eval = {}
p.finish()
return traceback_info, coverage_eval


def run_test_method(fn, args, p):
def run_test_method(fn, args, coverage_eval, p):
desc = fn.__doc__ or fn.__name__
if args['skip'] is True or (args['skip'] == "coverage" and ARGV['coverage']):
p.skip(desc)
return []
return [], coverage_eval
p.start(desc)
try:
if ARGV['coverage'] and 'always_transact' in args:
ARGV['always_transact'] = args['always_transact']
fn()
if ARGV['coverage']:
ARGV['always_transact'] = True
coverage_eval = merge_coverage_eval(
coverage_eval,
analyze_coverage(history.copy())
)
history.clear()
if args['pending']:
raise ExpectedFailing("Test was expected to fail")
p.stop()
return []
return [], coverage_eval
except Exception as e:
p.stop(e, args['pending'])
if type(e) != ExpectedFailing and args['pending']:
return []
return [], coverage_eval
path = Path(sys.modules[fn.__module__].__file__).relative_to(CONFIG['folders']['project'])
path = "{0[module]}{1}.{0[callable]}{2}{0}".format(color, str(path)[:-3], fn.__name__)
return [(
path,
color.format_tb(
sys.exc_info(),
sys.modules[fn.__module__].__file__,
),
type(e)
)]
tb = color.format_tb(sys.exc_info(), sys.modules[fn.__module__].__file__)
return [(path, tb, type(e))], coverage_eval


def display_report(coverage_files, save):
coverage_eval = merge_coverage(coverage_files)
coverage_eval = merge_coverage_files(coverage_files)
report = generate_report(coverage_eval)
print("\nCoverage analysis:")
for name in coverage_eval:
for name in sorted(coverage_eval):
pct = coverage_eval[name].pop('pct')
c = color(next(i[1] for i in COVERAGE_COLORS if pct <= i[0]))
print("\n contract: {0[contract]}{1}{0} - {2}{3:.1%}{0}".format(color, name, c, pct))
Expand Down
3 changes: 2 additions & 1 deletion brownie/data/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
},
"test":{
"gas_limit": 6721975,
"default_contract_owner": false
"default_contract_owner": false,
"broadcast_reverting_tx": true
},
"solc":{
"optimize": true,
Expand Down
4 changes: 4 additions & 0 deletions brownie/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def __init__(self, cmd, proc, uri):
super().__init__("Able to launch RPC client, but unable to connect.", cmd, proc, uri)


class RPCRequestError(Exception):
pass


class VirtualMachineError(Exception):

'''Raised when a call to a contract causes an EVM exception.
Expand Down
11 changes: 7 additions & 4 deletions brownie/network/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def call(self, *args):
return format_output(result[0])
return KwargTuple(result, self.abi)

def transact(self, *args):
def transact(self, *args, _rpc_clear=True):
'''Broadcasts a transaction that calls this contract method.
Args:
Expand All @@ -299,7 +299,8 @@ def transact(self, *args):
"Contract has no owner, you must supply a tx dict"
" with a 'from' field as the last argument."
)
rpc._internal_clear()
if _rpc_clear:
rpc._internal_clear()
return tx['from'].transfer(
self._address,
tx['value'],
Expand Down Expand Up @@ -331,7 +332,7 @@ class ContractTx(_ContractMethod):

def __init__(self, fn, abi, name, owner):
if (
ARGV['cli'] != "console" and not
ARGV['cli'] == "test" and not
CONFIG['test']['default_contract_owner']
):
owner = None
Expand Down Expand Up @@ -368,7 +369,9 @@ def __call__(self, *args):
Contract method return value(s).'''
if ARGV['always_transact']:
rpc._internal_snap()
tx = self.transact(*args, {'gas_price': 0})
args, tx = _get_tx(self._owner, args)
tx['gas_price'] = 0
tx = self.transact(*args, tx, _rpc_clear=False)
if tx.modified_state:
rpc._internal_revert()
return tx.return_value
Expand Down
3 changes: 3 additions & 0 deletions brownie/network/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def _console_repr(self):
def _add_tx(self, tx):
self._list.append(tx)

def clear(self):
self._list.clear()

def copy(self):
'''Returns a shallow copy of the object as a list'''
return self._list.copy()
Expand Down
14 changes: 11 additions & 3 deletions brownie/network/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .web3 import Web3

from brownie.types.types import _Singleton
from brownie.exceptions import RPCProcessError, RPCConnectionError
from brownie.exceptions import RPCProcessError, RPCConnectionError, RPCRequestError


web3 = Web3()
Expand Down Expand Up @@ -131,9 +131,12 @@ def _request(self, *args):
if not self.is_active():
raise SystemError("RPC is not active.")
try:
return web3.providers[0].make_request(*args)['result']
response = web3.providers[0].make_request(*args)
if 'result' in response:
return response['result']
except IndexError:
raise RPCConnectionError("Web3 is not connected.")
raise RPCRequestError(response['error']['message'])

def _snap(self):
return self._request("evm_snapshot", [])
Expand All @@ -145,7 +148,10 @@ def _revert(self, id_):
id_ = self._snap()
self.sleep(0)
for i in self._objects:
i._revert()
if web3.eth.blockNumber == 0:
i._reset()
else:
i._revert()
return id_

def _reset(self):
Expand Down Expand Up @@ -201,12 +207,14 @@ def revert(self):
'''Reverts the EVM to the most recently taken snapshot.'''
if not self._snapshot_id:
raise ValueError("No snapshot set")
self._internal_id = None
self._snapshot_id = self._revert(self._snapshot_id)
return "Block height reverted to {}".format(web3.eth.blockNumber)

def reset(self):
'''Reverts the EVM to the genesis state.'''
self._snapshot_id = None
self._internal_id = None
self._reset_id = self._revert(self._reset_id)
return "Block height reset to 0"

Expand Down
9 changes: 6 additions & 3 deletions brownie/project/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def __init__(self):
self._build = {}
self._path = None

def __getitem__(self, contract_name):
return self._build[contract_name.replace('.json', '')]

def __contains__(self, contract_name):
return contract_name.replace('.json', '') in self._build

def _load(self):
self. _path = Path(CONFIG['folders']['project']).joinpath('build/contracts')
# check build paths
Expand Down Expand Up @@ -119,9 +125,6 @@ def _check_coverage_hashes(self):
coverage_json.unlink()
break

def __getitem__(self, contract_name):
return self._build[contract_name.replace('.json', '')]

def items(self):
return self._build.items()

Expand Down
16 changes: 10 additions & 6 deletions brownie/project/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,23 @@ def get_fn(self, name, start, stop):
return False if stop > offset[2] else offset[0]

def get_fn_offset(self, name, fn_name):
if name not in self._data:
name = next(
k for k, v in self._data.items() if v['sourcePath'] == str(name) and
fn_name in [i[0] for i in v['fn_offsets']]
)
return next(i for i in self._data[name]['fn_offsets'] if i[0] == fn_name)[1:3]
try:
if name not in self._data:
name = next(
k for k, v in self._data.items() if v['sourcePath'] == str(name) and
fn_name in [i[0] for i in v['fn_offsets']]
)
return next(i for i in self._data[name]['fn_offsets'] if i[0] == fn_name)[1:3]
except StopIteration:
raise ValueError("Unknown function '{}' in contract {}".format(fn_name, name))

def inheritance_map(self):
return dict((k, v['inherited'].copy()) for k, v in self._data.items())

def add_source(self, source):
path = "<string-{}>".format(self._string_iter)
self._source[path] = source
self._remove_comments(path)
self._get_contract_data(path)
self._string_iter += 1
return path
Loading

0 comments on commit 96fb1cb

Please sign in to comment.