Skip to content

Commit

Permalink
Dynamic Call (neo-project#2167)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikzhang authored and cloud8little committed Jan 24, 2021
1 parent 8ca3e09 commit f69ddf2
Show file tree
Hide file tree
Showing 22 changed files with 157 additions and 263 deletions.
2 changes: 1 addition & 1 deletion src/neo/Network/P2P/Payloads/OracleResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class OracleResponse : TransactionAttribute
static OracleResponse()
{
using ScriptBuilder sb = new ScriptBuilder();
sb.EmitAppCall(NativeContract.Oracle.Hash, "finish");
sb.EmitDynamicCall(NativeContract.Oracle.Hash, "finish", false);
FixedScript = sb.ToArray();
}

Expand Down
35 changes: 16 additions & 19 deletions src/neo/SmartContract/ApplicationEngine.Contract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
using Neo.SmartContract.Manifest;
using Neo.SmartContract.Native;
using Neo.VM;
using Neo.VM.Types;
using System;
using Array = Neo.VM.Types.Array;

namespace Neo.SmartContract
{
partial class ApplicationEngine
{
public static readonly InteropDescriptor System_Contract_CallEx = Register("System.Contract.CallEx", nameof(CallContractEx), 1 << 15, CallFlags.AllowCall);
public static readonly InteropDescriptor System_Contract_Call = Register("System.Contract.Call", nameof(CallContract), 1 << 15, CallFlags.AllowCall);
public static readonly InteropDescriptor System_Contract_CallNative = Register("System.Contract.CallNative", nameof(CallNativeContract), 0, CallFlags.None);
public static readonly InteropDescriptor System_Contract_IsStandard = Register("System.Contract.IsStandard", nameof(IsStandardContract), 1 << 10, CallFlags.ReadStates);
public static readonly InteropDescriptor System_Contract_GetCallFlags = Register("System.Contract.GetCallFlags", nameof(GetCallFlags), 1 << 10, CallFlags.None);
Expand All @@ -22,15 +22,20 @@ partial class ApplicationEngine
public static readonly InteropDescriptor System_Contract_NativeOnPersist = Register("System.Contract.NativeOnPersist", nameof(NativeOnPersist), 0, CallFlags.WriteStates);
public static readonly InteropDescriptor System_Contract_NativePostPersist = Register("System.Contract.NativePostPersist", nameof(NativePostPersist), 0, CallFlags.WriteStates);

protected internal void CallContractEx(UInt160 contractHash, string method, Array args, CallFlags callFlags)
protected internal void CallContract(UInt160 contractHash, string method, CallFlags callFlags, bool hasReturnValue, ushort pcount)
{
if (method.StartsWith('_')) throw new ArgumentException($"Invalid Method Name: {method}");
if ((callFlags & ~CallFlags.All) != 0)
throw new ArgumentOutOfRangeException(nameof(callFlags));
CallContractInternal(contractHash, method, args, callFlags, ReturnTypeConvention.EnsureNotEmpty);
if (pcount > CurrentContext.EvaluationStack.Count)
throw new InvalidOperationException();
StackItem[] args = new StackItem[pcount];
for (int i = 0; i < pcount; i++)
args[i] = Pop();
CallContractInternal(contractHash, method, callFlags, hasReturnValue, args);
}

private void CallContractInternal(UInt160 contractHash, string method, Array args, CallFlags flags, ReturnTypeConvention convention)
private void CallContractInternal(UInt160 contractHash, string method, CallFlags flags, bool hasReturnValue, StackItem[] args)
{
ContractState contract = NativeContract.ContractManagement.GetContract(Snapshot, contractHash);
if (contract is null) throw new InvalidOperationException($"Called Contract Does Not Exist: {contractHash}");
Expand All @@ -48,10 +53,10 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg
throw new InvalidOperationException($"Cannot Call Method {method} Of Contract {contractHash} From Contract {CurrentScriptHash}");
}

CallContractInternal(contract, md, args, flags, convention);
CallContractInternal(contract, md, flags, hasReturnValue, args);
}

private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, Array args, CallFlags flags, ReturnTypeConvention convention)
private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, CallFlags flags, bool hasReturnValue, StackItem[] args)
{
if (invocationCounter.TryGetValue(contract.Hash, out var counter))
{
Expand All @@ -62,27 +67,19 @@ private void CallContractInternal(ContractState contract, ContractMethodDescript
invocationCounter[contract.Hash] = 1;
}

GetInvocationState(CurrentContext).Convention = convention;

ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
UInt160 callingScriptHash = state.ScriptHash;
CallFlags callingFlags = state.CallFlags;

if (args.Count != method.Parameters.Length) throw new InvalidOperationException($"Method {method.Name} Expects {method.Parameters.Length} Arguments But Receives {args.Count} Arguments");
ExecutionContext context_new = LoadContract(contract, method.Name, flags & callingFlags, false);
if (args.Length != method.Parameters.Length) throw new InvalidOperationException($"Method {method.Name} Expects {method.Parameters.Length} Arguments But Receives {args.Length} Arguments");
ExecutionContext context_new = LoadContract(contract, method.Name, flags & callingFlags, hasReturnValue, (ushort)args.Length);
state = context_new.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;

for (int i = args.Length - 1; i >= 0; i--)
context_new.EvaluationStack.Push(args[i]);
if (NativeContract.IsNative(contract.Hash))
{
context_new.EvaluationStack.Push(args);
context_new.EvaluationStack.Push(method.Name);
}
else
{
for (int i = args.Count - 1; i >= 0; i--)
context_new.EvaluationStack.Push(args[i]);
}
}

protected internal void CallNativeContract(string name)
Expand Down
1 change: 1 addition & 0 deletions src/neo/SmartContract/ApplicationEngine.OpCodePrices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ partial class ApplicationEngine
[OpCode.CALL] = 1 << 9,
[OpCode.CALL_L] = 1 << 9,
[OpCode.CALLA] = 1 << 9,
[OpCode.CALLT] = 1 << 15,
[OpCode.ABORT] = 0,
[OpCode.ASSERT] = 1 << 0,
[OpCode.THROW] = 1 << 9,
Expand Down
97 changes: 21 additions & 76 deletions src/neo/SmartContract/ApplicationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ namespace Neo.SmartContract
{
public partial class ApplicationEngine : ExecutionEngine
{
private enum ReturnTypeConvention : byte
{
None = 0,
EnsureIsEmpty = 1,
EnsureNotEmpty = 2
}

private class InvocationState
{
public ReturnTypeConvention Convention;
}

/// <summary>
/// This constant can be used for testing scripts.
/// </summary>
Expand All @@ -46,7 +34,6 @@ private class InvocationState
private List<NotifyEventArgs> notifications;
private List<IDisposable> disposables;
private readonly Dictionary<UInt160, int> invocationCounter = new Dictionary<UInt160, int>();
private readonly Dictionary<ExecutionContext, InvocationState> invocationStates = new Dictionary<ExecutionContext, InvocationState>();
private readonly uint exec_fee_factor;
internal readonly uint StoragePrice;

Expand Down Expand Up @@ -88,59 +75,27 @@ protected override void OnFault(Exception e)

internal void CallFromNativeContract(UInt160 callingScriptHash, UInt160 hash, string method, params StackItem[] args)
{
CallContractInternal(hash, method, new VMArray(ReferenceCounter, args), CallFlags.All, ReturnTypeConvention.EnsureIsEmpty);
CallContractInternal(hash, method, CallFlags.All, false, args);
ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;
StepOut();
}

internal T CallFromNativeContract<T>(UInt160 callingScriptHash, UInt160 hash, string method, params StackItem[] args)
{
CallFromNativeContract(callingScriptHash, hash, method, args);
CallContractInternal(hash, method, CallFlags.All, true, args);
ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;
StepOut();
return (T)Convert(Pop(), new InteropParameterDescriptor(typeof(T)));
}

protected override void ContextUnloaded(ExecutionContext context)
{
base.ContextUnloaded(context);
if (!(UncaughtException is null)) return;
if (invocationStates.Count == 0) return;
if (!invocationStates.Remove(CurrentContext, out InvocationState state)) return;
switch (state.Convention)
{
case ReturnTypeConvention.EnsureIsEmpty:
{
if (context.EvaluationStack.Count != 0)
throw new InvalidOperationException();
break;
}
case ReturnTypeConvention.EnsureNotEmpty:
{
if (context.EvaluationStack.Count == 0)
Push(StackItem.Null);
else if (context.EvaluationStack.Count > 1)
throw new InvalidOperationException();
break;
}
}
}

public static ApplicationEngine Create(TriggerType trigger, IVerifiable container, StoreView snapshot, long gas = TestModeGas)
{
return applicationEngineProvider?.Create(trigger, container, snapshot, gas)
?? new ApplicationEngine(trigger, container, snapshot, gas);
}

private InvocationState GetInvocationState(ExecutionContext context)
{
if (!invocationStates.TryGetValue(context, out InvocationState state))
{
state = new InvocationState();
invocationStates.Add(context, state);
}
return state;
}

protected override void LoadContext(ExecutionContext context)
{
// Set default execution context state
Expand All @@ -152,42 +107,32 @@ protected override void LoadContext(ExecutionContext context)
base.LoadContext(context);
}

public ExecutionContext LoadContract(ContractState contract, string method, CallFlags callFlags, bool packParameters = false)
public ExecutionContext LoadContract(ContractState contract, string method, CallFlags callFlags, bool hasReturnValue, ushort pcount)
{
ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod(method);
if (md is null) return null;

ExecutionContext context = LoadScript(contract.Script, callFlags, contract.Hash, md.Offset);
ExecutionContext context = LoadScript(contract.Script,
pcount: pcount,
rvcount: hasReturnValue ? 1 : 0,
initialPosition: md.Offset,
callFlags: callFlags,
scriptHash: contract.Hash);

if (NativeContract.IsNative(contract.Hash))
{
if (packParameters)
{
using ScriptBuilder sb = new ScriptBuilder();
sb.Emit(OpCode.DEPTH, OpCode.PACK);
sb.EmitPush(md.Name);
LoadScript(sb.ToArray(), CallFlags.None);
}
}
else
// Call initialization
var init = contract.Manifest.Abi.GetMethod("_initialize");
if (init != null)
{
// Call initialization

var init = contract.Manifest.Abi.GetMethod("_initialize");

if (init != null)
{
LoadContext(context.Clone(init.Offset));
}
LoadContext(context.Clone(init.Offset));
}

return context;
}

public ExecutionContext LoadScript(Script script, CallFlags callFlags, UInt160 scriptHash = null, int initialPosition = 0)
public ExecutionContext LoadScript(Script script, ushort pcount = 0, int rvcount = -1, int initialPosition = 0, CallFlags callFlags = CallFlags.All, UInt160 scriptHash = null)
{
// Create and configure context
ExecutionContext context = CreateContext(script, initialPosition);
ExecutionContext context = CreateContext(script, pcount, rvcount, initialPosition);
var state = context.GetState<ExecutionContextState>();
state.CallFlags = callFlags;
state.ScriptHash = scriptHash ?? ((byte[])script).ToScriptHash();
Expand Down Expand Up @@ -294,13 +239,13 @@ protected override void PreExecuteInstruction()
AddGas(exec_fee_factor * OpCodePrices[CurrentContext.CurrentInstruction.OpCode]);
}

private void StepOut()
internal void StepOut()
{
int c = InvocationStack.Count;
while (State != VMState.HALT && State != VMState.FAULT && InvocationStack.Count >= c)
ExecuteNext();
if (State == VMState.FAULT)
throw new InvalidOperationException("Call from native contract failed.", FaultException);
throw new InvalidOperationException("StepOut failed.", FaultException);
}

private static Block CreateDummyBlock(StoreView snapshot)
Expand Down Expand Up @@ -350,7 +295,7 @@ public static ApplicationEngine Run(byte[] script, StoreView snapshot = null, IV
snapshot.PersistingBlock = persistingBlock ?? snapshot.PersistingBlock ?? CreateDummyBlock(snapshot);
ApplicationEngine engine = Create(TriggerType.Application, container, snapshot, gas);
if (disposable != null) engine.Disposables.Add(disposable);
engine.LoadScript(script, offset);
engine.LoadScript(script, initialPosition: offset);
engine.Execute();
return engine;
}
Expand Down
19 changes: 15 additions & 4 deletions src/neo/SmartContract/Helper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,30 @@ internal static bool VerifyWitness(this IVerifiable verifiable, StoreView snapsh
{
ContractState cs = NativeContract.ContractManagement.GetContract(snapshot, hash);
if (cs is null) return false;
if (engine.LoadContract(cs, "verify", callFlags, true) is null)
if (engine.LoadContract(cs, "verify", callFlags, true, 0) is null)
return false;
}
else
{
if (NativeContract.IsNative(hash)) return false;
if (hash != witness.ScriptHash) return false;
engine.LoadScript(verification, callFlags, hash, 0);
engine.LoadScript(verification, callFlags: callFlags, scriptHash: hash, initialPosition: 0);
}

engine.LoadScript(witness.InvocationScript, callFlags: CallFlags.None);

if (NativeContract.IsNative(hash))
{
try
{
engine.StepOut();
engine.Push("verify");
}
catch { }
}

engine.LoadScript(witness.InvocationScript, CallFlags.None);
if (engine.Execute() == VMState.FAULT) return false;
if (engine.ResultStack.Count != 1 || !engine.ResultStack.Peek().GetBoolean()) return false;
if (!engine.ResultStack.Peek().GetBoolean()) return false;
fee = engine.GasConsumed;
}
return true;
Expand Down
10 changes: 3 additions & 7 deletions src/neo/SmartContract/Native/NativeContract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ internal void Invoke(ApplicationEngine engine)
throw new InvalidOperationException("It is not allowed to use Neo.Native.Call directly to call native contracts. System.Contract.Call should be used.");
ExecutionContext context = engine.CurrentContext;
string operation = context.EvaluationStack.Pop().GetString();
Array args = context.EvaluationStack.Pop<Array>();
ContractMethodMetadata method = methods[operation];
ExecutionContextState state = context.GetState<ExecutionContextState>();
if (!state.CallFlags.HasFlag(method.RequiredCallFlags))
Expand All @@ -117,10 +116,7 @@ internal void Invoke(ApplicationEngine engine)
if (method.NeedApplicationEngine) parameters.Add(engine);
if (method.NeedSnapshot) parameters.Add(engine.Snapshot);
for (int i = 0; i < method.Parameters.Length; i++)
{
StackItem item = i < args.Count ? args[i] : StackItem.Null;
parameters.Add(engine.Convert(item, method.Parameters[i]));
}
parameters.Add(engine.Convert(context.EvaluationStack.Pop(), method.Parameters[i]));
object returnValue = method.Handler.Invoke(this, parameters.ToArray());
if (method.Handler.ReturnType != typeof(void))
context.EvaluationStack.Push(engine.Convert(returnValue));
Expand All @@ -143,11 +139,11 @@ internal virtual void PostPersist(ApplicationEngine engine)
{
}

public ApplicationEngine TestCall(string operation, params object[] args)
public ApplicationEngine TestCall(string operation, bool hasReturnValue, params object[] args)
{
using (ScriptBuilder sb = new ScriptBuilder())
{
sb.EmitAppCall(Hash, operation, args);
sb.EmitDynamicCall(Hash, operation, hasReturnValue, args);
return ApplicationEngine.Run(sb.ToArray());
}
}
Expand Down
Loading

0 comments on commit f69ddf2

Please sign in to comment.