diff --git a/src/libraries/Common/tests/Tests/System/IO/StreamConformanceTests.cs b/src/libraries/Common/tests/Tests/System/IO/StreamConformanceTests.cs index cde50f84d6d4ca..8e77d98834aad7 100644 --- a/src/libraries/Common/tests/Tests/System/IO/StreamConformanceTests.cs +++ b/src/libraries/Common/tests/Tests/System/IO/StreamConformanceTests.cs @@ -499,36 +499,36 @@ e is NotSupportedException || } } + protected async Task AssertCanceledAsync(CancellationToken cancellationToken, Func testCode) + { + OperationCanceledException oce = await Assert.ThrowsAnyAsync(testCode); + if (cancellationToken.CanBeCanceled) + { + Assert.Equal(cancellationToken, oce.CancellationToken); + } + } + protected async Task ValidatePrecanceledOperations_ThrowsCancellationException(Stream stream) { var cts = new CancellationTokenSource(); cts.Cancel(); - OperationCanceledException oce; - if (stream.CanRead) { - oce = await Assert.ThrowsAnyAsync(() => stream.ReadAsync(new byte[1], 0, 1, cts.Token)); - Assert.Equal(cts.Token, oce.CancellationToken); - - oce = await Assert.ThrowsAnyAsync(async () => { await stream.ReadAsync(new Memory(new byte[1]), cts.Token); }); - Assert.Equal(cts.Token, oce.CancellationToken); + await AssertCanceledAsync(cts.Token, () => stream.ReadAsync(new byte[1], 0, 1, cts.Token)); + await AssertCanceledAsync(cts.Token, async () => { await stream.ReadAsync(new Memory(new byte[1]), cts.Token); }); } if (stream.CanWrite) { - oce = await Assert.ThrowsAnyAsync(() => stream.WriteAsync(new byte[1], 0, 1, cts.Token)); - Assert.Equal(cts.Token, oce.CancellationToken); - - oce = await Assert.ThrowsAnyAsync(async () => { await stream.WriteAsync(new ReadOnlyMemory(new byte[1]), cts.Token); }); - Assert.Equal(cts.Token, oce.CancellationToken); + await AssertCanceledAsync(cts.Token, () => stream.WriteAsync(new byte[1], 0, 1, cts.Token)); + await AssertCanceledAsync(cts.Token, async () => { await stream.WriteAsync(new ReadOnlyMemory(new byte[1]), cts.Token); }); } Exception e = await Record.ExceptionAsync(() => stream.FlushAsync(cts.Token)); if (e != null) { - oce = Assert.IsAssignableFrom(e); - Assert.Equal(cts.Token, oce.CancellationToken); + Assert.Equal(cts.Token, Assert.IsAssignableFrom(e).CancellationToken); } } @@ -540,15 +540,12 @@ protected async Task ValidateCancelableReads_AfterInvocation_ThrowsCancellationE } CancellationTokenSource cts; - OperationCanceledException oce; cts = new CancellationTokenSource(1); - oce = await Assert.ThrowsAnyAsync(() => stream.ReadAsync(new byte[1], 0, 1, cts.Token)); - Assert.Equal(cts.Token, oce.CancellationToken); + await AssertCanceledAsync(cts.Token, () => stream.ReadAsync(new byte[1], 0, 1, cts.Token)); cts = new CancellationTokenSource(1); - oce = await Assert.ThrowsAnyAsync(async () => { await stream.ReadAsync(new Memory(new byte[1]), cts.Token); }); - Assert.Equal(cts.Token, oce.CancellationToken); + await AssertCanceledAsync(cts.Token, async () => { await stream.ReadAsync(new Memory(new byte[1]), cts.Token); }); } protected async Task WhenAllOrAnyFailed(Task task1, Task task2) @@ -2398,18 +2395,22 @@ public virtual async Task ReadAsync_CancelPendingRead_DoesntImpactSubsequentRead using StreamPair streams = await CreateConnectedStreamsAsync(); foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams)) { - await Assert.ThrowsAnyAsync(() => readable.ReadAsync(new byte[1], 0, 1, new CancellationToken(true))); - await Assert.ThrowsAnyAsync(async () => { await readable.ReadAsync(new Memory(new byte[1]), new CancellationToken(true)); }); + CancellationTokenSource cts; - var cts = new CancellationTokenSource(); + cts = new CancellationTokenSource(); + cts.Cancel(); + await AssertCanceledAsync(cts.Token, () => readable.ReadAsync(new byte[1], 0, 1, cts.Token)); + await AssertCanceledAsync(cts.Token, async () => { await readable.ReadAsync(new Memory(new byte[1]), cts.Token); }); + + cts = new CancellationTokenSource(); Task t = readable.ReadAsync(new byte[1], 0, 1, cts.Token); cts.Cancel(); - await Assert.ThrowsAnyAsync(() => t); + await AssertCanceledAsync(cts.Token, () => t); cts = new CancellationTokenSource(); ValueTask vt = readable.ReadAsync(new Memory(new byte[1]), cts.Token); cts.Cancel(); - await Assert.ThrowsAnyAsync(async () => await vt); + await AssertCanceledAsync(cts.Token, async () => await vt); byte[] buffer = new byte[1]; vt = readable.ReadAsync(new Memory(buffer));