From 4225331fdf0858ed476959dc2c07f1d666994d2e Mon Sep 17 00:00:00 2001 From: Erik Zhang Date: Sun, 24 Jul 2022 10:18:37 +0800 Subject: [PATCH 1/2] Improve ApplicationEngine --- .../Conditions/CalledByEntryCondition.cs | 3 +- .../ApplicationEngine.Contract.cs | 2 +- .../ApplicationEngine.Runtime.cs | 2 ++ src/Neo/SmartContract/ApplicationEngine.cs | 28 ++++++++++++++----- .../SmartContract/ExecutionContextState.cs | 11 ++++++-- .../Extensions/NativeContractExtensions.cs | 4 +-- .../UT_ApplicationEngine.Runtime.cs | 1 + .../SmartContract/UT_InteropService.cs | 2 ++ 8 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs b/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs index a4dd2277f6..6fa15a2965 100644 --- a/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs +++ b/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs @@ -24,7 +24,8 @@ protected override void DeserializeWithoutType(ref MemoryReader reader, int maxN public override bool Match(ApplicationEngine engine) { - return engine.CallingScriptHash is null || engine.CallingScriptHash == engine.EntryScriptHash; + var state = engine.CurrentContext.GetState(); + return state.CallingContext is null || state.CallingContext == engine.EntryContext; } protected override void SerializeWithoutType(BinaryWriter writer) diff --git a/src/Neo/SmartContract/ApplicationEngine.Contract.cs b/src/Neo/SmartContract/ApplicationEngine.Contract.cs index b4ad8719b1..f3d33e7d77 100644 --- a/src/Neo/SmartContract/ApplicationEngine.Contract.cs +++ b/src/Neo/SmartContract/ApplicationEngine.Contract.cs @@ -82,7 +82,7 @@ protected internal void CallContract(UInt160 contractHash, string method, CallFl bool hasReturnValue = md.ReturnType != ContractParameterType.Void; ExecutionContext context = CallContractInternal(contract, md, callFlags, hasReturnValue, args); - if (!hasReturnValue) context.GetState().PushNullWhenReturn = true; + context.GetState().IsDynamicCall = true; } /// diff --git a/src/Neo/SmartContract/ApplicationEngine.Runtime.cs b/src/Neo/SmartContract/ApplicationEngine.Runtime.cs index 6b9a20a804..f4f4140d42 100644 --- a/src/Neo/SmartContract/ApplicationEngine.Runtime.cs +++ b/src/Neo/SmartContract/ApplicationEngine.Runtime.cs @@ -306,6 +306,8 @@ protected internal void RuntimeLog(byte[] state) protected internal void RuntimeNotify(byte[] eventName, Array state) { if (eventName.Length > MaxEventName) throw new ArgumentException(null, nameof(eventName)); + if (CurrentContext.GetState().Contract is null) + throw new InvalidOperationException("Notifications are not allowed in dynamic scripts."); using MemoryStream ms = new(MaxNotificationSize); using BinaryWriter writer = new(ms, Utility.StrictUTF8, true); BinarySerializer.Serialize(writer, state, MaxNotificationSize); diff --git a/src/Neo/SmartContract/ApplicationEngine.cs b/src/Neo/SmartContract/ApplicationEngine.cs index 66bc453eed..5889c3c11b 100644 --- a/src/Neo/SmartContract/ApplicationEngine.cs +++ b/src/Neo/SmartContract/ApplicationEngine.cs @@ -123,7 +123,15 @@ public partial class ApplicationEngine : ExecutionEngine /// /// The script hash of the calling contract. This field could be if the current context is the entry context. /// - public UInt160 CallingScriptHash => CurrentContext?.GetState().CallingScriptHash; + public UInt160 CallingScriptHash + { + get + { + if (CurrentContext is null) return null; + var state = CurrentContext.GetState(); + return state.NativeCallingScriptHash ?? state.CallingContext?.GetState().ScriptHash; + } + } /// /// The script hash of the entry context. This field could be if no context is loaded to the engine. @@ -224,15 +232,15 @@ private ExecutionContext CallContractInternal(ContractState contract, ContractMe invocationCounter[contract.Hash] = 1; } - ExecutionContextState state = CurrentContext.GetState(); - UInt160 callingScriptHash = state.ScriptHash; + ExecutionContext currentContext = CurrentContext; + ExecutionContextState state = currentContext.GetState(); CallFlags callingFlags = state.CallFlags; if (args.Count != method.Parameters.Length) throw new InvalidOperationException($"Method {method} Expects {method.Parameters.Length} Arguments But Receives {args.Count} Arguments"); if (hasReturnValue ^ (method.ReturnType != ContractParameterType.Void)) throw new InvalidOperationException("The return value type does not match."); ExecutionContext context_new = LoadContract(contract, method, flags & callingFlags); state = context_new.GetState(); - state.CallingScriptHash = callingScriptHash; + state.CallingContext = currentContext; for (int i = args.Count - 1; i >= 0; i--) context_new.EvaluationStack.Push(args[i]); @@ -244,7 +252,7 @@ internal ContractTask CallFromNativeContract(UInt160 callingScriptHash, UInt160 { ExecutionContext context_new = CallContractInternal(hash, method, CallFlags.All, false, args); ExecutionContextState state = context_new.GetState(); - state.CallingScriptHash = callingScriptHash; + state.NativeCallingScriptHash = callingScriptHash; ContractTask task = new(); contractTasks.Add(context_new, task.GetAwaiter()); return task; @@ -254,7 +262,7 @@ internal ContractTask CallFromNativeContract(UInt160 callingScriptHash, UI { ExecutionContext context_new = CallContractInternal(hash, method, CallFlags.All, true, args); ExecutionContextState state = context_new.GetState(); - state.CallingScriptHash = callingScriptHash; + state.NativeCallingScriptHash = callingScriptHash; ContractTask task = new(); contractTasks.Add(context_new, task.GetAwaiter()); return task; @@ -273,7 +281,13 @@ protected override void ContextUnloaded(ExecutionContext context) { ExecutionContextState contextState = CurrentContext.GetState(); contextState.NotificationCount += state.NotificationCount; - if (state.PushNullWhenReturn) Push(StackItem.Null); + if (state.IsDynamicCall) + { + if (context.EvaluationStack.Count == 0) + Push(StackItem.Null); + else if (context.EvaluationStack.Count > 1) + throw new NotSupportedException("Multiple return values are not allowed in cross-contract calls."); + } } } else diff --git a/src/Neo/SmartContract/ExecutionContextState.cs b/src/Neo/SmartContract/ExecutionContextState.cs index e912fba9e8..5e3aa3ef42 100644 --- a/src/Neo/SmartContract/ExecutionContextState.cs +++ b/src/Neo/SmartContract/ExecutionContextState.cs @@ -24,9 +24,14 @@ public class ExecutionContextState public UInt160 ScriptHash { get; set; } /// - /// The script hash of the calling contract. + /// The calling context. /// - public UInt160 CallingScriptHash { get; set; } + public ExecutionContext CallingContext { get; set; } + + /// + /// The script hash of the calling native contract. Used in native contracts only. + /// + internal UInt160 NativeCallingScriptHash { get; set; } /// /// The of the current context. @@ -42,6 +47,6 @@ public class ExecutionContextState public int NotificationCount { get; set; } - public bool PushNullWhenReturn { get; set; } + public bool IsDynamicCall { get; set; } } } diff --git a/tests/Neo.UnitTests/Extensions/NativeContractExtensions.cs b/tests/Neo.UnitTests/Extensions/NativeContractExtensions.cs index 58ab1bf5b1..437f98bbe4 100644 --- a/tests/Neo.UnitTests/Extensions/NativeContractExtensions.cs +++ b/tests/Neo.UnitTests/Extensions/NativeContractExtensions.cs @@ -42,7 +42,7 @@ public static void UpdateContract(this DataCache snapshot, UInt160 callingScript // Fake calling script hash if (callingScriptHash != null) { - engine.CurrentContext.GetState().CallingScriptHash = callingScriptHash; + engine.CurrentContext.GetState().NativeCallingScriptHash = callingScriptHash; engine.CurrentContext.GetState().ScriptHash = callingScriptHash; } @@ -65,7 +65,7 @@ public static void DestroyContract(this DataCache snapshot, UInt160 callingScrip // Fake calling script hash if (callingScriptHash != null) { - engine.CurrentContext.GetState().CallingScriptHash = callingScriptHash; + engine.CurrentContext.GetState().NativeCallingScriptHash = callingScriptHash; engine.CurrentContext.GetState().ScriptHash = callingScriptHash; } diff --git a/tests/Neo.UnitTests/SmartContract/UT_ApplicationEngine.Runtime.cs b/tests/Neo.UnitTests/SmartContract/UT_ApplicationEngine.Runtime.cs index 20e2508e83..61bef60e2c 100644 --- a/tests/Neo.UnitTests/SmartContract/UT_ApplicationEngine.Runtime.cs +++ b/tests/Neo.UnitTests/SmartContract/UT_ApplicationEngine.Runtime.cs @@ -24,6 +24,7 @@ public void TestNotSupportedNotification() { using var engine = ApplicationEngine.Create(TriggerType.Application, null, null, TestBlockchain.TheNeoSystem.GenesisBlock, settings: TestBlockchain.TheNeoSystem.Settings, gas: 1100_00000000); engine.LoadScript(Array.Empty()); + engine.CurrentContext.GetState().Contract = new(); // circular diff --git a/tests/Neo.UnitTests/SmartContract/UT_InteropService.cs b/tests/Neo.UnitTests/SmartContract/UT_InteropService.cs index b56198f5a0..0494ab22e0 100644 --- a/tests/Neo.UnitTests/SmartContract/UT_InteropService.cs +++ b/tests/Neo.UnitTests/SmartContract/UT_InteropService.cs @@ -93,6 +93,7 @@ public void Runtime_GetNotifications_Test() // Execute engine.LoadScript(script.ToArray()); + engine.CurrentContext.GetState().Contract = new(); var currentScriptHash = engine.EntryScriptHash; Assert.AreEqual(VMState.HALT, engine.Execute()); @@ -145,6 +146,7 @@ public void Runtime_GetNotifications_Test() // Execute engine.LoadScript(script.ToArray()); + engine.CurrentContext.GetState().Contract = new(); var currentScriptHash = engine.EntryScriptHash; Assert.AreEqual(VMState.HALT, engine.Execute()); From 6321f860b6a10bb0df1903ca9605e4e2bb3fd9ec Mon Sep 17 00:00:00 2001 From: Erik Zhang Date: Mon, 1 Aug 2022 08:09:09 +0800 Subject: [PATCH 2/2] Improve CalledByEntryCondition.Match() --- .../Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs b/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs index 6fa15a2965..4e5455af57 100644 --- a/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs +++ b/src/Neo/Network/P2P/Payloads/Conditions/CalledByEntryCondition.cs @@ -25,7 +25,9 @@ protected override void DeserializeWithoutType(ref MemoryReader reader, int maxN public override bool Match(ApplicationEngine engine) { var state = engine.CurrentContext.GetState(); - return state.CallingContext is null || state.CallingContext == engine.EntryContext; + if (state.CallingContext is null) return true; + state = state.CallingContext.GetState(); + return state.CallingContext is null; } protected override void SerializeWithoutType(BinaryWriter writer)