Skip to content

Commit

Permalink
Introduce overload for UnsafeDecode that writes to an output buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
Keegan Caruso committed Feb 28, 2024
1 parent ecfdd53 commit e8e6f23
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 27 deletions.
15 changes: 11 additions & 4 deletions src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,17 @@ internal JsonClaimSet CreateClaimSet(ReadOnlySpan<char> strSpan, int startIndex,
{
int outputSize = Base64UrlEncoding.ValidateAndGetOutputSize(strSpan, startIndex, length);

var slice = strSpan.Slice(startIndex, length);
var output = Base64UrlEncoder.UnsafeDecode(slice);

return createHeaderClaimSet ? CreateHeaderClaimSet(output.AsSpan()) : CreatePayloadClaimSet(output.AsSpan());
byte[] output = ArrayPool<byte>.Shared.Rent(outputSize);
try
{
ReadOnlySpan<char> slice = strSpan.Slice(startIndex, length);
Base64UrlEncoder.UnsafeDecode(slice, output);
return createHeaderClaimSet ? CreateHeaderClaimSet(output.AsSpan()) : CreatePayloadClaimSet(output.AsSpan());
}
finally
{
ArrayPool<byte>.Shared.Return(output, true);
}
}

/// <summary>
Expand Down
139 changes: 116 additions & 23 deletions src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,35 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
int decodedLength = strSpan.Length + (4 - mod) % 4;

#if NET6_0_OR_GREATER
return UnsafeNetCoreDecode(strSpan, needReplace, decodedLength);
#else
return UnsafeNetFWDecode(strSpan, needReplace, decodedLength);
#endif
}

#if NET6_0_OR_GREATER
[SkipLocalsInit]
#endif
internal static unsafe void UnsafeDecode(ReadOnlySpan<char> strSpan, Span<byte> output)
{
int mod = strSpan.Length % 4;
if (mod == 1)
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString())));

bool needReplace = strSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0;
int decodedLength = strSpan.Length + (4 - mod) % 4;

#if NET6_0_OR_GREATER
UnsafeNetCoreDecode(strSpan, output, needReplace, decodedLength);
#else
UnsafeNetFWDecode(strSpan, output, needReplace, decodedLength);
#endif
}

#if NET6_0_OR_GREATER

private static unsafe byte[] UnsafeNetCoreDecode(ReadOnlySpan<char> strSpan, bool needReplace, int decodedLength)
{
// If the incoming chars don't contain any of the base64url characters that need to be replaced,
// and if the incoming chars are of the exact right length, then we'll be able to just pass the
// incoming chars directly to DecodeFromUtf8InPlace. Otherwise, rent an array, copy all the
Expand All @@ -207,28 +236,7 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
arrayPoolChars = ArrayPool<char>.Shared.Rent(decodedLength);
charsSpan = charsSpan.Slice(0, decodedLength);

source.CopyTo(charsSpan);
if (source.Length < charsSpan.Length)
{
charsSpan[source.Length] = base64PadCharacter;
if (source.Length + 1 < charsSpan.Length)
{
charsSpan[source.Length + 1] = base64PadCharacter;
}
}

if (needReplace)
{
Span<char> remaining = charsSpan;
int pos;
while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
{
remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63;
remaining = remaining.Slice(pos + 1);
}
}

source = charsSpan;
source = HandlePaddingAndReplace(needReplace, charsSpan, source);
}

byte[] arrayPoolBytes = null;
Expand Down Expand Up @@ -256,7 +264,86 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
}

return result;
}

private static unsafe void UnsafeNetCoreDecode(ReadOnlySpan<char> strSpan, Span<byte> output, bool needReplace, int decodedLength)
{
// If the incoming chars don't contain any of the base64url characters that need to be replaced,
// and if the incoming chars are of the exact right length, then we'll be able to just pass the
// incoming chars directly to DecodeFromUtf8InPlace. Otherwise, rent an array, copy all the
// data into it, and do whatever fixups are necessary on that copy, then pass that copy into
// DecodeFromUtf8InPlace.

const int StackAllocThreshold = 512;
char[] arrayPoolChars = null;
scoped Span<char> charsSpan = default;
scoped ReadOnlySpan<char> source = strSpan;

if (needReplace || decodedLength != source.Length)
{
charsSpan = decodedLength <= StackAllocThreshold ?
stackalloc char[StackAllocThreshold] :
arrayPoolChars = ArrayPool<char>.Shared.Rent(decodedLength);
charsSpan = charsSpan.Slice(0, decodedLength);

source = HandlePaddingAndReplace(needReplace, charsSpan, source);
}

byte[] arrayPoolBytes = null;
Span<byte> bytesSpan = decodedLength <= StackAllocThreshold ?
stackalloc byte[StackAllocThreshold] :
arrayPoolBytes = ArrayPool<byte>.Shared.Rent(decodedLength);

int length = Encoding.UTF8.GetBytes(source, bytesSpan);
Span<byte> utf8Span = bytesSpan.Slice(0, length);
OperationStatus status = System.Buffers.Text.Base64.DecodeFromUtf8InPlace(utf8Span, out int bytesWritten);
Debug.Assert(status == OperationStatus.Done, "Expected DecodeFromUtf8 to be successful");

utf8Span.Slice(0, bytesWritten).CopyTo(output);

if (arrayPoolBytes is not null)
{
bytesSpan.Clear();
ArrayPool<byte>.Shared.Return(arrayPoolBytes);
}

if (arrayPoolChars is not null)
{
charsSpan.Clear();
ArrayPool<char>.Shared.Return(arrayPoolChars);
}
}

private static unsafe ReadOnlySpan<char> HandlePaddingAndReplace(bool needReplace, Span<char> charsSpan, ReadOnlySpan<char> source)
{
source.CopyTo(charsSpan);
if (source.Length < charsSpan.Length)
{
charsSpan[source.Length] = base64PadCharacter;
if (source.Length + 1 < charsSpan.Length)
{
charsSpan[source.Length + 1] = base64PadCharacter;
}
}

if (needReplace)
{
Span<char> remaining = charsSpan;
int pos;
while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
{
remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63;
remaining = remaining.Slice(pos + 1);
}
}

return charsSpan;
}

#else

private static unsafe byte[] UnsafeNetFWDecode(ReadOnlySpan<char> strSpan, bool needReplace, int decodedLength)
{
if (needReplace)
{
string decodedString = new(char.MinValue, decodedLength);
Expand Down Expand Up @@ -301,9 +388,15 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
return Convert.FromBase64String(decodedString);
}
}
#endif
}

private static unsafe void UnsafeNetFWDecode(ReadOnlySpan<char> strSpan, Span<byte> output, bool needReplace, int decodedLength)
{
byte[] result = UnsafeNetFWDecode(strSpan, needReplace, decodedLength);
result.CopyTo(output);
}
#endif

/// <summary>
/// Decodes the string from Base64UrlEncoded to UTF8.
/// </summary>
Expand Down

0 comments on commit e8e6f23

Please sign in to comment.