Skip to content

Commit

Permalink
Reduce allocation with Base64UrlEncoder (#2162)
Browse files Browse the repository at this point in the history
- Stop calling ToCharArray to produce inputs to it; instead pass in a ReadOnlyMemory sliced appropriately
- Move the char[] s_base64Table into a u8 span
- Use IndexOfAny to determine whether replacements are needed in UnsafeDecode
- Stackalloc a span or rent a char[] from the ArrayPool when a temporary is needed in UnsafeDecode
  • Loading branch information
stephentoub authored and brentschmaltz committed Jul 28, 2023
1 parent be6ba72 commit d4d0c61
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 102 deletions.
41 changes: 27 additions & 14 deletions src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ namespace Microsoft.IdentityModel.JsonWebTokens
/// </summary>
public class JsonWebToken : SecurityToken
{
private char[] _hChars;
private ClaimsIdentity _claimsIdentity;
private bool _wasClaimsIdentitySet;

Expand Down Expand Up @@ -470,25 +469,39 @@ private void ReadToken(string encodedJson)
throw LogHelper.LogExceptionMessage(new SecurityTokenMalformedException(LogHelper.FormatInvariant(LogMessages.IDX14310, encodedJson)));

// right number of dots for JWE
_hChars = encodedJson.ToCharArray(0, Dot1);
ReadOnlyMemory<char> hChars = encodedJson.AsMemory(0, Dot1);

// header cannot be empty
if (_hChars.Length == 0)
if (hChars.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14307, encodedJson)));

HeaderAsciiBytes = Encoding.ASCII.GetBytes(_hChars);
byte[] headerAsciiBytes = new byte[hChars.Length];
#if NET6_0_OR_GREATER
Encoding.ASCII.GetBytes(hChars.Span, headerAsciiBytes);
#else
unsafe
{
fixed (char* hCharsPtr = hChars.Span)
fixed (byte* headerAsciiBytesPtr = headerAsciiBytes)
{
Encoding.ASCII.GetBytes(hCharsPtr, hChars.Length, headerAsciiBytesPtr, headerAsciiBytes.Length);
}
}
#endif
HeaderAsciiBytes = headerAsciiBytes;

try
{
Header = new JsonClaimSet(Base64UrlEncoder.UnsafeDecode(_hChars));
Header = new JsonClaimSet(Base64UrlEncoder.UnsafeDecode(hChars));
}
catch (Exception ex)
{
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14102, encodedJson.Substring(0, Dot1), encodedJson), ex));
}

// dir does not have any key bytes
char[] encryptedKeyBytes = encodedJson.ToCharArray(Dot1 + 1, Dot2 - Dot1 - 1);
if (encryptedKeyBytes.Length != 0)
ReadOnlyMemory<char> encryptedKeyBytes = encodedJson.AsMemory(Dot1 + 1, Dot2 - Dot1 - 1);
if (!encryptedKeyBytes.IsEmpty)
{
EncryptedKeyBytes = Base64UrlEncoder.UnsafeDecode(encryptedKeyBytes);
_encryptedKey = encodedJson.Substring(Dot1 + 1, Dot2 - Dot1 - 1);
Expand All @@ -498,8 +511,8 @@ private void ReadToken(string encodedJson)
_encryptedKey = string.Empty;
}

char[] initializationVectorChars = encodedJson.ToCharArray(Dot2 + 1, Dot3 - Dot2 - 1);
if (initializationVectorChars.Length == 0)
ReadOnlyMemory<char> initializationVectorChars = encodedJson.AsMemory(Dot2 + 1, Dot3 - Dot2 - 1);
if (initializationVectorChars.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14308, encodedJson)));

try
Expand All @@ -511,8 +524,8 @@ private void ReadToken(string encodedJson)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14309, encodedJson, encodedJson), ex));
}

char[] authTagChars = encodedJson.ToCharArray(Dot4 + 1, encodedJson.Length - Dot4 - 1);
if (authTagChars.Length == 0)
ReadOnlyMemory<char> authTagChars = encodedJson.AsMemory(Dot4 + 1);
if (authTagChars.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14310, encodedJson)));

try
Expand All @@ -524,13 +537,13 @@ private void ReadToken(string encodedJson)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14311, encodedJson, encodedJson), ex));
}

char[] cipherTextBytes = encodedJson.ToCharArray(Dot3 + 1, Dot4 - Dot3 - 1);
if (cipherTextBytes.Length == 0)
ReadOnlyMemory<char> cipherTextBytes = encodedJson.AsMemory(Dot3 + 1, Dot4 - Dot3 - 1);
if (cipherTextBytes.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14306, encodedJson)));

