Skip to content

Commit

Permalink
Merge branch 'master' into feature/mvvm-toolkit-part2
Browse files Browse the repository at this point in the history
  • Loading branch information
azchohfi authored Aug 11, 2020
2 parents b5c0272 + f8b78a3 commit 3579e24
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
30 changes: 11 additions & 19 deletions Microsoft.Toolkit/Extensions/TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.Reflection;
Expand Down Expand Up @@ -41,28 +40,21 @@ public static class TaskExtensions
#endif
)
{
Type taskType = task.GetType();

// Check if the task is actually some Task<T>
if (
#if NETSTANDARD1_4
taskType.GetTypeInfo().IsGenericType &&
#else
taskType.IsGenericType &&
#endif
taskType.GetGenericTypeDefinition() == typeof(Task<>))
{
// Get the Task<T>.Result property
PropertyInfo propertyInfo =
// Try to get the Task<T>.Result property. This method would've
// been called anyway after the type checks, but using that to
// validate the input type saves some additional reflection calls.
// Furthermore, doing this also makes the method flexible enough to
// cases whether the input Task<T> is actually an instance of some
// runtime-specific type that inherits from Task<T>.
PropertyInfo? propertyInfo =
#if NETSTANDARD1_4
taskType.GetRuntimeProperty(nameof(Task<object>.Result));
task.GetType().GetRuntimeProperty(nameof(Task<object>.Result));
#else
taskType.GetProperty(nameof(Task<object>.Result));
task.GetType().GetProperty(nameof(Task<object>.Result));
#endif

// Finally retrieve the result
return propertyInfo!.GetValue(task);
}
// Return the result, if possible
return propertyInfo?.GetValue(task);
}

return null;
Expand Down
30 changes: 30 additions & 0 deletions UnitTests/UnitTests.Shared/Extensions/Test_TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ public void Test_TaskExtensions_ResultOrDefault()
Assert.AreEqual(42, ((Task)tcs.Task).GetResultOrDefault());
}

[TestCategory("TaskExtensions")]
[TestMethod]
public async Task Test_TaskExtensions_ResultOrDefault_FromAsyncTaskMethodBuilder()
{
var tcs = new TaskCompletionSource<object>();

Task<string> taskFromBuilder = GetTaskFromAsyncMethodBuilder("Test", tcs);

Assert.IsNull(((Task)taskFromBuilder).GetResultOrDefault());
Assert.IsNull(taskFromBuilder.GetResultOrDefault());

tcs.SetResult(null);

await taskFromBuilder;

Assert.AreEqual(((Task)taskFromBuilder).GetResultOrDefault(), "Test");
Assert.AreEqual(taskFromBuilder.GetResultOrDefault(), "Test");
}

[TestCategory("TaskExtensions")]
[TestMethod]
public void Test_TaskExtensions_ResultOrDefault_OfT_Int32()
Expand Down Expand Up @@ -86,5 +105,16 @@ public void Test_TaskExtensions_ResultOrDefault_OfT_String()

Assert.AreEqual("Hello world", tcs.Task.GetResultOrDefault());
}

// Creates a Task<T> of a given type which is actually an instance of
// System.Runtime.CompilerServices.AsyncTaskMethodBuilder<TResult>.AsyncStateMachineBox<TStateMachine>.
// See https://source.dot.net/#System.Private.CoreLib/AsyncTaskMethodBuilderT.cs,f8f35fd356112b30.
// This is needed to verify that the extension also works when the input Task<T> is of a derived type.
private static async Task<T> GetTaskFromAsyncMethodBuilder<T>(T result, TaskCompletionSource<object> tcs)
{
await tcs.Task;

return result;
}
}
}

0 comments on commit 3579e24

Please sign in to comment.