Skip to content

Commit

Permalink
https://github.com/neo-project/neo/pull/2200
Browse files Browse the repository at this point in the history
  • Loading branch information
ixje committed Feb 8, 2021
1 parent 6cb6980 commit f530446
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 79 deletions.
94 changes: 84 additions & 10 deletions neo3/contracts/applicationengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ApplicationEngine(vm.ApplicationEngineCpp):
MAX_NOTIFICATION_SIZE = 1024
#: Maximum size of the smart contract script.
MAX_CONTRACT_LENGTH = 1024 * 1024

#: Multiplier for determining the costs of storing the contract including its manifest.

def __init__(self,
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(self,
self.exec_fee_factor = contracts.PolicyContract().get_exec_fee_factor(snapshot)
self.STORAGE_PRICE = contracts.PolicyContract().get_storage_price(snapshot)

self._context_state: Dict[vm.ExecutionContext, storage.ContractState] = {}

def checkwitness(self, hash_: types.UInt160) -> bool:
"""
Check if the hash is a valid witness for the engines script_container
Expand Down Expand Up @@ -286,30 +289,29 @@ def load_script_with_callflags(self,
call_flags: contracts.CallFlags,
initial_position: int = 0,
pcount: int = 0,
rvcount: int = -1):
rvcount: int = -1,
contract_state: Optional[storage.ContractState] = None):
context = super(ApplicationEngine, self).load_script(script, initial_position, pcount, rvcount)
context.call_flags = int(call_flags)
if contract_state is not None:
self._context_state.update({context: contract_state})
return context

def call_from_native(self,
calling_scripthash: types.UInt160,
hash_: types.UInt160,
method: str,
args: List[vm.StackItem]) -> None:
ctx = self.current_context
contract_call_descriptor = interop.InteropService.get_descriptor(
contracts.syscall_name_to_int("contract_call_internal")
)
if contract_call_descriptor is None:
raise ValueError
contract_call_descriptor.handler(self,
hash_,
method,
contracts.CallFlags.ALL,
False,
args)

self._contract_call_internal(hash_, method, contracts.CallFlags.ALL, False, args)
self.current_context.calling_scripthash_bytes = calling_scripthash.to_array()
self.step_out()
while self.current_context != ctx:
self.step_out()

def step_out(self) -> None:
c = len(self.invocation_stack)
Expand All @@ -332,15 +334,87 @@ def load_contract(self,
flags,
method_descriptor.offset,
pcount,
int(has_return_value))
int(has_return_value),
contract)

init = contract.manifest.abi.get_method("_initialize")
if init is not None:
self.load_context(context.clone(init.offset))
return context

def load_token(self, token_id: int) -> vm.ExecutionContext:
contract = self._context_state.get(self.current_context, None)
if contract is None:
raise ValueError("Current context has no contract state")
if token_id >= len(contract.nef.tokens):
raise ValueError("token_id exceeds available tokens")

token = contract.nef.tokens[token_id]
if token.parameters_count > len(self.current_context.evaluation_stack):
raise ValueError("Token count exceeds available paremeters on evaluation stack")
args: List[vm.StackItem] = []
for _ in range(token.parameters_count):
args.append(self.pop())
return self._contract_call_internal(token.hash, token.method, token.call_flags, token.has_return_value, args)

def call_native(self, name: str) -> None:
contract = contracts.ManagementContract().get_contract_by_name(name)
if contract is None or contract.active_block_index > self.snapshot.persisting_block.index:
raise ValueError
contract.invoke(self)

def context_unloaded(self, context: vm.ExecutionContext) -> None:
self._context_state.pop(context, None)

def _contract_call_internal(engine: contracts.ApplicationEngine,
contract_hash: types.UInt160,
method: str,
flags: contracts.CallFlags,
has_return_value: bool,
args: List[vm.StackItem]) -> vm.ExecutionContext:
if method.startswith('_'):
raise ValueError("[System.Contract.Call] Method not allowed to start with _")

target_contract = contracts.ManagementContract().get_contract(engine.snapshot, contract_hash)
if target_contract is None:
raise ValueError("[System.Contract.Call] Can't find target contract")

method_descriptor = target_contract.manifest.abi.get_method(method)
if method_descriptor is None:
raise ValueError(f"[System.Contract.Call] Method '{method}' does not exist on target contract")

if method_descriptor.safe:
flags &= ~contracts.CallFlags.WRITE_STATES
else:
current_contract = contracts.ManagementContract().get_contract(engine.snapshot, engine.current_scripthash)
if current_contract and not current_contract.can_call(target_contract, method):
raise ValueError(
f"[System.Contract.Call] Not allowed to call target method '{method}' according to manifest")

counter = engine._invocation_counter.get(target_contract.hash, 0)
engine._invocation_counter.update({target_contract.hash: counter + 1})

state = engine.current_context
calling_flags = state.call_flags

arg_len = len(args)
expected_len = len(method_descriptor.parameters)
if arg_len != expected_len:
raise ValueError(
f"[System.Contract.Call] Invalid number of contract arguments. Expected {expected_len} actual {arg_len}") # noqa

context_new = engine.load_contract(target_contract,
method_descriptor.name,
flags & calling_flags,
has_return_value,
len(args))
if context_new is None:
raise ValueError
context_new.calling_scripthash_bytes = state.calling_scripthash_bytes

for item in reversed(args):
context_new.evaluation_stack.push(item)

if contracts.NativeContract.is_native(target_contract.hash):
context_new.evaluation_stack.push(vm.ByteStringStackItem(method_descriptor.name.encode('utf-8')))
return context_new
5 changes: 1 addition & 4 deletions neo3/contracts/interop/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def contract_call(engine: contracts.ApplicationEngine,
for _ in range(pcount):
args.append(engine.pop())

# TODO: fix
# contract_call_internal(engine,
# contract_hash, method, args, flags, contracts.ReturnTypeConvention.ENSURE_NOT_EMPTY)
pass
engine._contract_call_internal(contract_hash, method, call_flags, has_return_value, args)


@register("System.Contract.IsStandard", 1 << 10, contracts.CallFlags.READ_STATES, [types.UInt160])
Expand Down
64 changes: 0 additions & 64 deletions neo3/contracts/native/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,70 +8,6 @@
from neo3.contracts.interop import register


@register("contract_call_internal", 0, contracts.CallFlags.ALL, [])
def contract_call_internal(engine: contracts.ApplicationEngine,
contract_hash: types.UInt160,
method: str,
flags: contracts.CallFlags,
has_return_value: bool,
args: List[vm.StackItem]) -> None:
if method.startswith('_'):
raise ValueError("[System.Contract.Call] Method not allowed to start with _")

target_contract = ManagementContract().get_contract(engine.snapshot, contract_hash)
if target_contract is None:
raise ValueError("[System.Contract.Call] Can't find target contract")

method_descriptor = target_contract.manifest.abi.get_method(method)
if method_descriptor is None:
raise ValueError(f"[System.Contract.Call] Method '{method}' does not exist on target contract")

if method_descriptor.safe:
flags &= ~contracts.CallFlags.WRITE_STATES
else:
current_contract = ManagementContract().get_contract(engine.snapshot, engine.current_scripthash)
if current_contract and not current_contract.can_call(target_contract, method):
raise ValueError(
f"[System.Contract.Call] Not allowed to call target method '{method}' according to manifest")

contract_call_internal_ex(engine, target_contract, method_descriptor, flags, has_return_value, args)


def contract_call_internal_ex(engine: contracts.ApplicationEngine,
contract: storage.ContractState,
contract_method_descriptor: contracts.ContractMethodDescriptor,
flags: contracts.CallFlags,
has_return_value: bool,
args: List[vm.StackItem],
) -> None:
counter = engine._invocation_counter.get(contract.hash, 0)
engine._invocation_counter.update({contract.hash: counter + 1})

state = engine.current_context
calling_flags = state.call_flags

arg_len = len(args)
expected_len = len(contract_method_descriptor.parameters)
if arg_len != expected_len:
raise ValueError(
f"[System.Contract.Call] Invalid number of contract arguments. Expected {expected_len} actual {arg_len}") # noqa

context_new = engine.load_contract(contract,
contract_method_descriptor.name,
flags & calling_flags,
has_return_value,
len(args))
if context_new is None:
raise ValueError
context_new.calling_scripthash_bytes = state.calling_scripthash_bytes

for item in reversed(args):
context_new.evaluation_stack.push(item)

if contracts.NativeContract.is_native(contract.hash):
context_new.evaluation_stack.push(vm.ByteStringStackItem(contract_method_descriptor.name.encode('utf-8')))


class ManagementContract(NativeContract):
_id = 0
_service_name = "ContractManagement"
Expand Down
1 change: 0 additions & 1 deletion tests/contracts/interop/test_contract_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from neo3 import vm, contracts, storage
from neo3.network import payloads
from neo3.contracts import syscall_name_to_int
from neo3.contracts.interop.contract import contract_call_internal
from neo3.core import to_script_hash, types, cryptography
from .utils import test_engine, test_block, test_tx
from copy import deepcopy
Expand Down

0 comments on commit f530446

Please sign in to comment.