try
{
CipherTextBytes = Base64UrlEncoder.UnsafeDecode(encodedJson.ToCharArray(Dot3 + 1, Dot4 - Dot3 - 1));
CipherTextBytes = Base64UrlEncoder.UnsafeDecode(cipherTextBytes);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageId>Microsoft.IdentityModel.JsonWebTokens</PackageId>
<PackageTags>.NET;Windows;Authentication;Identity;Json Web Token</PackageTags>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)'=='Debug'">
Expand Down
173 changes: 85 additions & 88 deletions src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
// Licensed under the MIT License.

using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.IdentityModel.Logging;

Expand All @@ -18,18 +22,6 @@ public static class Base64UrlEncoder
private const char base64UrlCharacter62 = '-';
private const char base64UrlCharacter63 = '_';

/// <summary>
/// Encoding table
/// </summary>
internal static readonly char[] s_base64Table =
{
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z',
'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
'0','1','2','3','4','5','6','7','8','9',
base64UrlCharacter62,
base64UrlCharacter63
};

/// <summary>
/// The following functions perform base64url encoding which differs from regular base64 encoding as follows
/// * padding is skipped so the pad character '=' doesn't have to be percent encoded
Expand Down Expand Up @@ -90,7 +82,8 @@ public static string Encode(byte[] inArray, int offset, int length)
int lengthmod3 = length % 3;
int limit = offset + (length - lengthmod3);
char[] output = new char[(length + 2) / 3 * 4];
char[] table = s_base64Table;
ReadOnlySpan<byte> table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8;

int i, j = 0;

// takes 3 bytes from inArray and insert 4 bytes into output
Expand All @@ -100,10 +93,10 @@ public static string Encode(byte[] inArray, int offset, int length)
byte d1 = inArray[i + 1];
byte d2 = inArray[i + 2];

output[j + 0] = table[d0 >> 2];
output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = table[((d1 & 0x0f) << 2) | (d2 >> 6)];
output[j + 3] = table[d2 & 0x3f];
output[j + 0] = (char)table[d0 >> 2];
output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = (char)table[((d1 & 0x0f) << 2) | (d2 >> 6)];
output[j + 3] = (char)table[d2 & 0x3f];
j += 4;
}

Expand All @@ -117,9 +110,9 @@ public static string Encode(byte[] inArray, int offset, int length)
byte d0 = inArray[i];
byte d1 = inArray[i + 1];

output[j + 0] = table[d0 >> 2];
output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = table[(d1 & 0x0f) << 2];
output[j + 0] = (char)table[d0 >> 2];
output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = (char)table[(d1 & 0x0f) << 2];
j += 3;
}
break;
Expand All @@ -128,8 +121,8 @@ public static string Encode(byte[] inArray, int offset, int length)
{
byte d0 = inArray[i];

output[j + 0] = table[d0 >> 2];
output[j + 1] = table[(d0 & 0x03) << 4];
output[j + 0] = (char)table[d0 >> 2];
output[j + 1] = (char)table[(d0 & 0x03) << 4];
j += 2;
}
break;
Expand Down Expand Up @@ -168,106 +161,100 @@ internal static string EncodeString(string str)
public static byte[] DecodeBytes(string str)
{
_ = str ?? throw LogHelper.LogExceptionMessage(new ArgumentNullException(nameof(str)));
return UnsafeDecode(str);
return UnsafeDecode(str.AsMemory());
}

internal static unsafe byte[] UnsafeDecode(string str)
#if NET6_0_OR_GREATER
[SkipLocalsInit]
#endif
internal static unsafe byte[] UnsafeDecode(ReadOnlyMemory<char> str)
{
int mod = str.Length % 4;
if (mod == 1)
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str)));
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str.ToString())));

bool needReplace = false;
bool needReplace = str.Span.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0;
int decodedLength = str.Length + (4 - mod) % 4;

for (int i = 0; i < str.Length; i++)
{
if (str[i] == base64UrlCharacter62 || str[i] == base64UrlCharacter63)
{
needReplace = true;
break;
}
}
#if NET6_0_OR_GREATER
// If the incoming chars don't contain any of the base64url characters that need to be replaced,
// and if the incoming chars are of the exact right length, then we'll be able to just pass the
// incoming chars directly to Convert.TryFromBase64Chars. Otherwise, rent an array, copy all the
// data into it, and do whatever fixups are necessary on that copy, then pass that copy into
// Convert.TryFromBase64Chars.

if (needReplace)
const int StackAllocThreshold = 512;
char[] arrayPoolChars = null;
scoped Span<char> charsSpan = default;
scoped ReadOnlySpan<char> source = str.Span;

