Skip to content

Commit

Permalink
LoadToken (#2200)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikzhang authored Jan 10, 2021
1 parent 581a681 commit 3ee76ba
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 59 deletions.
49 changes: 0 additions & 49 deletions src/neo/SmartContract/ApplicationEngine.Contract.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using Neo.Cryptography.ECC;
using Neo.Network.P2P.Payloads;
using Neo.SmartContract.Manifest;
using Neo.SmartContract.Native;
using Neo.VM;
using Neo.VM.Types;
using System;

Expand Down Expand Up @@ -35,53 +33,6 @@ protected internal void CallContract(UInt160 contractHash, string method, CallFl
CallContractInternal(contractHash, method, callFlags, hasReturnValue, args);
}

private void CallContractInternal(UInt160 contractHash, string method, CallFlags flags, bool hasReturnValue, StackItem[] args)
{
ContractState contract = NativeContract.ContractManagement.GetContract(Snapshot, contractHash);
if (contract is null) throw new InvalidOperationException($"Called Contract Does Not Exist: {contractHash}");
ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod(method);
if (md is null) throw new InvalidOperationException($"Method {method} Does Not Exist In Contract {contractHash}");

if (md.Safe)
{
flags &= ~CallFlags.WriteStates;
}
else
{
ContractState currentContract = NativeContract.ContractManagement.GetContract(Snapshot, CurrentScriptHash);
if (currentContract?.CanCall(contract, method) == false)
throw new InvalidOperationException($"Cannot Call Method {method} Of Contract {contractHash} From Contract {CurrentScriptHash}");
}

CallContractInternal(contract, md, flags, hasReturnValue, args);
}

private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, CallFlags flags, bool hasReturnValue, StackItem[] args)
{
if (invocationCounter.TryGetValue(contract.Hash, out var counter))
{
invocationCounter[contract.Hash] = counter + 1;
}
else
{
invocationCounter[contract.Hash] = 1;
}

ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
UInt160 callingScriptHash = state.ScriptHash;
CallFlags callingFlags = state.CallFlags;

if (args.Length != method.Parameters.Length) throw new InvalidOperationException($"Method {method.Name} Expects {method.Parameters.Length} Arguments But Receives {args.Length} Arguments");
ExecutionContext context_new = LoadContract(contract, method.Name, flags & callingFlags, hasReturnValue, (ushort)args.Length);
state = context_new.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;

for (int i = args.Length - 1; i >= 0; i--)
context_new.EvaluationStack.Push(args[i]);
if (NativeContract.IsNative(contract.Hash))
context_new.EvaluationStack.Push(method.Name);
}

protected internal void CallNativeContract(string name)
{
NativeContract contract = NativeContract.GetContract(name);
Expand Down
76 changes: 70 additions & 6 deletions src/neo/SmartContract/ApplicationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,69 @@ protected override void OnFault(Exception e)
base.OnFault(e);
}

internal void CallFromNativeContract(UInt160 callingScriptHash, UInt160 hash, string method, params StackItem[] args)
private ExecutionContext CallContractInternal(UInt160 contractHash, string method, CallFlags flags, bool hasReturnValue, StackItem[] args)
{
CallContractInternal(hash, method, CallFlags.All, false, args);
ContractState contract = NativeContract.ContractManagement.GetContract(Snapshot, contractHash);
if (contract is null) throw new InvalidOperationException($"Called Contract Does Not Exist: {contractHash}");
ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod(method);
if (md is null) throw new InvalidOperationException($"Method {method} Does Not Exist In Contract {contractHash}");

if (md.Safe)
{
flags &= ~CallFlags.WriteStates;
}
else
{
ContractState currentContract = NativeContract.ContractManagement.GetContract(Snapshot, CurrentScriptHash);
if (currentContract?.CanCall(contract, method) == false)
throw new InvalidOperationException($"Cannot Call Method {method} Of Contract {contractHash} From Contract {CurrentScriptHash}");
}

if (invocationCounter.TryGetValue(contract.Hash, out var counter))
{
invocationCounter[contract.Hash] = counter + 1;
}
else
{
invocationCounter[contract.Hash] = 1;
}

ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
UInt160 callingScriptHash = state.ScriptHash;
CallFlags callingFlags = state.CallFlags;

if (args.Length != md.Parameters.Length) throw new InvalidOperationException($"Method {method} Expects {md.Parameters.Length} Arguments But Receives {args.Length} Arguments");
if (hasReturnValue ^ (md.ReturnType != ContractParameterType.Void)) throw new InvalidOperationException("The return value type does not match.");
ExecutionContext context_new = LoadContract(contract, method, flags & callingFlags, hasReturnValue, (ushort)args.Length);
state = context_new.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;
StepOut();

for (int i = args.Length - 1; i >= 0; i--)
context_new.EvaluationStack.Push(args[i]);
if (NativeContract.IsNative(contract.Hash))
context_new.EvaluationStack.Push(method);

return context_new;
}

