diff --git a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs
index d41e07217ee335..43451d45301128 100644
--- a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs
+++ b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs
@@ -1562,11 +1562,6 @@ public abstract class ConnectedStreamConformanceTests : StreamConformanceTests
/// Gets whether the stream guarantees that all data written to it will be flushed as part of Flush{Async}.
///
protected virtual bool FlushGuaranteesAllDataWritten => true;
- ///
- /// Gets whether a stream implements an aggressive read that tries to fill the supplied buffer and only
- /// stops when it does so or hits EOF.
- ///
- protected virtual bool ReadsMayBlockUntilBufferFullOrEOF => false;
/// Gets whether reads for a count of 0 bytes block if no bytes are available to read.
protected virtual bool BlocksOnZeroByteReads => false;
///
@@ -1709,6 +1704,10 @@ public virtual async Task ReadWriteByte_Success()
}
}
+ public static IEnumerable ReadWrite_Modes =>
+ from mode in Enum.GetValues()
+ select new object[] { mode };
+
public static IEnumerable ReadWrite_Success_MemberData() =>
from mode in Enum.GetValues()
from writeSize in new[] { 1, 42, 10 * 1024 }
@@ -1785,6 +1784,54 @@ public virtual async Task ReadWrite_Success(ReadWriteMode mode, int writeSize, b
}
}
+ [Theory]
+ [MemberData(nameof(ReadWrite_Modes))]
+ [ActiveIssue("https://github.com/dotnet/runtime/issues/51371", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)]
+ public virtual async Task ReadWrite_MessagesSmallerThanReadBuffer_Success(ReadWriteMode mode)
+ {
+ if (!FlushGuaranteesAllDataWritten)
+ {
+ return;
+ }
+
+ foreach (CancellationToken nonCanceledToken in new[] { CancellationToken.None, new CancellationTokenSource().Token })
+ {
+ using StreamPair streams = await CreateConnectedStreamsAsync();
+
+ foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams))
+ {
+ byte[] writerBytes = RandomNumberGenerator.GetBytes(512);
+ var readerBytes = new byte[writerBytes.Length * 2];
+
+ // Repeatedly write then read a message smaller in size than the read buffer
+ for (int i = 0; i < 5; i++)
+ {
+ Task writes = Task.Run(async () =>
+ {
+ await WriteAsync(mode, writeable, writerBytes, 0, writerBytes.Length, nonCanceledToken);
+ if (FlushRequiredToWriteData)
+ {
+ await writeable.FlushAsync();
+ }
+ });
+
+ int n = 0;
+ while (n < writerBytes.Length)
+ {
+ int r = await ReadAsync(mode, readable, readerBytes, n, readerBytes.Length - n);
+ Assert.InRange(r, 1, writerBytes.Length - n);
+ n += r;
+ }
+
+ Assert.Equal(writerBytes.Length, n);
+ AssertExtensions.SequenceEqual(writerBytes, readerBytes.AsSpan(0, writerBytes.Length));
+
+ await writes;
+ }
+ }
+ }
+ }
+
[Theory]
[MemberData(nameof(AllReadWriteModesAndValue), false)]
[MemberData(nameof(AllReadWriteModesAndValue), true)]
@@ -2160,6 +2207,10 @@ public virtual async Task ZeroByteRead_BlocksUntilDataAvailableOrNops(ReadWriteM
});
Assert.Equal(0, await zeroByteRead);
+ // Perform a second zero-byte read.
+ await Task.Run(() => ReadAsync(mode, readable, Array.Empty(), 0, 0));
+
+ // Now consume all the data.
var readBytes = new byte[5];
int count = 0;
while (count < readBytes.Length)
@@ -2684,7 +2735,7 @@ public virtual async Task Flush_FlushesUnderlyingStream(bool flushAsync)
[InlineData(true, true)]
public virtual async Task Dispose_Flushes(bool useAsync, bool leaveOpen)
{
- if (leaveOpen && (!SupportsLeaveOpen || ReadsMayBlockUntilBufferFullOrEOF))
+ if (leaveOpen && !SupportsLeaveOpen)
{
return;
}
diff --git a/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs b/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs
index 366c547d2c600d..24f32fab04b045 100644
--- a/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs
+++ b/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs
@@ -54,6 +54,6 @@ protected override Task CreateWrappedConnectedStreamsAsync(StreamPai
protected override Type UnsupportedReadWriteExceptionType => typeof(InvalidOperationException);
protected override bool WrappedUsableAfterClose => false;
protected override bool FlushRequiredToWriteData => true;
- protected override bool FlushGuaranteesAllDataWritten => false;
+ protected override bool BlocksOnZeroByteReads => true;
}
}
diff --git a/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs b/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs
index e8a0c08c3df26c..6f319461c35ef4 100644
--- a/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs
+++ b/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs
@@ -65,6 +65,17 @@ public static void ReadBytes(Stream stream, byte[] buffer, long bytesToRead)
}
}
+ public static int ReadAllBytes(Stream stream, byte[] buffer, int offset, int count)
+ {
+ int bytesRead;
+ int totalRead = 0;
+ while ((bytesRead = stream.Read(buffer, offset + totalRead, count - totalRead)) != 0)
+ {
+ totalRead += bytesRead;
+ }
+ return totalRead;
+ }
+
public static bool ArraysEqual(T[] a, T[] b) where T : IComparable
{
if (a.Length != b.Length) return false;
@@ -111,8 +122,8 @@ public static void StreamsEqual(Stream ast, Stream bst, int blocksToRead)
if (blocksToRead != -1 && blocksRead >= blocksToRead)
break;
- ac = ast.Read(ad, 0, 4096);
- bc = bst.Read(bd, 0, 4096);
+ ac = ReadAllBytes(ast, ad, 0, 4096);
+ bc = ReadAllBytes(bst, bd, 0, 4096);
if (ac != bc)
{
@@ -170,7 +181,7 @@ public static void IsZipSameAsDir(Stream archiveFile, string directory, ZipArchi
var buffer = new byte[entry.Length];
using (Stream entrystream = entry.Open())
{
- entrystream.Read(buffer, 0, buffer.Length);
+ ReadAllBytes(entrystream, buffer, 0, buffer.Length);
#if NETCOREAPP
uint zipcrc = entry.Crc32;
Assert.Equal(CRC.CalculateCRC(buffer), zipcrc);
diff --git a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs
index 4401a4753d5ac1..73ecccf89262b7 100644
--- a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs
+++ b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs
@@ -173,7 +173,7 @@ private void EnsureNoActiveAsyncOperation()
private void AsyncOperationStarting()
{
- if (Interlocked.CompareExchange(ref _activeAsyncOperation, 1, 0) != 0)
+ if (Interlocked.Exchange(ref _activeAsyncOperation, 1) != 0)
{
ThrowInvalidBeginCall();
}
@@ -181,13 +181,11 @@ private void AsyncOperationStarting()
private void AsyncOperationCompleting()
{
- int oldValue = Interlocked.CompareExchange(ref _activeAsyncOperation, 0, 1);
- Debug.Assert(oldValue == 1, $"Expected {nameof(_activeAsyncOperation)} to be 1, got {oldValue}");
+ Debug.Assert(_activeAsyncOperation == 1);
+ Volatile.Write(ref _activeAsyncOperation, 0);
}
- private static void ThrowInvalidBeginCall()
- {
+ private static void ThrowInvalidBeginCall() =>
throw new InvalidOperationException(SR.InvalidBeginCall);
- }
}
}
diff --git a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs
index f2f7d720c7ae7f..b9708e0c19df00 100644
--- a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs
+++ b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
+using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
@@ -42,8 +43,8 @@ public override int Read(byte[] buffer, int offset, int count)
public override int ReadByte()
{
byte b = default;
- int numRead = Read(MemoryMarshal.CreateSpan(ref b, 1));
- return numRead != 0 ? b : -1;
+ int bytesRead = Read(MemoryMarshal.CreateSpan(ref b, 1));
+ return bytesRead != 0 ? b : -1;
}
/// Reads a sequence of bytes from the current Brotli stream to a byte span and advances the position within the Brotli stream by the number of bytes read.
@@ -57,59 +58,25 @@ public override int Read(Span buffer)
if (_mode != CompressionMode.Decompress)
throw new InvalidOperationException(SR.BrotliStream_Compress_UnsupportedOperation);
EnsureNotDisposed();
- int totalWritten = 0;
- OperationStatus lastResult = OperationStatus.DestinationTooSmall;
- // We want to continue calling Decompress until we're either out of space for output or until Decompress indicates it is finished.
- while (buffer.Length > 0 && lastResult != OperationStatus.Done)
+ int bytesWritten;
+ while (!TryDecompress(buffer, out bytesWritten))
{
- if (lastResult == OperationStatus.NeedMoreData)
+ int bytesRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount);
+ if (bytesRead <= 0)
{
- // Ensure any left over data is at the beginning of the array so we can fill the remainder.
- if (_bufferCount > 0 && _bufferOffset != 0)
- {
- _buffer.AsSpan(_bufferOffset, _bufferCount).CopyTo(_buffer);
- }
- _bufferOffset = 0;
-
- int numRead = 0;
- while (_bufferCount < _buffer.Length && ((numRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount)) > 0))
- {
- _bufferCount += numRead;
- if (_bufferCount > _buffer.Length)
- {
- // The stream is either malicious or poorly implemented and returned a number of
- // bytes larger than the buffer supplied to it.
- throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);
- }
- }
-
- if (_bufferCount <= 0)
- {
- break;
- }
- }
-
- lastResult = _decoder.Decompress(new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount), buffer, out int bytesConsumed, out int bytesWritten);
- if (lastResult == OperationStatus.InvalidData)
- {
- throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData);
+ break;
}
- if (bytesConsumed > 0)
- {
- _bufferOffset += bytesConsumed;
- _bufferCount -= bytesConsumed;
- }
+ _bufferCount += bytesRead;
- if (bytesWritten > 0)
+ if (_bufferCount > _buffer.Length)
{
- totalWritten += bytesWritten;
- buffer = buffer.Slice(bytesWritten);
+ ThrowInvalidStream();
}
}
- return totalWritten;
+ return bytesWritten;
}
/// Begins an asynchronous read operation. (Consider using the method instead.)
@@ -169,73 +136,100 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel
{
return ValueTask.FromCanceled(cancellationToken);
}
- return FinishReadAsyncMemory(buffer, cancellationToken);
- }
- private async ValueTask FinishReadAsyncMemory(Memory buffer, CancellationToken cancellationToken)
- {
- AsyncOperationStarting();
- try
+ return Core(buffer, cancellationToken);
+
+ async ValueTask Core(Memory buffer, CancellationToken cancellationToken)
{
- int totalWritten = 0;
- OperationStatus lastResult = OperationStatus.DestinationTooSmall;
- // We want to continue calling Decompress until we're either out of space for output or until Decompress indicates it is finished.
- while (buffer.Length > 0 && lastResult != OperationStatus.Done)
+ AsyncOperationStarting();
+ try
{
- if (lastResult == OperationStatus.NeedMoreData)
+ int bytesWritten;
+ while (!TryDecompress(buffer.Span, out bytesWritten))
{
- // Ensure any left over data is at the beginning of the array so we can fill the remainder.
- if (_bufferCount > 0 && _bufferOffset != 0)
+ int bytesRead = await _stream.ReadAsync(_buffer.AsMemory(_bufferCount), cancellationToken).ConfigureAwait(false);
+ if (bytesRead <= 0)
{
- _buffer.AsSpan(_bufferOffset, _bufferCount).CopyTo(_buffer);
+ break;
}
- _bufferOffset = 0;
- int numRead = 0;
- while (_bufferCount < _buffer.Length &&
- ((numRead = await _stream.ReadAsync(new Memory(_buffer, _bufferCount, _buffer.Length - _bufferCount), cancellationToken).ConfigureAwait(false)) > 0))
- {
- _bufferCount += numRead;
- if (_bufferCount > _buffer.Length)
- {
- // The stream is either malicious or poorly implemented and returned a number of
- // bytes larger than the buffer supplied to it.
- throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);
- }
- }
+ _bufferCount += bytesRead;
- if (_bufferCount <= 0)
+ if (_bufferCount > _buffer.Length)
{
- break;
+ ThrowInvalidStream();
}
}
- cancellationToken.ThrowIfCancellationRequested();
- lastResult = _decoder.Decompress(new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount), buffer.Span, out int bytesConsumed, out int bytesWritten);
- if (lastResult == OperationStatus.InvalidData)
- {
- throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData);
- }
+ return bytesWritten;
+ }
+ finally
+ {
+ AsyncOperationCompleting();
+ }
+ }
+ }
- if (bytesConsumed > 0)
- {
- _bufferOffset += bytesConsumed;
- _bufferCount -= bytesConsumed;
- }
+ /// Tries to decode available data into the destination buffer.
+ /// The destination buffer for the decompressed data.
+ /// The number of bytes written to destination.
+ /// true if the caller should consider the read operation completed; otherwise, false.
+ private bool TryDecompress(Span destination, out int bytesWritten)
+ {
+ // Decompress any data we may have in our buffer.
+ OperationStatus lastResult = _decoder.Decompress(new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount), destination, out int bytesConsumed, out bytesWritten);
+ if (lastResult == OperationStatus.InvalidData)
+ {
+ throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData);
+ }
- if (bytesWritten > 0)
- {
- totalWritten += bytesWritten;
- buffer = buffer.Slice(bytesWritten);
- }
- }
+ if (bytesConsumed != 0)
+ {
+ _bufferOffset += bytesConsumed;
+ _bufferCount -= bytesConsumed;
+ }
+
+ // If we successfully decompressed any bytes, or if we've reached the end of the decompression, we're done.
+ if (bytesWritten != 0 || lastResult == OperationStatus.Done)
+ {
+ return true;
+ }
- return totalWritten;
+ if (destination.IsEmpty)
+ {
+ // The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting
+ // a buffer until data is known to be available. We don't have perfect knowledge here, as _decoder.Decompress
+ // will return DestinationTooSmall whether or not more data is required. As such, we assume that if there's
+ // any data in our input buffer, it would have been decompressible into at least one byte of output, and
+ // otherwise we need to do a read on the underlying stream. This isn't perfect, because having input data
+ // doesn't necessarily mean it'll decompress into at least one byte of output, but it's a reasonable approximation
+ // for the 99% case. If it's wrong, it just means that a caller using zero-byte reads as a way to delay
+ // getting a buffer to use for a subsequent call may end up getting one earlier than otherwise preferred.
+ Debug.Assert(lastResult == OperationStatus.DestinationTooSmall);
+ if (_bufferCount != 0)
+ {
+ Debug.Assert(bytesWritten == 0);
+ return true;
+ }
}
- finally
+
+ Debug.Assert(
+ lastResult == OperationStatus.NeedMoreData ||
+ (lastResult == OperationStatus.DestinationTooSmall && destination.IsEmpty && _bufferCount == 0), $"{nameof(lastResult)} == {lastResult}, {nameof(destination.Length)} == {destination.Length}");
+
+ // Ensure any left over data is at the beginning of the array so we can fill the remainder.
+ if (_bufferCount != 0 && _bufferOffset != 0)
{
- AsyncOperationCompleting();
+ new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount).CopyTo(_buffer);
}
+ _bufferOffset = 0;
+
+ return false;
}
+
+ private static void ThrowInvalidStream() =>
+ // The stream is either malicious or poorly implemented and returned a number of
+ // bytes larger than the buffer supplied to it.
+ throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);
}
}
diff --git a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs
index 44a368f58eeec2..efa17bc5f9cbb7 100644
--- a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs
+++ b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs
@@ -68,8 +68,8 @@ internal void WriteCore(ReadOnlySpan buffer, bool isFinalBlock = false)
Span output = new Span(_buffer);
while (lastResult == OperationStatus.DestinationTooSmall)
{
- int bytesConsumed = 0;
- int bytesWritten = 0;
+ int bytesConsumed;
+ int bytesWritten;
lastResult = _encoder.Compress(buffer, output, out bytesConsumed, out bytesWritten, isFinalBlock);
if (lastResult == OperationStatus.InvalidData)
throw new InvalidOperationException(SR.BrotliStream_Compress_InvalidData);
@@ -176,7 +176,7 @@ public override void Flush()
Span output = new Span(_buffer);
while (lastResult == OperationStatus.DestinationTooSmall)
{
- int bytesWritten = 0;
+ int bytesWritten;
lastResult = _encoder.Flush(output, out bytesWritten);
if (lastResult == OperationStatus.InvalidData)
throw new InvalidDataException(SR.BrotliStream_Compress_InvalidData);
diff --git a/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs b/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs
index 3960233e5eafa8..96eff305a0ef3e 100644
--- a/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs
+++ b/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs
@@ -14,7 +14,8 @@ public class BrotliStreamUnitTests : CompressionStreamUnitTestBase
public override Stream CreateStream(Stream stream, CompressionLevel level) => new BrotliStream(stream, level);
public override Stream CreateStream(Stream stream, CompressionLevel level, bool leaveOpen) => new BrotliStream(stream, level, leaveOpen);
public override Stream BaseStream(Stream stream) => ((BrotliStream)stream).BaseStream;
- protected override bool ReadsMayBlockUntilBufferFullOrEOF => true;
+
+ protected override bool FlushGuaranteesAllDataWritten => false;
// The tests are relying on an implementation detail of BrotliStream, using knowledge of its internal buffer size
// in various test calculations. Currently the implementation is using the ArrayPool, which will round up to a
diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs b/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs
index d9739b087b7041..84b7ad0e5acade 100644
--- a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs
+++ b/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs
@@ -104,12 +104,14 @@ internal void InitializeDeflater(Stream stream, bool leaveOpen, int windowBits,
InitializeBuffer();
}
+ [MemberNotNull(nameof(_buffer))]
private void InitializeBuffer()
{
Debug.Assert(_buffer == null);
_buffer = ArrayPool.Shared.Rent(DefaultBufferSize);
}
+ [MemberNotNull(nameof(_buffer))]
private void EnsureBufferInitialized()
{
if (_buffer == null)
@@ -259,83 +261,94 @@ internal int ReadCore(Span buffer)
EnsureDecompressionMode();
EnsureNotDisposed();
EnsureBufferInitialized();
-
- int totalRead = 0;
-
Debug.Assert(_inflater != null);
+
+ int bytesRead;
while (true)
{
- int bytesRead = _inflater.Inflate(buffer.Slice(totalRead));
- totalRead += bytesRead;
- if (totalRead == buffer.Length)
- {
- break;
- }
-
- // If the stream is finished then we have a few potential cases here:
- // 1. DeflateStream => return
- // 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input
- // 3. GZipStream that is finished and appended with garbage => return
- if (_inflater.Finished() && (!_inflater.IsGzipStream() || !_inflater.NeedsInput()))
+ // Try to decompress any data from the inflater into the caller's buffer.
+ // If we're able to decompress any bytes, or if decompression is completed, we're done.
+ bytesRead = _inflater.Inflate(buffer);
+ if (bytesRead != 0 || InflatorIsFinished)
{
break;
}
+ // We were unable to decompress any data. If the inflater needs additional input
+ // data to proceed, read some to populate it.
if (_inflater.NeedsInput())
{
- Debug.Assert(_buffer != null);
- int bytes = _stream.Read(_buffer, 0, _buffer.Length);
- if (bytes <= 0)
+ int n = _stream.Read(_buffer, 0, _buffer.Length);
+ if (n <= 0)
{
break;
}
- else if (bytes > _buffer.Length)
+ else if (n > _buffer.Length)
+ {
+ ThrowGenericInvalidData();
+ }
+ else
{
- // The stream is either malicious or poorly implemented and returned a number of
- // bytes larger than the buffer supplied to it.
- throw new InvalidDataException(SR.GenericInvalidData);
+ _inflater.SetInput(_buffer, 0, n);
}
+ }
- _inflater.SetInput(_buffer, 0, bytes);
+ if (buffer.IsEmpty)
+ {
+ // The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting
+ // a buffer until data is known to be available. We don't have perfect knowledge here, as _inflater.Inflate
+ // will return 0 whether or not more data is required, and having input data doesn't necessarily mean it'll
+ // decompress into at least one byte of output, but it's a reasonable approximation for the 99% case. If it's
+ // wrong, it just means that a caller using zero-byte reads as a way to delay getting a buffer to use for a
+ // subsequent call may end up getting one earlier than otherwise preferred.
+ Debug.Assert(bytesRead == 0);
+ break;
}
}
- return totalRead;
+ return bytesRead;
}
+ private bool InflatorIsFinished =>
+ // If the stream is finished then we have a few potential cases here:
+ // 1. DeflateStream => return
+ // 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input
+ // 3. GZipStream that is finished and appended with garbage => return
+ _inflater!.Finished() &&
+ (!_inflater.IsGzipStream() || !_inflater.NeedsInput());
+
private void EnsureNotDisposed()
{
if (_stream == null)
ThrowStreamClosedException();
- }
- private static void ThrowStreamClosedException()
- {
- throw new ObjectDisposedException(nameof(DeflateStream), SR.ObjectDisposed_StreamClosed);
+ static void ThrowStreamClosedException() =>
+ throw new ObjectDisposedException(nameof(DeflateStream), SR.ObjectDisposed_StreamClosed);
}
private void EnsureDecompressionMode()
{
if (_mode != CompressionMode.Decompress)
ThrowCannotReadFromDeflateStreamException();
- }
- private static void ThrowCannotReadFromDeflateStreamException()
- {
- throw new InvalidOperationException(SR.CannotReadFromDeflateStream);
+ static void ThrowCannotReadFromDeflateStreamException() =>
+ throw new InvalidOperationException(SR.CannotReadFromDeflateStream);
}
private void EnsureCompressionMode()
{
if (_mode != CompressionMode.Compress)
ThrowCannotWriteToDeflateStreamException();
- }
- private static void ThrowCannotWriteToDeflateStreamException()
- {
- throw new InvalidOperationException(SR.CannotWriteToDeflateStream);
+ static void ThrowCannotWriteToDeflateStreamException() =>
+ throw new InvalidOperationException(SR.CannotWriteToDeflateStream);
}
+ private static void ThrowGenericInvalidData() =>
+ // The stream is either malicious or poorly implemented and returned a number of
+ // bytes < 0 || > than the buffer supplied to it.
+ throw new InvalidDataException(SR.GenericInvalidData);
+
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) =>
TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState);
@@ -378,6 +391,7 @@ internal ValueTask ReadAsyncMemory(Memory buffer, CancellationToken c
}
EnsureBufferInitialized();
+ Debug.Assert(_inflater != null);
return Core(buffer, cancellationToken);
@@ -386,48 +400,49 @@ async ValueTask Core(Memory buffer, CancellationToken cancellationTok
AsyncOperationStarting();
try
{
- int totalRead = 0;
-
- Debug.Assert(_inflater != null);
+ int bytesRead;
while (true)
{
- int bytesRead = _inflater.Inflate(buffer.Span.Slice(totalRead));
- totalRead += bytesRead;
- if (totalRead == buffer.Length)
- {
- break;
- }
-
- // If the stream is finished then we have a few potential cases here:
- // 1. DeflateStream => return
- // 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input
- // 3. GZipStream that is finished and appended with garbage => return
- if (_inflater.Finished() && (!_inflater.IsGzipStream() || !_inflater.NeedsInput()))
+ // Try to decompress any data from the inflater into the caller's buffer.
+ // If we're able to decompress any bytes, or if decompression is completed, we're done.
+ bytesRead = _inflater.Inflate(buffer.Span);
+ if (bytesRead != 0 || InflatorIsFinished)
{
break;
}
+ // We were unable to decompress any data. If the inflater needs additional input
+ // data to proceed, read some to populate it.
if (_inflater.NeedsInput())
{
- Debug.Assert(_buffer != null);
- int bytes = await _stream.ReadAsync(_buffer, cancellationToken).ConfigureAwait(false);
- EnsureNotDisposed();
- if (bytes <= 0)
+ int n = await _stream.ReadAsync(new Memory(_buffer, 0, _buffer.Length), cancellationToken).ConfigureAwait(false);
+ if (n <= 0)
{
break;
}
- else if (bytes > _buffer.Length)
+ else if (n > _buffer.Length)
{
- // The stream is either malicious or poorly implemented and returned a number of
- // bytes larger than the buffer supplied to it.
- throw new InvalidDataException(SR.GenericInvalidData);
+ ThrowGenericInvalidData();
}
+ else
+ {
+ _inflater.SetInput(_buffer, 0, n);
+ }
+ }
- _inflater.SetInput(_buffer, 0, bytes);
+ if (buffer.IsEmpty)
+ {
+ // The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting
+ // a buffer until data is known to be available. We don't have perfect knowledge here, as _inflater.Inflate
+ // will return 0 whether or not more data is required, and having input data doesn't necessarily mean it'll
+ // decompress into at least one byte of output, but it's a reasonable approximation for the 99% case. If it's
+ // wrong, it just means that a caller using zero-byte reads as a way to delay getting a buffer to use for a
+ // subsequent call may end up getting one earlier than otherwise preferred.
+ break;
}
}
- return totalRead;
+ return bytesRead;
}
finally
{
@@ -1014,21 +1029,16 @@ private void EnsureNoActiveAsyncOperation()
private void AsyncOperationStarting()
{
- if (Interlocked.CompareExchange(ref _activeAsyncOperation, 1, 0) != 0)
+ if (Interlocked.Exchange(ref _activeAsyncOperation, 1) != 0)
{
ThrowInvalidBeginCall();
}
}
- private void AsyncOperationCompleting()
- {
- int oldValue = Interlocked.CompareExchange(ref _activeAsyncOperation, 0, 1);
- Debug.Assert(oldValue == 1, $"Expected {nameof(_activeAsyncOperation)} to be 1, got {oldValue}");
- }
+ private void AsyncOperationCompleting() =>
+ Volatile.Write(ref _activeAsyncOperation, 0);
- private static void ThrowInvalidBeginCall()
- {
+ private static void ThrowInvalidBeginCall() =>
throw new InvalidOperationException(SR.InvalidBeginCall);
- }
}
}
diff --git a/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs b/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs
index d2e813670f29d9..aab283c32c543e 100644
--- a/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs
+++ b/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
+using System.Diagnostics;
using System.IO;
using Xunit;
@@ -123,7 +124,7 @@ private static void ValidateCryptoStream(string expected, string data, ICryptoTr
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, transform, CryptoStreamMode.Read))
{
- int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
+ int bytesRead = ReadAll(cs, outputBytes);
string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead);
Assert.Equal(expected, outputString);
}
@@ -195,7 +196,7 @@ public static void ValidateFromBase64_NoPadding(string data)
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, transform, CryptoStreamMode.Read))
{
- int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
+ int bytesRead = ReadAll(cs, outputBytes);
// Missing padding bytes not supported (no exception, however)
Assert.NotEqual(inputBytes.Length, bytesRead);
@@ -230,7 +231,7 @@ public static void ValidateWhitespace(string expected, string data)
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, base64Transform, CryptoStreamMode.Read))
{
- int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
+ int bytesRead = ReadAll(cs, outputBytes);
string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead);
Assert.Equal(expected, outputString);
}
@@ -240,7 +241,7 @@ public static void ValidateWhitespace(string expected, string data)
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, base64Transform, CryptoStreamMode.Read))
{
- int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
+ int bytesRead = ReadAll(cs, outputBytes);
string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead);
Assert.Equal(expected, outputString);
}
@@ -293,5 +294,22 @@ public void TransformUsageFlags_FromBase64Transform()
Assert.True(transform.CanReuseTransform);
}
}
+
+ private static int ReadAll(Stream stream, Span buffer)
+ {
+ int totalRead = 0;
+ while (totalRead < buffer.Length)
+ {
+ int bytesRead = stream.Read(buffer.Slice(totalRead));
+ if (bytesRead == 0)
+ {
+ break;
+ }
+
+ totalRead += bytesRead;
+ }
+
+ return totalRead;
+ }
}
}
diff --git a/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs b/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs
index c023c9068f51ae..afe137dda9b969 100644
--- a/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs
+++ b/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs
@@ -16,10 +16,10 @@ public class CryptoStream : Stream, IDisposable
// Member variables
private readonly Stream _stream;
private readonly ICryptoTransform _transform;
- private byte[]? _inputBuffer; // read from _stream before _Transform
+ private byte[] _inputBuffer; // read from _stream before _Transform
private int _inputBufferIndex;
private int _inputBlockSize;
- private byte[]? _outputBuffer; // buffered output of _Transform
+ private byte[] _outputBuffer; // buffered output of _Transform
private int _outputBufferIndex;
private int _outputBlockSize;
private bool _canRead;
@@ -37,24 +37,41 @@ public CryptoStream(Stream stream, ICryptoTransform transform, CryptoStreamMode
public CryptoStream(Stream stream, ICryptoTransform transform, CryptoStreamMode mode, bool leaveOpen)
{
+ if (transform is null)
+ {
+ throw new ArgumentNullException(nameof(transform));
+ }
_stream = stream;
_transform = transform;
_leaveOpen = leaveOpen;
+
switch (mode)
{
case CryptoStreamMode.Read:
- if (!(_stream.CanRead)) throw new ArgumentException(SR.Format(SR.Argument_StreamNotReadable, nameof(stream)));
+ if (!_stream.CanRead)
+ {
+ throw new ArgumentException(SR.Format(SR.Argument_StreamNotReadable, nameof(stream)));
+ }
_canRead = true;
break;
+
case CryptoStreamMode.Write:
- if (!(_stream.CanWrite)) throw new ArgumentException(SR.Format(SR.Argument_StreamNotWritable, nameof(stream)));
+ if (!_stream.CanWrite)
+ {
+ throw new ArgumentException(SR.Format(SR.Argument_StreamNotWritable, nameof(stream)));
+ }
_canWrite = true;
break;
+
default:
- throw new ArgumentException(SR.Argument_InvalidValue);
+ throw new ArgumentException(SR.Argument_InvalidValue, nameof(mode));
}
- InitializeBuffer();
+
+ _inputBlockSize = _transform.InputBlockSize;
+ _inputBuffer = new byte[_inputBlockSize];
+ _outputBlockSize = _transform.OutputBlockSize;
+ _outputBuffer = new byte[_outputBlockSize];
}
public override bool CanRead
@@ -293,198 +310,149 @@ private void CheckReadArguments(byte[] buffer, int offset, int count)
private async ValueTask ReadAsyncCore(Memory buffer, CancellationToken cancellationToken, bool useAsync)
{
- // read <= count bytes from the input stream, transforming as we go.
- // Basic idea: first we deliver any bytes we already have in the
- // _OutputBuffer, because we know they're good. Then, if asked to deliver
- // more bytes, we read & transform a block at a time until either there are
- // no bytes ready or we've delivered enough.
- int bytesToDeliver = buffer.Length;
- int currentOutputIndex = 0;
- Debug.Assert(_outputBuffer != null);
- if (_outputBufferIndex != 0)
+ while (true)
{
- // we have some already-transformed bytes in the output buffer
- if (_outputBufferIndex <= buffer.Length)
+ // If there are currently any bytes stored in the output buffer, hand back as many as we can.
+ if (_outputBufferIndex != 0)
{
- _outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span);
- bytesToDeliver -= _outputBufferIndex;
- currentOutputIndex += _outputBufferIndex;
- int toClear = _outputBuffer.Length - _outputBufferIndex;
- CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear));
- _outputBufferIndex = 0;
+ int bytesToCopy = Math.Min(_outputBufferIndex, buffer.Length);
+ if (bytesToCopy != 0)
+ {
+ // Copy as many bytes as we can, then shift down the remaining bytes.
+ new ReadOnlySpan(_outputBuffer, 0, bytesToCopy).CopyTo(buffer.Span);
+ _outputBufferIndex -= bytesToCopy;
+ _outputBuffer.AsSpan(bytesToCopy).CopyTo(_outputBuffer);
+ CryptographicOperations.ZeroMemory(_outputBuffer.AsSpan(_outputBufferIndex, bytesToCopy));
+ }
+ return bytesToCopy;
}
- else
- {
- _outputBuffer.AsSpan(0, buffer.Length).CopyTo(buffer.Span);
- Buffer.BlockCopy(_outputBuffer, buffer.Length, _outputBuffer, 0, _outputBufferIndex - buffer.Length);
- _outputBufferIndex -= buffer.Length;
- int toClear = _outputBuffer.Length - _outputBufferIndex;
- CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear));
-
- return buffer.Length;
+ // If we've already hit the end of the stream, there's nothing more to do.
+ Debug.Assert(_outputBufferIndex == 0);
+ if (_finalBlockTransformed)
+ {
+ Debug.Assert(_inputBufferIndex == 0);
+ return 0;
}
- }
- // _finalBlockTransformed == true implies we're at the end of the input stream
- // if we got through the previous if block then _OutputBufferIndex = 0, meaning
- // we have no more transformed bytes to give
- // so return count-bytesToDeliver, the amount we were able to hand back
- // eventually, we'll just always return 0 here because there's no more to read
- if (_finalBlockTransformed)
- {
- return buffer.Length - bytesToDeliver;
- }
- // ok, now loop until we've delivered enough or there's nothing available
- int amountRead = 0;
- int numOutputBytes;
- // OK, see first if it's a multi-block transform and we can speed up things
- int blocksToProcess = bytesToDeliver / _outputBlockSize;
+ int bytesRead = 0;
+ bool eof = false;
- Debug.Assert(_inputBuffer != null);
- if (blocksToProcess > 1 && _transform.CanTransformMultipleBlocks)
- {
- int numWholeBlocksInBytes = blocksToProcess * _inputBlockSize;
-
- // Use ArrayPool.Shared instead of CryptoPool because the array is passed out.
- byte[]? tempInputBuffer = ArrayPool.Shared.Rent(numWholeBlocksInBytes);
- byte[]? tempOutputBuffer = null;
-
- try
+ // If the transform supports transforming multiple blocks, try to read as large a chunk as would yield
+ // data to fill the output buffer and do the appropriate transform directly into the output buffer.
+ int blocksToProcess = buffer.Length / _outputBlockSize;
+ if (blocksToProcess > 1 && _transform.CanTransformMultipleBlocks)
{
- amountRead = useAsync ?
- await _stream.ReadAsync(new Memory(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex), cancellationToken).ConfigureAwait(false) :
- _stream.Read(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex);
-
- int totalInput = _inputBufferIndex + amountRead;
-
- // If there's still less than a block, copy the new data into the hold buffer and move to the slow read.
- if (totalInput < _inputBlockSize)
- {
- Buffer.BlockCopy(tempInputBuffer, _inputBufferIndex, _inputBuffer, _inputBufferIndex, amountRead);
- _inputBufferIndex = totalInput;
- }
- else
+ // Use ArrayPool.Shared instead of CryptoPool because the array is passed out.
+ int numWholeBlocksInBytes = blocksToProcess * _inputBlockSize;
+ byte[] tempInputBuffer = ArrayPool.Shared.Rent(numWholeBlocksInBytes);
+ try
{
- // Copy any held data into tempInputBuffer now that we know we're proceeding
- Buffer.BlockCopy(_inputBuffer, 0, tempInputBuffer, 0, _inputBufferIndex);
- CryptographicOperations.ZeroMemory(new Span(_inputBuffer, 0, _inputBufferIndex));
- amountRead += _inputBufferIndex;
- _inputBufferIndex = 0;
-
- // Make amountRead an integral multiple of _InputBlockSize
- int numWholeReadBlocks = amountRead / _inputBlockSize;
- int numWholeReadBlocksInBytes = numWholeReadBlocks * _inputBlockSize;
- int numIgnoredBytes = amountRead - numWholeReadBlocksInBytes;
-
- if (numIgnoredBytes != 0)
+ // Read into our temporary input buffer, leaving enough room at the beginning for any existing data
+ // we have in _inputBuffer.
+ bytesRead = useAsync ?
+ await _stream.ReadAsync(new Memory(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex), cancellationToken).ConfigureAwait(false) :
+ _stream.Read(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex);
+ eof = bytesRead == 0;
+
+ // If we got enough data to form at least one block, transform as much as we can.
+ int totalInput = _inputBufferIndex + bytesRead;
+ if (totalInput >= _inputBlockSize)
{
- _inputBufferIndex = numIgnoredBytes;
- Buffer.BlockCopy(tempInputBuffer, numWholeReadBlocksInBytes, _inputBuffer, 0, numIgnoredBytes);
- }
-
- // Use ArrayPool.Shared instead of CryptoPool because the array is passed out.
- tempOutputBuffer = ArrayPool.Shared.Rent(numWholeReadBlocks * _outputBlockSize);
- numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, tempOutputBuffer, 0);
- tempOutputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex));
-
- // Clear what was written while we know how much that was
- CryptographicOperations.ZeroMemory(new Span(tempOutputBuffer, 0, numOutputBytes));
- ArrayPool.Shared.Return(tempOutputBuffer);
- tempOutputBuffer = null;
+ // Copy any held data into tempInputBuffer now that we know we're proceeding to handle
+ // decrypting all the received data.
+ Buffer.BlockCopy(_inputBuffer, 0, tempInputBuffer, 0, _inputBufferIndex);
+ CryptographicOperations.ZeroMemory(new Span(_inputBuffer, 0, _inputBufferIndex));
+ bytesRead += _inputBufferIndex;
+
+ // Determine how many entire blocks worth of data we read.
+ int numWholeReadBlocks = bytesRead / _inputBlockSize;
+ int numWholeReadBlocksInBytes = numWholeReadBlocks * _inputBlockSize;
+
+ // If there's anything left over, copy that back into _inputBuffer for a later read.
+ _inputBufferIndex = bytesRead - numWholeReadBlocksInBytes;
+ if (_inputBufferIndex != 0)
+ {
+ Buffer.BlockCopy(tempInputBuffer, numWholeReadBlocksInBytes, _inputBuffer, 0, _inputBufferIndex);
+ }
- bytesToDeliver -= numOutputBytes;
- currentOutputIndex += numOutputBytes;
- }
+ // Transform the read data into the caller's buffer.
+ int numOutputBytes;
+ if (MemoryMarshal.TryGetArray(buffer, out ArraySegment bufferArray))
+ {
+ // Because TransformBlock is based on arrays, we can only write directly into the output
+ // buffer if it's backed by an array; otherwise, we need to rent from the pool.
+ numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, bufferArray.Array!, bufferArray.Offset);
+ }
+ else
+ {
+ // Otherwise, we need to rent a temporary from the pool.
+ byte[] tempOutputBuffer = ArrayPool.Shared.Rent(numWholeReadBlocks * _outputBlockSize);
+ numOutputBytes = numWholeReadBlocks * _outputBlockSize;
+ try
+ {
+ numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, tempOutputBuffer, 0);
+ tempOutputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span);
+ }
+ finally
+ {
+ CryptographicOperations.ZeroMemory(new Span(tempOutputBuffer, 0, numOutputBytes));
+ ArrayPool.Shared.Return(tempOutputBuffer);
+ }
+ }
- CryptographicOperations.ZeroMemory(new Span(tempInputBuffer, 0, numWholeBlocksInBytes));
- ArrayPool.Shared.Return(tempInputBuffer);
- tempInputBuffer = null;
- }
- catch
- {
- // If we rented and then an exception happened we don't know how much was written to,
- // clear the whole thing and let it get reclaimed by the GC.
- if (tempOutputBuffer != null)
- {
- CryptographicOperations.ZeroMemory(tempOutputBuffer);
- tempOutputBuffer = null;
+ // Return anything we've got at this point.
+ if (numOutputBytes != 0)
+ {
+ return numOutputBytes;
+ }
+ }
+ else
+ {
+ // We have less than a block's worth of data. Copy the new data back into the _inputBuffer
+ // and fall back to using the single block code path.
+ Buffer.BlockCopy(tempInputBuffer, _inputBufferIndex, _inputBuffer, _inputBufferIndex, bytesRead);
+ _inputBufferIndex = totalInput;
+ }
}
-
- // For the input buffer we know how much was written, so clear that.
- // But still let it get reclaimed by the GC.
- if (tempInputBuffer != null)
+ finally
{
CryptographicOperations.ZeroMemory(new Span(tempInputBuffer, 0, numWholeBlocksInBytes));
- tempInputBuffer = null;
+ ArrayPool.Shared.Return(tempInputBuffer);
}
-
- throw;
}
- }
- // try to fill _InputBuffer so we have something to transform
- while (bytesToDeliver > 0)
- {
- while (_inputBufferIndex < _inputBlockSize)
+ // Read enough to fill one input block, as anything less won't be able to be transformed to produce output.
+ if (!eof)
{
- amountRead = useAsync ?
- await _stream.ReadAsync(new Memory(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex), cancellationToken).ConfigureAwait(false) :
- _stream.Read(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex);
+ while (_inputBufferIndex < _inputBlockSize)
+ {
+ bytesRead = useAsync ?
+ await _stream.ReadAsync(new Memory(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex), cancellationToken).ConfigureAwait(false) :
+ _stream.Read(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex);
+ if (bytesRead <= 0)
+ {
+ break;
+ }
- // first, check to see if we're at the end of the input stream
- if (amountRead == 0) goto ProcessFinalBlock;
- _inputBufferIndex += amountRead;
+ _inputBufferIndex += bytesRead;
+ }
}
- numOutputBytes = _transform.TransformBlock(_inputBuffer, 0, _inputBlockSize, _outputBuffer, 0);
- _inputBufferIndex = 0;
-
- if (bytesToDeliver >= numOutputBytes)
+ // Transform the received data.
+ if (bytesRead <= 0)
{
- _outputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex));
- CryptographicOperations.ZeroMemory(new Span(_outputBuffer, 0, numOutputBytes));
- currentOutputIndex += numOutputBytes;
- bytesToDeliver -= numOutputBytes;
+ _outputBuffer = _transform.TransformFinalBlock(_inputBuffer, 0, _inputBufferIndex);
+ _outputBufferIndex = _outputBuffer.Length;
+ _finalBlockTransformed = true;
}
else
{
- _outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex));
- _outputBufferIndex = numOutputBytes - bytesToDeliver;
- Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex);
- int toClear = _outputBuffer.Length - _outputBufferIndex;
- CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear));
- return buffer.Length;
+ _outputBufferIndex = _transform.TransformBlock(_inputBuffer, 0, _inputBufferIndex, _outputBuffer, 0);
}
- }
- return buffer.Length;
-
- ProcessFinalBlock:
- // if so, then call TransformFinalBlock to get whatever is left
- byte[] finalBytes = _transform.TransformFinalBlock(_inputBuffer, 0, _inputBufferIndex);
- // now, since _OutputBufferIndex must be 0 if we're in the while loop at this point,
- // reset it to be what we just got back
- _outputBuffer = finalBytes;
- _outputBufferIndex = finalBytes.Length;
- // set the fact that we've transformed the final block
- _finalBlockTransformed = true;
- // now, return either everything we just got or just what's asked for, whichever is smaller
- if (bytesToDeliver < _outputBufferIndex)
- {
- _outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex));
- _outputBufferIndex -= bytesToDeliver;
- Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex);
- int toClear = _outputBuffer.Length - _outputBufferIndex;
- CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear));
- return buffer.Length;
- }
- else
- {
- _outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span.Slice(currentOutputIndex));
- bytesToDeliver -= _outputBufferIndex;
- _outputBufferIndex = 0;
- CryptographicOperations.ZeroMemory(_outputBuffer);
- return buffer.Length - bytesToDeliver;
+
+ // All input data has been processed.
+ _inputBufferIndex = 0;
}
}
@@ -807,8 +775,8 @@ protected override void Dispose(bool disposing)
if (_outputBuffer != null)
Array.Clear(_outputBuffer);
- _inputBuffer = null;
- _outputBuffer = null;
+ _inputBuffer = null!;
+ _outputBuffer = null!;
_canRead = false;
_canWrite = false;
}
@@ -858,30 +826,13 @@ private async ValueTask DisposeAsyncCore()
Array.Clear(_outputBuffer);
}
- _inputBuffer = null;
- _outputBuffer = null;
+ _inputBuffer = null!;
+ _outputBuffer = null!;
_canRead = false;
_canWrite = false;
}
}
- // Private methods
-
- private void InitializeBuffer()
- {
- if (_transform != null)
- {
- _inputBlockSize = _transform.InputBlockSize;
- _inputBuffer = new byte[_inputBlockSize];
- _outputBlockSize = _transform.OutputBlockSize;
- _outputBuffer = new byte[_outputBlockSize];
- }
- else
- {
- throw new ArgumentNullException(nameof(_transform));
- }
- }
-
[MemberNotNull(nameof(_lazyAsyncActiveSemaphore))]
private SemaphoreSlim AsyncActiveSemaphore
{
diff --git a/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs b/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs
index ea43dc8ca10a2b..86fa12306f8ae8 100644
--- a/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs
+++ b/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs
@@ -27,6 +27,7 @@ protected override Task CreateWrappedConnectedStreamsAsync(StreamPai
}
protected override Type UnsupportedConcurrentExceptionType => null;
+ protected override bool BlocksOnZeroByteReads => true;
[ActiveIssue("https://github.com/dotnet/runtime/issues/45080")]
[Theory]
@@ -37,7 +38,7 @@ protected override Task CreateWrappedConnectedStreamsAsync(StreamPai
public static void Ctor()
{
var transform = new IdentityTransform(1, 1, true);
- AssertExtensions.Throws(null, () => new CryptoStream(new MemoryStream(), transform, (CryptoStreamMode)12345));
+ AssertExtensions.Throws("mode", () => new CryptoStream(new MemoryStream(), transform, (CryptoStreamMode)12345));
AssertExtensions.Throws(null, "stream", () => new CryptoStream(new MemoryStream(new byte[0], writable: false), transform, CryptoStreamMode.Write));
AssertExtensions.Throws(null, "stream", () => new CryptoStream(new CryptoStream(new MemoryStream(new byte[0]), transform, CryptoStreamMode.Write), transform, CryptoStreamMode.Read));
}