diff --git a/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj b/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj index 82fb040b46eb..2c4ae67226a5 100644 --- a/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj +++ b/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj @@ -4,7 +4,7 @@ {F24D3391-2928-4E83-AADE-B34423498750} System.Threading.Tasks.Extensions - 4.0.1.0 + 4.1.0.0 $(OutputPath)$(AssemblyName).xml true @@ -22,6 +22,7 @@ + diff --git a/src/System.Threading.Tasks.Extensions/src/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilder.cs b/src/System.Threading.Tasks.Extensions/src/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilder.cs new file mode 100644 index 000000000000..e26a790e3af8 --- /dev/null +++ b/src/System.Threading.Tasks.Extensions/src/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilder.cs @@ -0,0 +1,103 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.InteropServices; +using System.Security; +using System.Threading.Tasks; + +namespace System.Runtime.CompilerServices +{ + /// Represents a builder for asynchronous methods that returns a . + /// The type of the result. + [StructLayout(LayoutKind.Auto)] + public struct AsyncValueTaskMethodBuilder + { + /// The to which most operations are delegated. + private AsyncTaskMethodBuilder _methodBuilder; + /// The result for this builder, if it's completed before any awaits occur. + private TResult _result; + /// true if contains the synchronous result for the async method; otherwise, false. + private bool _haveResult; + /// true if the builder should be used for setting/getting the result; otherwise, false. + private bool _useBuilder; + + /// Creates an instance of the struct. + /// The initialized instance. + public static AsyncValueTaskMethodBuilder Create() => + new AsyncValueTaskMethodBuilder() { _methodBuilder = AsyncTaskMethodBuilder.Create() }; + + /// Begins running the builder with the associated state machine. + /// The type of the state machine. + /// The state machine instance, passed by reference. + public void Start(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine + { + _methodBuilder.Start(ref stateMachine); // will provide the right ExecutionContext semantics + } + + /// Associates the builder with the specified state machine. + /// The state machine instance to associate with the builder. + public void SetStateMachine(IAsyncStateMachine stateMachine) => _methodBuilder.SetStateMachine(stateMachine); + + /// Marks the task as successfully completed. + /// The result to use to complete the task. + public void SetResult(TResult result) + { + if (_useBuilder) + { + _methodBuilder.SetResult(result); + } + else + { + _result = result; + _haveResult = true; + } + } + + /// Marks the task as failed and binds the specified exception to the task. + /// The exception to bind to the task. + public void SetException(Exception exception) => _methodBuilder.SetException(exception); + + /// Gets the task for this builder. + public ValueTask Task + { + get + { + if (_haveResult) + { + return new ValueTask(_result); + } + else + { + _useBuilder = true; + return new ValueTask(_methodBuilder.Task); + } + } + } + + /// Schedules the state machine to proceed to the next action when the specified awaiter completes. + /// The type of the awaiter. + /// The type of the state machine. + /// the awaiter + /// The state machine. + public void AwaitOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) + where TAwaiter : INotifyCompletion + where TStateMachine : IAsyncStateMachine + { + _methodBuilder.AwaitOnCompleted(ref awaiter, ref stateMachine); + } + + /// Schedules the state machine to proceed to the next action when the specified awaiter completes. + /// The type of the awaiter. + /// The type of the state machine. + /// the awaiter + /// The state machine. + [SecuritySafeCritical] + public void AwaitUnsafeOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) + where TAwaiter : ICriticalNotifyCompletion + where TStateMachine : IAsyncStateMachine + { + _methodBuilder.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine); + } + } +} diff --git a/src/System.Threading.Tasks.Extensions/src/System/Threading/Tasks/ValueTask.cs b/src/System.Threading.Tasks.Extensions/src/System/Threading/Tasks/ValueTask.cs index 9a7102f7a42e..7109ec5c6f85 100644 --- a/src/System.Threading.Tasks.Extensions/src/System/Threading/Tasks/ValueTask.cs +++ b/src/System.Threading.Tasks.Extensions/src/System/Threading/Tasks/ValueTask.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.ComponentModel; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -174,5 +175,10 @@ public override string ToString() string.Empty; } } + + /// Creates a method builder for use with an async method. + /// The created builder. + [EditorBrowsable(EditorBrowsableState.Never)] // intended only for compiler consumption + public static AsyncValueTaskMethodBuilder CreateAsyncMethodBuilder() => AsyncValueTaskMethodBuilder.Create(); } } diff --git a/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs b/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs new file mode 100644 index 000000000000..8d11cbac674f --- /dev/null +++ b/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Xunit; + +namespace System.Threading.Tasks.Tests +{ + public class AsyncValueTaskMethodBuilderTests + { + [Fact] + public void Create_ReturnsDefaultInstance() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + Assert.Equal(default(AsyncValueTaskMethodBuilder), b); // implementation detail being verified + } + + [Fact] + public void SetResult_BeforeAccessTask_ValueTaskContainsValue() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + b.SetResult(42); + ValueTask vt = b.Task; + Assert.True(vt.IsCompletedSuccessfully); + Assert.NotSame(vt.AsTask(), vt.AsTask()); // will be different if completed synchronously + Assert.Equal(42, vt.Result); + } + + [Fact] + public void SetResult_AfterAccessTask_ValueTaskContainsValue() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + ValueTask vt = b.Task; + b.SetResult(42); + Assert.True(vt.IsCompletedSuccessfully); + Assert.Same(vt.AsTask(), vt.AsTask()); // will be safe if completed asynchronously + Assert.Equal(42, vt.Result); + } + + [Fact] + public void SetException_BeforeAccessTask_FaultsTask() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + var e = new FormatException(); + b.SetException(e); + ValueTask vt = b.Task; + Assert.True(vt.IsFaulted); + Assert.Same(e, Assert.Throws(() => vt.GetAwaiter().GetResult())); + } + + [Fact] + public void SetException_AfterAccessTask_FaultsTask() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + var e = new FormatException(); + ValueTask vt = b.Task; + b.SetException(e); + Assert.True(vt.IsFaulted); + Assert.Same(e, Assert.Throws(() => vt.GetAwaiter().GetResult())); + } + + [Fact] + public void SetException_OperationCanceledException_CancelsTask() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + var e = new OperationCanceledException(); + ValueTask vt = b.Task; + b.SetException(e); + Assert.True(vt.IsCanceled); + Assert.Same(e, Assert.Throws(() => vt.GetAwaiter().GetResult())); + } + + [Fact] + public void Start_InvokesMoveNext() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + int invokes = 0; + var dsm = new DelegateStateMachine { MoveNextDelegate = () => invokes++ }; + b.Start(ref dsm); + Assert.Equal(1, invokes); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task AwaitOnCompleted_InvokesStateMachineMethods(bool awaitUnsafe) + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + var ignored = b.Task; + + var callbackCompleted = new TaskCompletionSource(); + IAsyncStateMachine foundSm = null; + var dsm = new DelegateStateMachine + { + MoveNextDelegate = () => callbackCompleted.SetResult(true), + SetStateMachineDelegate = sm => foundSm = sm + }; + + TaskAwaiter t = Task.CompletedTask.GetAwaiter(); + if (awaitUnsafe) + { + b.AwaitUnsafeOnCompleted(ref t, ref dsm); + } + else + { + b.AwaitOnCompleted(ref t, ref dsm); + } + + await callbackCompleted.Task; + Assert.Equal(dsm, foundSm); + } + + [Fact] + public void SetStateMachine_InvalidArgument_ThrowsException() + { + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + Assert.Throws("stateMachine", () => b.SetStateMachine(null)); + b.SetStateMachine(new DelegateStateMachine()); + } + + [Fact] + public void Start_ExecutionContextChangesInMoveNextDontFlowOut() + { + var al = new AsyncLocal { Value = 0 }; + int calls = 0; + + var dsm = new DelegateStateMachine + { + MoveNextDelegate = () => + { + al.Value++; + calls++; + } + }; + + dsm.MoveNext(); + Assert.Equal(1, al.Value); + Assert.Equal(1, calls); + + dsm.MoveNext(); + Assert.Equal(2, al.Value); + Assert.Equal(2, calls); + + AsyncValueTaskMethodBuilder b = ValueTask.CreateAsyncMethodBuilder(); + b.Start(ref dsm); + Assert.Equal(2, al.Value); // change should not be visible + Assert.Equal(3, calls); + + // Make sure we've not caused the Task to be allocated + b.SetResult(42); + ValueTask vt = b.Task; + Assert.NotSame(vt.AsTask(), vt.AsTask()); + } + + private struct DelegateStateMachine : IAsyncStateMachine + { + internal Action MoveNextDelegate; + public void MoveNext() => MoveNextDelegate?.Invoke(); + + internal Action SetStateMachineDelegate; + public void SetStateMachine(IAsyncStateMachine stateMachine) => SetStateMachineDelegate?.Invoke(stateMachine); + } + } +} diff --git a/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj b/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj index f1e33a31b641..94a8862d4c28 100644 --- a/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj +++ b/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj @@ -1,4 +1,4 @@ - + @@ -13,6 +13,7 @@ + @@ -25,4 +26,4 @@ - + \ No newline at end of file