internal void CallFromNativeContract(UInt160 callingScriptHash, UInt160 hash, string method, params StackItem[] args)
{
ExecutionContext context_current = CurrentContext;
ExecutionContext context_new = CallContractInternal(hash, method, CallFlags.All, false, args);
ExecutionContextState state = context_new.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;
while (CurrentContext != context_current)
StepOut();
}

internal T CallFromNativeContract<T>(UInt160 callingScriptHash, UInt160 hash, string method, params StackItem[] args)
{
CallContractInternal(hash, method, CallFlags.All, true, args);
ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
ExecutionContext context_current = CurrentContext;
ExecutionContext context_new = CallContractInternal(hash, method, CallFlags.All, true, args);
ExecutionContextState state = context_new.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;
StepOut();
while (CurrentContext != context_current)
StepOut();
return (T)Convert(Pop(), new InteropParameterDescriptor(typeof(T)));
}

Expand Down Expand Up @@ -123,6 +172,7 @@ public ExecutionContext LoadContract(ContractState contract, string method, Call
{
p.CallFlags = callFlags;
p.ScriptHash = contract.Hash;
p.Contract = contract;
});

// Call initialization
Expand All @@ -145,6 +195,20 @@ public ExecutionContext LoadScript(Script script, ushort pcount = 0, int rvcount
return context;
}

protected override ExecutionContext LoadToken(ushort tokenId)
{
ContractState contract = CurrentContext.GetState<ExecutionContextState>().Contract;
if (contract is null || tokenId >= contract.Nef.Tokens.Length)
throw new InvalidOperationException();
MethodToken token = contract.Nef.Tokens[tokenId];
if (token.ParametersCount > CurrentContext.EvaluationStack.Count)
throw new InvalidOperationException();
StackItem[] args = new StackItem[token.ParametersCount];
for (int i = 0; i < token.ParametersCount; i++)
args[i] = Pop();
return CallContractInternal(token.Hash, token.Method, token.CallFlags, token.HasReturnValue, args);
}

protected internal StackItem Convert(object value)
{
if (value is IDisposable disposable) Disposables.Add(disposable);
Expand Down
5 changes: 5 additions & 0 deletions src/neo/SmartContract/ExecutionContextState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ public class ExecutionContextState
/// </summary>
public UInt160 CallingScriptHash { get; set; }

/// <summary>
/// The ContractState of the current context.
/// </summary>
public ContractState Contract { get; set; }

/// <summary>
/// Execution context rights
/// </summary>
Expand Down
8 changes: 4 additions & 4 deletions tests/neo.UnitTests/SmartContract/UT_InteropService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -597,20 +597,20 @@ public void TestContract_Call()
engine.LoadScript(new byte[] { 0x01 });

engine.Push(args[1]); engine.Push(args[0]);
engine.CallContract(state.Hash, method, CallFlags.All, false, (ushort)args.Count);
engine.CallContract(state.Hash, method, CallFlags.All, true, (ushort)args.Count);
engine.CurrentContext.EvaluationStack.Pop().Should().Be(args[0]);
engine.CurrentContext.EvaluationStack.Pop().Should().Be(args[1]);

state.Manifest.Permissions[0].Methods = WildcardContainer<string>.Create("a");
engine.Push(args[1]); engine.Push(args[0]);
Assert.ThrowsException<InvalidOperationException>(() => engine.CallContract(state.Hash, method, CallFlags.All, false, (ushort)args.Count));
Assert.ThrowsException<InvalidOperationException>(() => engine.CallContract(state.Hash, method, CallFlags.All, true, (ushort)args.Count));

state.Manifest.Permissions[0].Methods = WildcardContainer<string>.CreateWildcard();
engine.Push(args[1]); engine.Push(args[0]);
engine.CallContract(state.Hash, method, CallFlags.All, false, (ushort)args.Count);
engine.CallContract(state.Hash, method, CallFlags.All, true, (ushort)args.Count);

engine.Push(args[1]); engine.Push(args[0]);
Assert.ThrowsException<InvalidOperationException>(() => engine.CallContract(UInt160.Zero, method, CallFlags.All, false, (ushort)args.Count));
Assert.ThrowsException<InvalidOperationException>(() => engine.CallContract(UInt160.Zero, method, CallFlags.All, true, (ushort)args.Count));
}

[TestMethod]
Expand Down

0 comments on commit 3ee76ba

Please sign in to comment.