Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Base64.DecodeFromUtf8InPlace #2504

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
</ItemGroup>-->

<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.13.5" />
<PackageReference Include="BenchmarkDotNet" Version="0.13.12" />
keegan-caruso marked this conversation as resolved.
Show resolved Hide resolved
</ItemGroup>

<PropertyGroup Condition=" '$(Configuration)' == 'Release' ">
Expand Down
15 changes: 8 additions & 7 deletions src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ internal void ReadToken(ReadOnlyMemory<char> encodedTokenMemory)

try
{
Header = CreateHeaderClaimSet(Base64UrlEncoder.UnsafeDecode(headerSpan).AsSpan());
Header = CreateHeaderClaimSet(Base64UrlEncoder.Decode(headerSpan).AsSpan());
}
catch (Exception ex)
{
Expand All @@ -518,7 +518,7 @@ internal void ReadToken(ReadOnlyMemory<char> encodedTokenMemory)
ReadOnlySpan<char> encryptedKeyBytes = encodedTokenSpan.Slice(Dot1 + 1, Dot2 - Dot1 - 1);
if (!encryptedKeyBytes.IsEmpty)
{
EncryptedKeyBytes = Base64UrlEncoder.UnsafeDecode(encryptedKeyBytes);
EncryptedKeyBytes = Base64UrlEncoder.Decode(encryptedKeyBytes);
_encryptedKey = encryptedKeyBytes.ToString();
}
else
Expand All @@ -532,7 +532,7 @@ internal void ReadToken(ReadOnlyMemory<char> encodedTokenMemory)

try
{
InitializationVectorBytes = Base64UrlEncoder.UnsafeDecode(initializationVectorSpan);
InitializationVectorBytes = Base64UrlEncoder.Decode(initializationVectorSpan);
}
catch (Exception ex)
{
Expand All @@ -545,7 +545,7 @@ internal void ReadToken(ReadOnlyMemory<char> encodedTokenMemory)

try
{
AuthenticationTagBytes = Base64UrlEncoder.UnsafeDecode(authTagSpan);
AuthenticationTagBytes = Base64UrlEncoder.Decode(authTagSpan);
}
catch (Exception ex)
{
Expand All @@ -558,7 +558,7 @@ internal void ReadToken(ReadOnlyMemory<char> encodedTokenMemory)

try
{
CipherTextBytes = Base64UrlEncoder.UnsafeDecode(cipherTextSpan);
CipherTextBytes = Base64UrlEncoder.Decode(cipherTextSpan);
}
catch (Exception ex)
{
Expand All @@ -570,15 +570,16 @@ internal void ReadToken(ReadOnlyMemory<char> encodedTokenMemory)
internal JsonClaimSet CreateClaimSet(ReadOnlySpan<char> strSpan, int startIndex, int length, bool createHeaderClaimSet)
{
int outputSize = Base64UrlEncoding.ValidateAndGetOutputSize(strSpan, startIndex, length);

byte[] output = ArrayPool<byte>.Shared.Rent(outputSize);
try
{
Base64UrlEncoding.Decode(strSpan, startIndex, length, output);
Base64UrlEncoder.Decode(strSpan.Slice(startIndex, length), output);
return createHeaderClaimSet ? CreateHeaderClaimSet(output.AsSpan()) : CreatePayloadClaimSet(output.AsSpan());
}
finally
{
ArrayPool<byte>.Shared.Return(output);
ArrayPool<byte>.Shared.Return(output, true);
}
}

Expand Down
131 changes: 93 additions & 38 deletions src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,31 @@ public static int Encode(ReadOnlySpan<byte> inArray, Span<char> output)
public static byte[] DecodeBytes(string str)
{
_ = str ?? throw LogHelper.LogExceptionMessage(new ArgumentNullException(nameof(str)));
return UnsafeDecode(str.AsSpan());
return Decode(str.AsSpan());
}

internal static byte[] Decode(ReadOnlySpan<char> strSpan)
{
int mod = strSpan.Length % 4;
if (mod == 1)
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString())));

bool needReplace = strSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0;
int decodedLength = strSpan.Length + (4 - mod) % 4;

#if NET6_0_OR_GREATER
[SkipLocalsInit]

Span<byte> output = new byte[decodedLength];

int length = Decode(strSpan, output, needReplace, decodedLength);

return output.Slice(0, length).ToArray();
#else
return UnsafeDecode(strSpan, needReplace, decodedLength);
#endif
internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
}

internal static void Decode(ReadOnlySpan<char> strSpan, Span<byte> output)
{
int mod = strSpan.Length % 4;
if (mod == 1)
Expand All @@ -189,11 +207,22 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
int decodedLength = strSpan.Length + (4 - mod) % 4;

#if NET6_0_OR_GREATER
Decode(strSpan, output, needReplace, decodedLength);
#else
Decode(strSpan, output, needReplace, decodedLength);
#endif
}

#if NET6_0_OR_GREATER

[SkipLocalsInit]
private static int Decode(ReadOnlySpan<char> strSpan, Span<byte> output, bool needReplace, int decodedLength)
{
// If the incoming chars don't contain any of the base64url characters that need to be replaced,
// and if the incoming chars are of the exact right length, then we'll be able to just pass the
// incoming chars directly to Convert.TryFromBase64Chars. Otherwise, rent an array, copy all the
// incoming chars directly to DecodeFromUtf8InPlace. Otherwise, rent an array, copy all the
// data into it, and do whatever fixups are necessary on that copy, then pass that copy into
// Convert.TryFromBase64Chars.
// DecodeFromUtf8InPlace.

const int StackAllocThreshold = 512;
char[] arrayPoolChars = null;
Expand All @@ -207,53 +236,73 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
arrayPoolChars = ArrayPool<char>.Shared.Rent(decodedLength);
charsSpan = charsSpan.Slice(0, decodedLength);

source.CopyTo(charsSpan);
if (source.Length < charsSpan.Length)
{
charsSpan[source.Length] = base64PadCharacter;
if (source.Length + 1 < charsSpan.Length)
{
charsSpan[source.Length + 1] = base64PadCharacter;
}
}

if (needReplace)
{
Span<char> remaining = charsSpan;
int pos;
while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
{
remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63;
remaining = remaining.Slice(pos + 1);
}
}

source = charsSpan;
source = HandlePaddingAndReplace(source, charsSpan, needReplace);
}

byte[] arrayPoolBytes = null;
Span<byte> bytesSpan = decodedLength <= StackAllocThreshold ?
stackalloc byte[StackAllocThreshold] :
arrayPoolBytes = ArrayPool<byte>.Shared.Rent(decodedLength);

bool converted = Convert.TryFromBase64Chars(source, bytesSpan, out int bytesWritten);
Debug.Assert(converted, "Expected TryFromBase64Chars to be successful");
byte[] result = bytesSpan.Slice(0, bytesWritten).ToArray();
int length = Encoding.UTF8.GetBytes(source, bytesSpan);
Span<byte> utf8Span = bytesSpan.Slice(0, length);

try
{
OperationStatus status = System.Buffers.Text.Base64.DecodeFromUtf8InPlace(utf8Span, out int bytesWritten);
if (status != OperationStatus.Done)
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString())));