if (needReplace || decodedLength != source.Length)
{
string decodedString = new(char.MinValue, decodedLength);
fixed (char* dest = decodedString)
charsSpan = decodedLength <= StackAllocThreshold ?
stackalloc char[StackAllocThreshold] :
arrayPoolChars = ArrayPool<char>.Shared.Rent(decodedLength);
charsSpan = charsSpan.Slice(0, decodedLength);

source.CopyTo(charsSpan);
if (source.Length < charsSpan.Length)
{
int i = 0;
for (; i < str.Length; i++)
charsSpan[source.Length] = base64PadCharacter;
if (source.Length + 1 < charsSpan.Length)
{
if (str[i] == base64UrlCharacter62)
dest[i] = base64Character62;
else if (str[i] == base64UrlCharacter63)
dest[i] = base64Character63;
else
dest[i] = str[i];
charsSpan[source.Length + 1] = base64PadCharacter;
}

for (; i < decodedLength; i++)
dest[i] = base64PadCharacter;
}

return Convert.FromBase64String(decodedString);
}
else
{
if (decodedLength == str.Length)
if (needReplace)
{
return Convert.FromBase64String(str);
}
else
{
string decodedString = new(char.MinValue, decodedLength);
fixed (char* src = str)
fixed (char* dest = decodedString)
int pos;
while ((pos = charsSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
{
Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2);

dest[str.Length] = base64PadCharacter;
if (str.Length + 2 == decodedLength)
dest[str.Length + 1] = base64PadCharacter;
charsSpan[pos] = charsSpan[pos] == base64UrlCharacter62 ? base64Character62 : base64Character63;
}

return Convert.FromBase64String(decodedString);
}

source = charsSpan;
}
}

internal static unsafe byte[] UnsafeDecode(char[] str)
{
int mod = str.Length % 4;
if (mod == 1)
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str)));
byte[] arrayPoolBytes = null;
Span<byte> bytesSpan = decodedLength <= StackAllocThreshold ?
stackalloc byte[StackAllocThreshold] :
arrayPoolBytes = ArrayPool<byte>.Shared.Rent(decodedLength);

bool needReplace = false;
// the decoded length
int decodedLength = str.Length + (4 - mod) % 4;
bool converted = Convert.TryFromBase64Chars(source, bytesSpan, out int bytesWritten);
Debug.Assert(converted, "Expected TryFromBase64Chars to be successful");
byte[] result = bytesSpan.Slice(0, bytesWritten).ToArray();

for (int i = 0; i < str.Length; i++)
if (arrayPoolBytes is not null)
{
if (str[i] == base64UrlCharacter62 || str[i] == base64UrlCharacter63)
{
needReplace = true;
break;
}
bytesSpan.Clear();
ArrayPool<byte>.Shared.Return(arrayPoolBytes);
}

if (arrayPoolChars is not null)
{
charsSpan.Clear();
ArrayPool<char>.Shared.Return(arrayPoolChars);
}

return result;
#else
if (needReplace)
{
ReadOnlySpan<char> strSpan = str.Span;
string decodedString = new(char.MinValue, decodedLength);
fixed (char* dest = decodedString)
{
int i = 0;
for (; i < str.Length; i++)
for (; i < strSpan.Length; i++)
{
if (str[i] == base64UrlCharacter62)
if (strSpan[i] == base64UrlCharacter62)
dest[i] = base64Character62;
else if (str[i] == base64UrlCharacter63)
else if (strSpan[i] == base64UrlCharacter63)
dest[i] = base64Character63;
else
dest[i] = str[i];
dest[i] = strSpan[i];
}

for (; i < decodedLength; i++)
Expand All @@ -280,12 +267,21 @@ internal static unsafe byte[] UnsafeDecode(char[] str)
{
if (decodedLength == str.Length)
{
return Convert.FromBase64CharArray(str, 0, str.Length);
if (MemoryMarshal.TryGetArray(str, out ArraySegment<char> segment))
{
return Convert.FromBase64CharArray(segment.Array, segment.Offset, segment.Count);
}
else
{
bool gotString = MemoryMarshal.TryGetString(str, out string text, out int start, out int length);
Debug.Assert(gotString, "Expected ReadOnlyMemory to wrap either array or string");
return Convert.FromBase64String(text.Substring(start, length));
}
}
else
{
string decodedString = new(char.MinValue, decodedLength);
fixed (char* src = str)
fixed (char* src = str.Span)
fixed (char* dest = decodedString)
{
Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2);
Expand All @@ -298,6 +294,7 @@ internal static unsafe byte[] UnsafeDecode(char[] str)
return Convert.FromBase64String(decodedString);
}
}
#endif
}

/// <summary>
Expand Down

0 comments on commit d4d0c61

Please sign in to comment.