diff --git a/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj b/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj
index 5279cc3f14..6d2e5581d3 100644
--- a/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj
+++ b/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj
@@ -23,7 +23,7 @@
-->
-
+
diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs
index fd6921f1f4..255bdc21d9 100644
--- a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs
+++ b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs
@@ -504,7 +504,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory)
try
{
- Header = CreateHeaderClaimSet(Base64UrlEncoder.UnsafeDecode(headerSpan).AsSpan());
+ Header = CreateHeaderClaimSet(Base64UrlEncoder.Decode(headerSpan).AsSpan());
}
catch (Exception ex)
{
@@ -518,7 +518,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory)
ReadOnlySpan encryptedKeyBytes = encodedTokenSpan.Slice(Dot1 + 1, Dot2 - Dot1 - 1);
if (!encryptedKeyBytes.IsEmpty)
{
- EncryptedKeyBytes = Base64UrlEncoder.UnsafeDecode(encryptedKeyBytes);
+ EncryptedKeyBytes = Base64UrlEncoder.Decode(encryptedKeyBytes);
_encryptedKey = encryptedKeyBytes.ToString();
}
else
@@ -532,7 +532,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory)
try
{
- InitializationVectorBytes = Base64UrlEncoder.UnsafeDecode(initializationVectorSpan);
+ InitializationVectorBytes = Base64UrlEncoder.Decode(initializationVectorSpan);
}
catch (Exception ex)
{
@@ -545,7 +545,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory)
try
{
- AuthenticationTagBytes = Base64UrlEncoder.UnsafeDecode(authTagSpan);
+ AuthenticationTagBytes = Base64UrlEncoder.Decode(authTagSpan);
}
catch (Exception ex)
{
@@ -558,7 +558,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory)
try
{
- CipherTextBytes = Base64UrlEncoder.UnsafeDecode(cipherTextSpan);
+ CipherTextBytes = Base64UrlEncoder.Decode(cipherTextSpan);
}
catch (Exception ex)
{
@@ -570,15 +570,16 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory)
internal JsonClaimSet CreateClaimSet(ReadOnlySpan strSpan, int startIndex, int length, bool createHeaderClaimSet)
{
int outputSize = Base64UrlEncoding.ValidateAndGetOutputSize(strSpan, startIndex, length);
+
byte[] output = ArrayPool.Shared.Rent(outputSize);
try
{
- Base64UrlEncoding.Decode(strSpan, startIndex, length, output);
+ Base64UrlEncoder.Decode(strSpan.Slice(startIndex, length), output);
return createHeaderClaimSet ? CreateHeaderClaimSet(output.AsSpan()) : CreatePayloadClaimSet(output.AsSpan());
}
finally
{
- ArrayPool.Shared.Return(output);
+ ArrayPool.Shared.Return(output, true);
}
}
diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs
index 8a11a07c3c..cb1ed39940 100644
--- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs
+++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs
@@ -173,13 +173,31 @@ public static int Encode(ReadOnlySpan inArray, Span output)
public static byte[] DecodeBytes(string str)
{
_ = str ?? throw LogHelper.LogExceptionMessage(new ArgumentNullException(nameof(str)));
- return UnsafeDecode(str.AsSpan());
+ return Decode(str.AsSpan());
}
+ internal static byte[] Decode(ReadOnlySpan strSpan)
+ {
+ int mod = strSpan.Length % 4;
+ if (mod == 1)
+ throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString())));
+
+ bool needReplace = strSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0;
+ int decodedLength = strSpan.Length + (4 - mod) % 4;
+
#if NET6_0_OR_GREATER
- [SkipLocalsInit]
+
+ Span output = new byte[decodedLength];
+
+ int length = Decode(strSpan, output, needReplace, decodedLength);
+
+ return output.Slice(0, length).ToArray();
+#else
+ return UnsafeDecode(strSpan, needReplace, decodedLength);
#endif
- internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan)
+ }
+
+ internal static void Decode(ReadOnlySpan strSpan, Span output)
{
int mod = strSpan.Length % 4;
if (mod == 1)
@@ -189,11 +207,22 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan)
int decodedLength = strSpan.Length + (4 - mod) % 4;
#if NET6_0_OR_GREATER
+ Decode(strSpan, output, needReplace, decodedLength);
+#else
+ Decode(strSpan, output, needReplace, decodedLength);
+#endif
+ }
+
+#if NET6_0_OR_GREATER
+
+ [SkipLocalsInit]
+ private static int Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength)
+ {
// If the incoming chars don't contain any of the base64url characters that need to be replaced,
// and if the incoming chars are of the exact right length, then we'll be able to just pass the
- // incoming chars directly to Convert.TryFromBase64Chars. Otherwise, rent an array, copy all the
+ // incoming chars directly to DecodeFromUtf8InPlace. Otherwise, rent an array, copy all the
// data into it, and do whatever fixups are necessary on that copy, then pass that copy into
- // Convert.TryFromBase64Chars.
+ // DecodeFromUtf8InPlace.
const int StackAllocThreshold = 512;
char[] arrayPoolChars = null;
@@ -207,28 +236,7 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan)
arrayPoolChars = ArrayPool.Shared.Rent(decodedLength);
charsSpan = charsSpan.Slice(0, decodedLength);
- source.CopyTo(charsSpan);
- if (source.Length < charsSpan.Length)
- {
- charsSpan[source.Length] = base64PadCharacter;
- if (source.Length + 1 < charsSpan.Length)
- {
- charsSpan[source.Length + 1] = base64PadCharacter;
- }
- }
-
- if (needReplace)
- {
- Span remaining = charsSpan;
- int pos;
- while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
- {
- remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63;
- remaining = remaining.Slice(pos + 1);
- }
- }
-
- source = charsSpan;
+ source = HandlePaddingAndReplace(source, charsSpan, needReplace);
}
byte[] arrayPoolBytes = null;
@@ -236,24 +244,65 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan)
stackalloc byte[StackAllocThreshold] :
arrayPoolBytes = ArrayPool.Shared.Rent(decodedLength);
- bool converted = Convert.TryFromBase64Chars(source, bytesSpan, out int bytesWritten);
- Debug.Assert(converted, "Expected TryFromBase64Chars to be successful");
- byte[] result = bytesSpan.Slice(0, bytesWritten).ToArray();
+ int length = Encoding.UTF8.GetBytes(source, bytesSpan);
+ Span utf8Span = bytesSpan.Slice(0, length);
+
+ try
+ {
+ OperationStatus status = System.Buffers.Text.Base64.DecodeFromUtf8InPlace(utf8Span, out int bytesWritten);
+ if (status != OperationStatus.Done)
+ throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString())));
+
+ utf8Span.Slice(0, bytesWritten).CopyTo(output);
+
+ return bytesWritten;
+ }
+ finally
+ {
+ if (arrayPoolBytes is not null)
+ {
+ bytesSpan.Clear();
+ ArrayPool.Shared.Return(arrayPoolBytes);
+ }
+
+ if (arrayPoolChars is not null)
+ {
+ charsSpan.Clear();
+ ArrayPool.Shared.Return(arrayPoolChars);
+ }
+ }
+ }
- if (arrayPoolBytes is not null)
+ private static ReadOnlySpan HandlePaddingAndReplace(ReadOnlySpan source, Span charsSpan, bool needReplace)
+ {
+ source.CopyTo(charsSpan);
+ if (source.Length < charsSpan.Length)
{
- bytesSpan.Clear();
- ArrayPool.Shared.Return(arrayPoolBytes);
+ charsSpan[source.Length] = base64PadCharacter;
+ if (source.Length + 1 < charsSpan.Length)
+ {
+ charsSpan[source.Length + 1] = base64PadCharacter;
+ }
}
- if (arrayPoolChars is not null)
+ if (needReplace)
{
- charsSpan.Clear();
- ArrayPool.Shared.Return(arrayPoolChars);
+ Span remaining = charsSpan;
+ int pos;
+ while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
+ {
+ remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63;
+ remaining = remaining.Slice(pos + 1);
+ }
}
- return result;
+ return charsSpan;
+ }
+
#else
+
+ private static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan, bool needReplace, int decodedLength)
+ {
if (needReplace)
{
string decodedString = new(char.MinValue, decodedLength);
@@ -298,9 +347,15 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan)
return Convert.FromBase64String(decodedString);
}
}
-#endif
}
+ private static void Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength)
+ {
+ byte[] result = UnsafeDecode(strSpan, needReplace, decodedLength);
+ result.CopyTo(output);
+ }
+#endif
+
///
/// Decodes the string from Base64UrlEncoded to UTF8.
///