Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved reliability of Task.GetResultOrDefault #3410

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions Microsoft.Toolkit/Extensions/TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.Reflection;
Expand Down Expand Up @@ -41,28 +40,21 @@ public static class TaskExtensions
#endif
)
{
Type taskType = task.GetType();

// Check if the task is actually some Task<T>
if (
#if NETSTANDARD1_4
taskType.GetTypeInfo().IsGenericType &&
#else
taskType.IsGenericType &&
#endif
taskType.GetGenericTypeDefinition() == typeof(Task<>))
{
// Get the Task<T>.Result property
PropertyInfo propertyInfo =
// Try to get the Task<T>.Result property. This method would've
// been called anyway after the type checks, but using that to
// validate the input type saves some additional reflection calls.
// Furthermore, doing this also makes the method flexible enough to
// cases whether the input Task<T> is actually an instance of some
// runtime-specific type that inherits from Task<T>.
PropertyInfo? propertyInfo =
#if NETSTANDARD1_4
taskType.GetRuntimeProperty(nameof(Task<object>.Result));
task.GetType().GetRuntimeProperty(nameof(Task<object>.Result));
#else
taskType.GetProperty(nameof(Task<object>.Result));
task.GetType().GetProperty(nameof(Task<object>.Result));
#endif

// Finally retrieve the result
return propertyInfo!.GetValue(task);
}
// Return the result, if possible
return propertyInfo?.GetValue(task);
}

return null;
Expand Down
30 changes: 30 additions & 0 deletions UnitTests/UnitTests.Shared/Extensions/Test_TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ public void Test_TaskExtensions_ResultOrDefault()
Assert.AreEqual(42, ((Task)tcs.Task).GetResultOrDefault());
}

[TestCategory("TaskExtensions")]
[TestMethod]
public async Task Test_TaskExtensions_ResultOrDefault_FromAsyncTaskMethodBuilder()
{
var tcs = new TaskCompletionSource<object>();

Task<string> taskFromBuilder = GetTaskFromAsyncMethodBuilder("Test", tcs);

Assert.IsNull(((Task)taskFromBuilder).GetResultOrDefault());
Assert.IsNull(taskFromBuilder.GetResultOrDefault());

tcs.SetResult(null);

await taskFromBuilder;

Assert.AreEqual(((Task)taskFromBuilder).GetResultOrDefault(), "Test");
Assert.AreEqual(taskFromBuilder.GetResultOrDefault(), "Test");
}

[TestCategory("TaskExtensions")]
[TestMethod]
public void Test_TaskExtensions_ResultOrDefault_OfT_Int32()
Expand Down Expand Up @@ -86,5 +105,16 @@ public void Test_TaskExtensions_ResultOrDefault_OfT_String()

Assert.AreEqual("Hello world", tcs.Task.GetResultOrDefault());
}

// Creates a Task<T> of a given type which is actually an instance of
// System.Runtime.CompilerServices.AsyncTaskMethodBuilder<TResult>.AsyncStateMachineBox<TStateMachine>.
// See https://source.dot.net/#System.Private.CoreLib/AsyncTaskMethodBuilderT.cs,f8f35fd356112b30.
// This is needed to verify that the extension also works when the input Task<T> is of a derived type.
private static async Task<T> GetTaskFromAsyncMethodBuilder<T>(T result, TaskCompletionSource<object> tcs)
{
await tcs.Task;

return result;
}
}
}