diff --git a/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj b/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj index 5279cc3f14..6d2e5581d3 100644 --- a/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj +++ b/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj @@ -23,7 +23,7 @@ --> - + diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs index fd6921f1f4..255bdc21d9 100644 --- a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs +++ b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs @@ -504,7 +504,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory) try { - Header = CreateHeaderClaimSet(Base64UrlEncoder.UnsafeDecode(headerSpan).AsSpan()); + Header = CreateHeaderClaimSet(Base64UrlEncoder.Decode(headerSpan).AsSpan()); } catch (Exception ex) { @@ -518,7 +518,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory) ReadOnlySpan encryptedKeyBytes = encodedTokenSpan.Slice(Dot1 + 1, Dot2 - Dot1 - 1); if (!encryptedKeyBytes.IsEmpty) { - EncryptedKeyBytes = Base64UrlEncoder.UnsafeDecode(encryptedKeyBytes); + EncryptedKeyBytes = Base64UrlEncoder.Decode(encryptedKeyBytes); _encryptedKey = encryptedKeyBytes.ToString(); } else @@ -532,7 +532,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory) try { - InitializationVectorBytes = Base64UrlEncoder.UnsafeDecode(initializationVectorSpan); + InitializationVectorBytes = Base64UrlEncoder.Decode(initializationVectorSpan); } catch (Exception ex) { @@ -545,7 +545,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory) try { - AuthenticationTagBytes = Base64UrlEncoder.UnsafeDecode(authTagSpan); + AuthenticationTagBytes = Base64UrlEncoder.Decode(authTagSpan); } catch (Exception ex) { @@ -558,7 +558,7 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory) try { - CipherTextBytes = Base64UrlEncoder.UnsafeDecode(cipherTextSpan); + CipherTextBytes = Base64UrlEncoder.Decode(cipherTextSpan); } catch (Exception ex) { @@ -570,15 +570,16 @@ internal void ReadToken(ReadOnlyMemory encodedTokenMemory) internal JsonClaimSet CreateClaimSet(ReadOnlySpan strSpan, int startIndex, int length, bool createHeaderClaimSet) { int outputSize = Base64UrlEncoding.ValidateAndGetOutputSize(strSpan, startIndex, length); + byte[] output = ArrayPool.Shared.Rent(outputSize); try { - Base64UrlEncoding.Decode(strSpan, startIndex, length, output); + Base64UrlEncoder.Decode(strSpan.Slice(startIndex, length), output); return createHeaderClaimSet ? CreateHeaderClaimSet(output.AsSpan()) : CreatePayloadClaimSet(output.AsSpan()); } finally { - ArrayPool.Shared.Return(output); + ArrayPool.Shared.Return(output, true); } } diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs index 8a11a07c3c..cb1ed39940 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs @@ -173,13 +173,31 @@ public static int Encode(ReadOnlySpan inArray, Span output) public static byte[] DecodeBytes(string str) { _ = str ?? throw LogHelper.LogExceptionMessage(new ArgumentNullException(nameof(str))); - return UnsafeDecode(str.AsSpan()); + return Decode(str.AsSpan()); } + internal static byte[] Decode(ReadOnlySpan strSpan) + { + int mod = strSpan.Length % 4; + if (mod == 1) + throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); + + bool needReplace = strSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0; + int decodedLength = strSpan.Length + (4 - mod) % 4; + #if NET6_0_OR_GREATER - [SkipLocalsInit] + + Span output = new byte[decodedLength]; + + int length = Decode(strSpan, output, needReplace, decodedLength); + + return output.Slice(0, length).ToArray(); +#else + return UnsafeDecode(strSpan, needReplace, decodedLength); #endif - internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan) + } + + internal static void Decode(ReadOnlySpan strSpan, Span output) { int mod = strSpan.Length % 4; if (mod == 1) @@ -189,11 +207,22 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan) int decodedLength = strSpan.Length + (4 - mod) % 4; #if NET6_0_OR_GREATER + Decode(strSpan, output, needReplace, decodedLength); +#else + Decode(strSpan, output, needReplace, decodedLength); +#endif + } + +#if NET6_0_OR_GREATER + + [SkipLocalsInit] + private static int Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength) + { // If the incoming chars don't contain any of the base64url characters that need to be replaced, // and if the incoming chars are of the exact right length, then we'll be able to just pass the - // incoming chars directly to Convert.TryFromBase64Chars. Otherwise, rent an array, copy all the + // incoming chars directly to DecodeFromUtf8InPlace. Otherwise, rent an array, copy all the // data into it, and do whatever fixups are necessary on that copy, then pass that copy into - // Convert.TryFromBase64Chars. + // DecodeFromUtf8InPlace. const int StackAllocThreshold = 512; char[] arrayPoolChars = null; @@ -207,28 +236,7 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan) arrayPoolChars = ArrayPool.Shared.Rent(decodedLength); charsSpan = charsSpan.Slice(0, decodedLength); - source.CopyTo(charsSpan); - if (source.Length < charsSpan.Length) - { - charsSpan[source.Length] = base64PadCharacter; - if (source.Length + 1 < charsSpan.Length) - { - charsSpan[source.Length + 1] = base64PadCharacter; - } - } - - if (needReplace) - { - Span remaining = charsSpan; - int pos; - while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0) - { - remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63; - remaining = remaining.Slice(pos + 1); - } - } - - source = charsSpan; + source = HandlePaddingAndReplace(source, charsSpan, needReplace); } byte[] arrayPoolBytes = null; @@ -236,24 +244,65 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan) stackalloc byte[StackAllocThreshold] : arrayPoolBytes = ArrayPool.Shared.Rent(decodedLength); - bool converted = Convert.TryFromBase64Chars(source, bytesSpan, out int bytesWritten); - Debug.Assert(converted, "Expected TryFromBase64Chars to be successful"); - byte[] result = bytesSpan.Slice(0, bytesWritten).ToArray(); + int length = Encoding.UTF8.GetBytes(source, bytesSpan); + Span utf8Span = bytesSpan.Slice(0, length); + + try + { + OperationStatus status = System.Buffers.Text.Base64.DecodeFromUtf8InPlace(utf8Span, out int bytesWritten); + if (status != OperationStatus.Done) + throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); + + utf8Span.Slice(0, bytesWritten).CopyTo(output); + + return bytesWritten; + } + finally + { + if (arrayPoolBytes is not null) + { + bytesSpan.Clear(); + ArrayPool.Shared.Return(arrayPoolBytes); + } + + if (arrayPoolChars is not null) + { + charsSpan.Clear(); + ArrayPool.Shared.Return(arrayPoolChars); + } + } + } - if (arrayPoolBytes is not null) + private static ReadOnlySpan HandlePaddingAndReplace(ReadOnlySpan source, Span charsSpan, bool needReplace) + { + source.CopyTo(charsSpan); + if (source.Length < charsSpan.Length) { - bytesSpan.Clear(); - ArrayPool.Shared.Return(arrayPoolBytes); + charsSpan[source.Length] = base64PadCharacter; + if (source.Length + 1 < charsSpan.Length) + { + charsSpan[source.Length + 1] = base64PadCharacter; + } } - if (arrayPoolChars is not null) + if (needReplace) { - charsSpan.Clear(); - ArrayPool.Shared.Return(arrayPoolChars); + Span remaining = charsSpan; + int pos; + while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0) + { + remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63; + remaining = remaining.Slice(pos + 1); + } } - return result; + return charsSpan; + } + #else + + private static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan, bool needReplace, int decodedLength) + { if (needReplace) { string decodedString = new(char.MinValue, decodedLength); @@ -298,9 +347,15 @@ internal static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan) return Convert.FromBase64String(decodedString); } } -#endif } + private static void Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength) + { + byte[] result = UnsafeDecode(strSpan, needReplace, decodedLength); + result.CopyTo(output); + } +#endif + /// /// Decodes the string from Base64UrlEncoded to UTF8. ///