utf8Span.Slice(0, bytesWritten).CopyTo(output);

return bytesWritten;
}
finally
{
if (arrayPoolBytes is not null)
{
bytesSpan.Clear();
ArrayPool<byte>.Shared.Return(arrayPoolBytes);
}

if (arrayPoolChars is not null)
{
charsSpan.Clear();
ArrayPool<char>.Shared.Return(arrayPoolChars);
}
}
}

if (arrayPoolBytes is not null)
private static ReadOnlySpan<char> HandlePaddingAndReplace(ReadOnlySpan<char> source, Span<char> charsSpan, bool needReplace)
{
source.CopyTo(charsSpan);
keegan-caruso marked this conversation as resolved.
Show resolved Hide resolved
if (source.Length < charsSpan.Length)
{
bytesSpan.Clear();
ArrayPool<byte>.Shared.Return(arrayPoolBytes);
charsSpan[source.Length] = base64PadCharacter;
if (source.Length + 1 < charsSpan.Length)
{
charsSpan[source.Length + 1] = base64PadCharacter;
}
}

if (arrayPoolChars is not null)
if (needReplace)
{
charsSpan.Clear();
ArrayPool<char>.Shared.Return(arrayPoolChars);
Span<char> remaining = charsSpan;
int pos;
while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
{
remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63;
remaining = remaining.Slice(pos + 1);
}
keegan-caruso marked this conversation as resolved.
Show resolved Hide resolved
}

return result;
return charsSpan;
}

#else

private static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan, bool needReplace, int decodedLength)
{
if (needReplace)
{
string decodedString = new(char.MinValue, decodedLength);
Expand Down Expand Up @@ -298,9 +347,15 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan<char> strSpan)
return Convert.FromBase64String(decodedString);
}
}
#endif
}

private static void Decode(ReadOnlySpan<char> strSpan, Span<byte> output, bool needReplace, int decodedLength)
{
byte[] result = UnsafeDecode(strSpan, needReplace, decodedLength);
result.CopyTo(output);
}
#endif

/// <summary>
/// Decodes the string from Base64UrlEncoded to UTF8.
/// </summary>
Expand Down
Loading