From 5e5a5db3320b8f39a89058875a63839413ee3274 Mon Sep 17 00:00:00 2001 From: id4s Date: Mon, 22 Apr 2024 20:24:00 -0700 Subject: [PATCH] Do not dispose SignatureProvider when AsymmetricAdapter faults. Move compacted SignatureProviders to new cache to ensure dispose is called. Add delegate to check if signature provider should be removed from cache. Separate expired from compaction Dispose SignatureProvider if it was never cached. --- buildpackLocal.bat | 3 + .../AsymmetricSignatureProvider.cs | 18 +- .../CryptoProviderFactory.cs | 36 +- .../EventBasedLRUCache.cs | 320 +++++++++++------- .../InMemoryCryptoProviderCache.cs | 63 ++-- .../LogMessages.cs | 5 +- .../SignatureProvider.cs | 2 + .../SymmetricSignatureProvider.cs | 28 +- .../CryptoProviderFactoryTests.cs | 8 +- 9 files changed, 290 insertions(+), 193 deletions(-) create mode 100644 buildpackLocal.bat diff --git a/buildpackLocal.bat b/buildpackLocal.bat new file mode 100644 index 0000000000..44038b4f98 --- /dev/null +++ b/buildpackLocal.bat @@ -0,0 +1,3 @@ +dotnet clean Product.proj > clean.log +dotnet build /r Product.proj +dotnet pack --no-restore -o c:\localpackages --no-build Product.proj diff --git a/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs index 6183e64781..d03d3efc69 100644 --- a/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs @@ -208,15 +208,13 @@ public override bool Sign(ReadOnlySpan input, Span signature, out in catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } - } #endif @@ -248,12 +246,11 @@ public override byte[] Sign(byte[] input) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } } @@ -279,12 +276,11 @@ public override byte[] Sign(byte[] input, int offset, int count) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } } @@ -380,12 +376,11 @@ public override bool Verify(byte[] input, byte[] signature) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } } @@ -474,15 +469,14 @@ public override bool Verify(byte[] input, int inputOffset, int inputLength, byte } catch { - Dispose(true); + CryptoProviderCache?.TryRemove(this); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } - } /// diff --git a/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs b/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs index a94551be69..d3d7213170 100644 --- a/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs +++ b/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs @@ -18,6 +18,8 @@ public class CryptoProviderFactory private static readonly ConcurrentDictionary _typeToAlgorithmMap = new ConcurrentDictionary(); private static readonly object _cacheLock = new object(); private static int _defaultSignatureProviderObjectPoolCacheSize = Environment.ProcessorCount * 4; + private static string _typeofAsymmetricSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + private static string _typeofSymmetricSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); private int _signatureProviderObjectPoolCacheSize = _defaultSignatureProviderObjectPoolCacheSize; /// @@ -513,7 +515,13 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori { signatureProvider = CustomCryptoProvider.Create(algorithm, key, willCreateSignatures) as SignatureProvider; if (signatureProvider == null) - throw LogHelper.LogExceptionMessage(new InvalidOperationException(LogHelper.FormatInvariant(LogMessages.IDX10646, LogHelper.MarkAsNonPII(algorithm), key, LogHelper.MarkAsNonPII(typeof(SignatureProvider))))); + throw LogHelper.LogExceptionMessage( + new InvalidOperationException( + LogHelper.FormatInvariant( + LogMessages.IDX10646, + LogHelper.MarkAsNonPII(algorithm), + key, + LogHelper.MarkAsNonPII(typeof(SignatureProvider))))); return signatureProvider; } @@ -523,7 +531,7 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori bool createAsymmetric = true; if (key is AsymmetricSecurityKey) { - typeofSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofAsymmetricSignatureProvider; } else if (key is JsonWebKey jsonWebKey) { @@ -533,11 +541,11 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori { if (convertedSecurityKey is AsymmetricSecurityKey) { - typeofSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofAsymmetricSignatureProvider; } else if (convertedSecurityKey is SymmetricSecurityKey) { - typeofSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofSymmetricSignatureProvider; createAsymmetric = false; } } @@ -545,10 +553,10 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori else { if (jsonWebKey.Kty == JsonWebAlgorithmsKeyTypes.RSA || jsonWebKey.Kty == JsonWebAlgorithmsKeyTypes.EllipticCurve) - typeofSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofAsymmetricSignatureProvider; else if (jsonWebKey.Kty == JsonWebAlgorithmsKeyTypes.Octet) { - typeofSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofSymmetricSignatureProvider; createAsymmetric = false; } } @@ -560,12 +568,20 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori } else if (key is SymmetricSecurityKey) { - typeofSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofSymmetricSignatureProvider; createAsymmetric = false; } if (typeofSignatureProvider == null) - throw LogHelper.LogExceptionMessage(new NotSupportedException(LogHelper.FormatInvariant(LogMessages.IDX10621, LogHelper.MarkAsNonPII(typeof(SymmetricSignatureProvider)), LogHelper.MarkAsNonPII(typeof(SecurityKey)), LogHelper.MarkAsNonPII(typeof(AsymmetricSecurityKey)), LogHelper.MarkAsNonPII(typeof(SymmetricSecurityKey)), LogHelper.MarkAsNonPII(key.GetType())))); + throw LogHelper.LogExceptionMessage( + new NotSupportedException( + LogHelper.FormatInvariant( + LogMessages.IDX10621, + LogHelper.MarkAsNonPII(typeof(SymmetricSignatureProvider)), + LogHelper.MarkAsNonPII(typeof(SecurityKey)), + LogHelper.MarkAsNonPII(typeof(AsymmetricSecurityKey)), + LogHelper.MarkAsNonPII(typeof(SymmetricSecurityKey)), + LogHelper.MarkAsNonPII(key.GetType())))); if (CacheSignatureProviders && cacheProvider) { @@ -592,7 +608,7 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori signatureProvider = new SymmetricSignatureProvider(key, algorithm, willCreateSignatures); if (ShouldCacheSignatureProvider(signatureProvider)) - CryptoProviderCache.TryAdd(signatureProvider); + signatureProvider.IsCached = CryptoProviderCache.TryAdd(signatureProvider); } } else @@ -737,7 +753,7 @@ public virtual void ReleaseSignatureProvider(SignatureProvider signatureProvider signatureProvider.Release(); if (CustomCryptoProvider != null && CustomCryptoProvider.IsSupportedAlgorithm(signatureProvider.Algorithm)) CustomCryptoProvider.Release(signatureProvider); - else if (signatureProvider.CryptoProviderCache == null && signatureProvider.RefCount == 0) + else if (signatureProvider.CryptoProviderCache == null && signatureProvider.RefCount == 0 && !signatureProvider.IsCached) signatureProvider.Dispose(); } } diff --git a/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs b/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs index 1a329252ee..64e3a70fed 100644 --- a/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs +++ b/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs @@ -5,7 +5,6 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; -using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Microsoft.IdentityModel.Abstractions; @@ -29,63 +28,74 @@ namespace Microsoft.IdentityModel.Tokens /// The value type to be used by the cache internal class EventBasedLRUCache { + internal delegate void ItemCompacted(TValue Value); + internal delegate void ItemExpired(TValue Value); internal delegate void ItemRemoved(TValue Value); + internal delegate bool ShouldRemove(TValue Value); private readonly int _capacity; - + private List> _compactedItems = new List>(); // The percentage of the cache to be removed when _maxCapacityPercentage is reached. private readonly double _compactionPercentage = .20; private LinkedList> _doubleLinkedList = new LinkedList>(); private ConcurrentQueue _eventQueue = new ConcurrentQueue(); + private readonly TaskCreationOptions _options; + // if true, then items will be maintained in a LRU fashion, moving to front of list when accessed in the cache. + private readonly bool _maintainLRU; private ConcurrentDictionary> _map; - // When the current cache size gets to this percentage of _capacity, _compactionPercentage% of the cache will be removed. private readonly double _maxCapacityPercentage = .95; + private readonly int _compactIntervalInSeconds; // if true, expired values will not be added to the cache and clean-up of expired values will occur on a 5 minute interval private readonly bool _removeExpiredValues; private readonly int _removeExpiredValuesIntervalInSeconds; - // if true, then items will be maintained in a LRU fashion, moving to front of list when accessed in the cache. - private readonly bool _maintainLRU; - - private readonly TaskCreationOptions _options; - private DateTime _dueForExpiredValuesRemoval; - // for testing purpose only to verify the task count private int _taskCount = 0; + private DateTime _timeForNextExpiredValuesRemoval; + private DateTime _timeForNextCompaction; #region event queue - private int _eventQueuePollingInterval = 50; - // The idle timeout, the _eventQueueTask will end after being idle for the specified time interval (execution continues even if the queue is empty to reduce the task startup overhead), default to 120 seconds. // TODO: consider implementing a better algorithm that tracks and predicts the usage patterns and adjusts this value dynamically. private long _eventQueueTaskIdleTimeoutInSeconds = 120; - // The time when the _eventQueueTask should end. The intent is to reduce the overhead costs of starting/ending tasks too frequently // but at the same time keep the _eventQueueTask a short running task. // Since Task is based on thread pool the overhead should be reasonable. private DateTime _eventQueueTaskStopTime; - // task states used to ensure thread safety (Interlocked.CompareExchange) private const int EventQueueTaskStopped = 0; // task not started yet private const int EventQueueTaskRunning = 1; // task is running private const int EventQueueTaskDoNotStop = 2; // force the task to continue even it has past the _eventQueueTaskStopTime, see StartEventQueueTaskIfNotRunning() for more details. private int _eventQueueTaskState = EventQueueTaskStopped; - private const int CompactionNotQueued = 0; // compaction action not in the event queue - private const int CompactionQueuedOrRunning = 1; // compaction action in the event queue or currently in progress - private int _compactionState = CompactionNotQueued; + private const int ActionNotQueued = 0; // compaction action not in the event queue + private const int ActionQueuedOrRunning = 1; // compaction action in the event queue or currently in progress + private int _compactValuesState = ActionNotQueued; + private int _removeExpiredValuesState = ActionNotQueued; + private int _processCompactedValuesState = ActionNotQueued; // set to true when the AppDomain is to be unloaded or the default AppDomain process is ready to exit private bool _shouldStopImmediately = false; - internal ItemRemoved OnItemRemoved + internal ItemExpired OnItemExpired { get; set; } + + /// + /// For back compat any friend would be broken, this is the same as OnItemExpired. + /// + internal ItemExpired OnItemRemoved { - get; - set; + get { return OnItemExpired; } + set { OnItemExpired = value; } } + internal ItemCompacted OnItemMovedToCompactedList { get; set; } + + internal ItemRemoved OnItemRemovedFromCompactedList { get; set; } + + internal ShouldRemove OnShouldRemoveFromCompactedList { get; set; } + internal long EventQueueTaskIdleTimeoutInSeconds { get => _eventQueueTaskIdleTimeoutInSeconds; @@ -96,20 +106,6 @@ internal long EventQueueTaskIdleTimeoutInSeconds _eventQueueTaskIdleTimeoutInSeconds = value; } } - - // If the task operating on the _eventQueue has not timed out and the _eventQueue is empty, this polling interval will be used - // to determine how often the cache should be checked for the presence of a new action. - private int EventQueuePollingInterval - { - get => _eventQueuePollingInterval; - set - { - if (value <= 0) - throw new ArgumentOutOfRangeException(nameof(value), "EventQueuePollingInterval must be positive."); - _eventQueuePollingInterval = value; - } - } - #endregion /// @@ -121,22 +117,26 @@ private int EventQueuePollingInterval /// Whether or not to remove expired items. /// The period to wait to remove expired items, in seconds. /// Whether or not to maintain items in a LRU fashion, moving to front of list when accessed in the cache. + /// The period to wait to compact items, in seconds. internal EventBasedLRUCache( int capacity, TaskCreationOptions options = TaskCreationOptions.None, IEqualityComparer comparer = null, bool removeExpiredValues = false, int removeExpiredValuesIntervalInSeconds = 300, - bool maintainLRU = false) + bool maintainLRU = false, + int compactIntervalInSeconds = 20) { _capacity = capacity > 0 ? capacity : throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException(nameof(capacity))); _options = options; _map = new ConcurrentDictionary>(comparer ?? EqualityComparer.Default); _removeExpiredValuesIntervalInSeconds = removeExpiredValuesIntervalInSeconds; _removeExpiredValues = removeExpiredValues; + _compactIntervalInSeconds = compactIntervalInSeconds; + _timeForNextExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); + _timeForNextCompaction = DateTime.UtcNow.AddSeconds(_compactIntervalInSeconds); _eventQueueTaskStopTime = DateTime.UtcNow; _maintainLRU = maintainLRU; - _dueForExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); } /// @@ -168,6 +168,7 @@ internal EventBasedLRUCache( private void AddActionToEventQueue(Action action) { _eventQueue.Enqueue(action); + // start the event queue task if it is not running StartEventQueueTaskIfNotRunning(); } @@ -186,88 +187,97 @@ public bool Contains(TKey key) private void EventQueueTaskAction() { Interlocked.Increment(ref _taskCount); - // Keep running until the queue is empty or the AppDomain is about to be unloaded or the application is ready to exit. - while (!_shouldStopImmediately) + try { - // always set the state to EventQueueTaskRunning in case it was set to EventQueueTaskDoNotStop - Interlocked.Exchange(ref _eventQueueTaskState, EventQueueTaskRunning); - - try + // Keep running until the queue is empty or the AppDomain is about to be unloaded or the application is ready to exit. + while (!_shouldStopImmediately) { - // remove expired items if needed - if (_removeExpiredValues && DateTime.UtcNow >= _dueForExpiredValuesRemoval) - { - if (_maintainLRU) - RemoveExpiredValuesLRU(); - else - RemoveExpiredValues(); + // always set the state to EventQueueTaskRunning in case it was set to EventQueueTaskDoNotStop + Interlocked.Exchange(ref _eventQueueTaskState, EventQueueTaskRunning); - _dueForExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); - } - - // process all events in the queue and exit - if (_eventQueue.TryDequeue(out var action)) - { - action?.Invoke(); - } - else if (DateTime.UtcNow > _eventQueueTaskStopTime) // no more event to be processed, exit if expired + try { - // Setting _eventQueueTaskState = EventQueueTaskStopped if the _eventQueueTaskEndTime has past and _eventQueueTaskState == EventQueueTaskRunning. - // This means no other thread came in and it is safe to end this task. - // If another thread adds new events while this task is still running, it will set the _eventQueueTaskState = EventQueueTaskDoNotStop instead of starting a new task. - // The Interlocked.CompareExchange() call below will not succeed and the loop continues (until the event queue is empty and the _eventQueueTaskEndTime expires again). - // This should prevent a rare (but theoretically possible) scenario caused by context switching. - if (Interlocked.CompareExchange(ref _eventQueueTaskState, EventQueueTaskStopped, EventQueueTaskRunning) == EventQueueTaskRunning) - break; - + // remove expired items if needed + if (_removeExpiredValues && DateTime.UtcNow >= _timeForNextExpiredValuesRemoval) + { + if (Interlocked.CompareExchange(ref _removeExpiredValuesState, ActionNotQueued, ActionQueuedOrRunning) == ActionQueuedOrRunning) + { + if (_maintainLRU) + RemoveExpiredValuesLRU(); + else + RemoveExpiredValues(); + } + } + + // process all events in the queue and exit + if (_eventQueue.TryDequeue(out var action)) + { + action?.Invoke(); + } + else if (DateTime.UtcNow > _eventQueueTaskStopTime) // no more event to be processed, exit if expired + { + // Setting _eventQueueTaskState = EventQueueTaskStopped if the _eventQueueStopTime has past and _eventQueueTaskState == EventQueueTaskRunning. + // This means no other thread came in and it is safe to end this task. + // If another thread adds new events while this task is still running, it will set the _eventQueueTaskState = EventQueueTaskDoNotStop instead of starting a new task. + // The Interlocked.CompareExchange() call below will not succeed and the loop continues (until the event queue is empty and the _eventQueueTaskEndTime expires again). + // This should prevent a rare (but theoretically possible) scenario caused by context switching. + if (Interlocked.CompareExchange(ref _eventQueueTaskState, EventQueueTaskStopped, EventQueueTaskRunning) == EventQueueTaskRunning) + break; + + } + else // if empty, let the thread sleep for a specified number of milliseconds before attempting to retrieve another value from the queue + { + Thread.Sleep(_eventQueuePollingInterval); + } } - else // if empty, let the thread sleep for a specified number of milliseconds before attempting to retrieve another value from the queue + catch (Exception ex) { - Thread.Sleep(_eventQueuePollingInterval); + if (LogHelper.IsEnabled(EventLogLevel.Warning)) + LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10900, ex)); } } - catch (Exception ex) - { - if (LogHelper.IsEnabled(EventLogLevel.Warning)) - LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10900, ex)); - } } - - Interlocked.Decrement(ref _taskCount); + finally + { + Interlocked.Decrement(ref _taskCount); + Interlocked.Exchange(ref _eventQueueTaskState, EventQueueTaskStopped); + } } /// /// Remove all expired cache items from _doubleLinkedList and _map. /// /// Number of items removed. - internal int RemoveExpiredValuesLRU() + internal void RemoveExpiredValuesLRU() { - int numItemsRemoved = 0; +#pragma warning disable CA1031 // Do not catch general exception types try { - var node = _doubleLinkedList.First; + LinkedListNode> node = _doubleLinkedList.First; while (node != null) { - var nextNode = node.Next; + LinkedListNode> nextNode = node.Next; if (node.Value.ExpirationTime < DateTime.UtcNow) { _doubleLinkedList.Remove(node); - if (_map.TryRemove(node.Value.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); - - numItemsRemoved++; + if (_map.TryRemove(node.Value.Key, out LRUCacheItem cacheItem)) + OnItemExpired?.Invoke(cacheItem.Value); } node = nextNode; } } - catch (ObjectDisposedException ex) + catch(Exception ex) { if (LogHelper.IsEnabled(EventLogLevel.Warning)) LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10902, LogHelper.MarkAsNonPII(nameof(RemoveExpiredValuesLRU)), ex)); } - - return numItemsRemoved; + finally + { + _removeExpiredValuesState = ActionNotQueued; + _timeForNextExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); + } +#pragma warning restore CA1031 // Do not catch general exception types } /// @@ -275,29 +285,64 @@ internal int RemoveExpiredValuesLRU() /// The enumerator returned from the dictionary is safe to use concurrently with reads and writes to the dictionary, according to the MS document. /// /// Number of items removed. - internal int RemoveExpiredValues() + internal void RemoveExpiredValues() { - int numItemsRemoved = 0; +#pragma warning disable CA1031 // Do not catch general exception types try { - foreach (var node in _map) + foreach (KeyValuePair> node in _map) { if (node.Value.ExpirationTime < DateTime.UtcNow) { if (_map.TryRemove(node.Value.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); + OnItemExpired?.Invoke(cacheItem.Value); + } + } + } + catch(Exception ex) + { + if (LogHelper.IsEnabled(EventLogLevel.Warning)) + LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10902, LogHelper.MarkAsNonPII(nameof(ProcessCompactedValues)), ex)); + + } + finally + { + _removeExpiredValuesState = ActionNotQueued; + _timeForNextExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); + } - numItemsRemoved++; +#pragma warning restore CA1031 // Do not catch general exception types + } + + /// + /// Remove all compacted items. + /// + internal void ProcessCompactedValues() + { +#pragma warning disable CA1031 // Do not catch general exception types + try + { + for (int i = _compactedItems.Count - 1; i >= 0; i--) + { + if ((OnShouldRemoveFromCompactedList == null) || OnShouldRemoveFromCompactedList(_compactedItems[i].Value)) + { + OnItemRemovedFromCompactedList?.Invoke(_compactedItems[i].Value); + _compactedItems.RemoveAt(i); } } } - catch (ObjectDisposedException ex) + catch(Exception ex) { if (LogHelper.IsEnabled(EventLogLevel.Warning)) - LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10902, LogHelper.MarkAsNonPII(nameof(RemoveExpiredValues)), ex)); + LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10906, LogHelper.MarkAsNonPII(nameof(ProcessCompactedValues)), ex)); + } + finally + { + _processCompactedValuesState = ActionNotQueued; + _timeForNextCompaction = DateTime.UtcNow.AddSeconds(_compactIntervalInSeconds); } - return numItemsRemoved; +#pragma warning restore CA1031 // Do not catch general exception types } /// @@ -306,18 +351,23 @@ internal int RemoveExpiredValues() /// private void CompactLRU() { - var newCacheSize = CalculateNewCacheSize(); - while (_map.Count > newCacheSize && _doubleLinkedList.Count > 0) + try { - var lru = _doubleLinkedList.Last; - if (_map.TryRemove(lru.Value.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); + int newCacheSize = CalculateNewCacheSize(); + while (_map.Count > newCacheSize && _doubleLinkedList.Count > 0) + { + LinkedListNode> node = _doubleLinkedList.Last; + if (_map.TryRemove(node.Value.Key, out LRUCacheItem cacheItem)) + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); - _doubleLinkedList.RemoveLast(); + _compactedItems.Add(cacheItem); + _doubleLinkedList.RemoveLast(); + } + } + finally + { + _compactValuesState = ActionNotQueued; } - - // reset _compactionState so the compaction action can be queued again when needed - _compactionState = CompactionNotQueued; } /// @@ -326,21 +376,28 @@ private void CompactLRU() /// private void Compact() { - var newCacheSize = CalculateNewCacheSize(); - while (_map.Count > newCacheSize) + try { - // Since all items could have been removed by the public TryRemove() method, leaving the map empty, we need to check if a default value is returned. - // Remove the item from the map only if the returned item is NOT default value. - var item = _map.FirstOrDefault(); - if (!item.Equals(default)) + int newCacheSize = CalculateNewCacheSize(); + while (_map.Count > newCacheSize) { - if (_map.TryRemove(item.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); + // Since all items could have been removed by the public TryRemove() method, leaving the map empty, we need to check if a default value is returned. + // Remove the item from the map only if the returned item is NOT default value. + KeyValuePair> item = _map.FirstOrDefault(); + if (!item.Equals(default)) + { + if (_map.TryRemove(item.Key, out LRUCacheItem cacheItem)) + { + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); + _compactedItems.Add(cacheItem); + } + } } } - - // reset _compactionState so the compaction action can be queued again when needed - _compactionState = CompactionNotQueued; + finally + { + _compactValuesState = ActionNotQueued; + } } /// @@ -408,12 +465,20 @@ public bool SetValue(TKey key, TValue value, DateTime expirationTime) // if cache is at _maxCapacityPercentage, trim it by _compactionPercentage if ((double)_map.Count / _capacity >= _maxCapacityPercentage) { - if (Interlocked.CompareExchange(ref _compactionState, CompactionQueuedOrRunning, CompactionNotQueued) == CompactionNotQueued) + if (Interlocked.CompareExchange(ref _compactValuesState, ActionQueuedOrRunning, ActionNotQueued) == ActionNotQueued) { if (_maintainLRU) AddActionToEventQueue(CompactLRU); else AddActionToEventQueue(Compact); + + if (DateTime.UtcNow >= _timeForNextCompaction) + { + if (Interlocked.CompareExchange(ref _processCompactedValuesState, ActionQueuedOrRunning, ActionNotQueued) == ActionNotQueued) + { + _eventQueue.Enqueue(ProcessCompactedValues); + } + } } } @@ -476,8 +541,7 @@ private void StartEventQueueTaskIfNotRunning() // the caller's TaskScheduler (if there is one) as some custom TaskSchedulers might be single-threaded and its execution can be blocked. if (Interlocked.CompareExchange(ref _eventQueueTaskState, EventQueueTaskRunning, EventQueueTaskStopped) == EventQueueTaskStopped) { - // EventQueueTaskAction manages its own state. - _ = Task.Run(EventQueueTaskAction); + _ = Task.Run(EventQueueTaskAction); } } @@ -514,6 +578,23 @@ public bool TryGetValue(TKey key, out TValue value) return cacheItem != null; } + // These Try methods are not thread safe and they rely on the SignatureProviders to have logic to dispose of important objects. + // A better design would be to have TryRemove move the SignatureProvider to the compacted list. + // This would need a new action in LRUCache, AddItemToCompactedList. + + /// Removes a particular key from the cache. + public bool TryRemove(TKey key) + { + if (key == null) + throw LogHelper.LogArgumentNullException(nameof(key)); + + if (!_map.TryRemove(key, out var cacheItem)) + return false; + + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); + return true; + } + /// Removes a particular key from the cache. public bool TryRemove(TKey key, out TValue value) { @@ -534,7 +615,7 @@ public bool TryRemove(TKey key, out TValue value) } value = cacheItem.Value; - OnItemRemoved?.Invoke(cacheItem.Value); + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); return true; } @@ -579,7 +660,9 @@ public bool TryRemove(TKey key, out TValue value) /// internal void WaitForProcessing() { - while (!_eventQueue.IsEmpty); + while (!_eventQueue.IsEmpty) + { + }; } #endregion @@ -613,4 +696,3 @@ public override bool Equals(object obj) public override int GetHashCode() => 990326508 + EqualityComparer.Default.GetHashCode(Key); } } - diff --git a/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs b/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs index 00ee2c023f..0ac44300f8 100644 --- a/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs +++ b/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Globalization; using System.Threading.Tasks; using Microsoft.IdentityModel.Abstractions; using Microsoft.IdentityModel.Logging; @@ -14,7 +13,6 @@ namespace Microsoft.IdentityModel.Tokens /// Current support is limited to only. /// public class InMemoryCryptoProviderCache: CryptoProviderCache, IDisposable - { internal CryptoProviderCacheOptions _cryptoProviderCacheOptions; private bool _disposed = false; @@ -28,39 +26,56 @@ public InMemoryCryptoProviderCache() : this(new CryptoProviderCacheOptions()) { } - internal CryptoProviderFactory CryptoProviderFactory { get; set; } - /// /// Creates a new instance of using the specified . /// /// The options used to configure the . - public InMemoryCryptoProviderCache(CryptoProviderCacheOptions cryptoProviderCacheOptions) + public InMemoryCryptoProviderCache(CryptoProviderCacheOptions cryptoProviderCacheOptions) : this(cryptoProviderCacheOptions, TaskCreationOptions.None) { - if (cryptoProviderCacheOptions == null) - throw LogHelper.LogArgumentNullException(nameof(cryptoProviderCacheOptions)); - - _cryptoProviderCacheOptions = cryptoProviderCacheOptions; - _signingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, removeExpiredValues: false, comparer: StringComparer.Ordinal) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; - _verifyingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, removeExpiredValues: false, comparer: StringComparer.Ordinal) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; } - /// - /// Creates a new instance of using the specified . - /// - /// The options used to configure the . - /// Options used to create the event queue thread. - /// The time used in ms for the timeout interval of the event queue. Defaults to 500 ms. internal InMemoryCryptoProviderCache(CryptoProviderCacheOptions cryptoProviderCacheOptions, TaskCreationOptions options, int tryTakeTimeout = 500) { - if (cryptoProviderCacheOptions == null) - throw LogHelper.LogArgumentNullException(nameof(cryptoProviderCacheOptions)); - + _cryptoProviderCacheOptions = cryptoProviderCacheOptions ?? throw LogHelper.LogArgumentNullException(nameof(cryptoProviderCacheOptions)); if (tryTakeTimeout <= 0) throw LogHelper.LogArgumentException(nameof(tryTakeTimeout), $"{nameof(tryTakeTimeout)} must be greater than zero"); - _cryptoProviderCacheOptions = cryptoProviderCacheOptions; - _signingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, options, StringComparer.Ordinal, false) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; - _verifyingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, options, StringComparer.Ordinal, false) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; + _signingSignatureProviders = new EventBasedLRUCache( + cryptoProviderCacheOptions.SizeLimit, + options, + comparer: StringComparer.Ordinal) + { + OnItemMovedToCompactedList = SetCryptoProviderCacheToNull, + OnItemRemovedFromCompactedList = DisposeSignatureProvider, + OnShouldRemoveFromCompactedList = IsCacheNullAndRefCountZero + }; + + _verifyingSignatureProviders = new EventBasedLRUCache( + cryptoProviderCacheOptions.SizeLimit, + options, + comparer: StringComparer.Ordinal) + { + OnItemMovedToCompactedList = SetCryptoProviderCacheToNull, + OnItemRemovedFromCompactedList = DisposeSignatureProvider, + OnShouldRemoveFromCompactedList = IsCacheNullAndRefCountZero + }; + } + + internal CryptoProviderFactory CryptoProviderFactory { get; set; } + + private static void DisposeSignatureProvider(SignatureProvider signatureProvider) + { + signatureProvider.Dispose(); + } + + private void SetCryptoProviderCacheToNull(SignatureProvider signatureProvider) + { + signatureProvider.CryptoProviderCache = null; + } + + private static bool IsCacheNullAndRefCountZero(SignatureProvider signatureProvider) + { + return signatureProvider.CryptoProviderCache == null && signatureProvider.RefCount == 0; } /// @@ -195,7 +210,7 @@ public override bool TryRemove(SignatureProvider signatureProvider) try { - return signatureProviderCache.TryRemove(cacheKey, out SignatureProvider provider); + return signatureProviderCache.TryRemove(cacheKey); } catch (Exception ex) { diff --git a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs index c644c0936f..94f41493d8 100644 --- a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs +++ b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs @@ -149,7 +149,7 @@ internal static class LogMessages public const string IDX10640 = "IDX10640: Algorithm is not supported: '{0}'."; // public const string IDX10641 = "IDX10641:"; public const string IDX10642 = "IDX10642: Creating signature using the input: '{0}'."; - public const string IDX10643 = "IDX10643: Comparing the signature created over the input with the token signature: '{0}'."; + // public const string IDX10643 = "IDX10643:"; // public const string IDX10644 = "IDX10644:"; public const string IDX10645 = "IDX10645: Elliptical Curve not supported for curveId: '{0}'"; public const string IDX10646 = "IDX10646: A CustomCryptoProvider was set and returned 'true' for IsSupportedAlgorithm(Algorithm: '{0}', Key: '{1}'), but Create.(algorithm, args) as '{2}' == NULL."; @@ -253,7 +253,8 @@ internal static class LogMessages //EventBasedLRUCache errors public const string IDX10900 = "IDX10900: EventBasedLRUCache._eventQueue encountered an error while processing a cache operation. Exception '{0}'."; public const string IDX10901 = "IDX10901: CryptoProviderCacheOptions.SizeLimit must be greater than 10. Value: '{0}'"; - public const string IDX10902 = "IDX10902: Object disposed exception in '{0}': '{1}'"; + public const string IDX10902 = "IDX10902: Exception caught while removing expired items: '{0}', Exception: '{1}'"; + public const string IDX10906 = "IDX10906: Exception caught while compacting items: '{0}', Exception: '{1}'"; // Crypto Errors public const string IDX11000 = "IDX11000: Cannot create EcdhKeyExchangeProvider. '{0}'\'s Curve '{1}' does not match with '{2}'\'s curve '{3}'."; diff --git a/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs index 1e80d5290c..fbc0fc727e 100644 --- a/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs @@ -68,6 +68,8 @@ public void Dispose() /// true, if called from Dispose(), false, if invoked inside a finalizer protected abstract void Dispose(bool disposing); + internal bool IsCached { get; set; } + /// /// Gets the . /// diff --git a/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs index e066b57f48..bb665cd9d0 100644 --- a/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs @@ -193,13 +193,11 @@ public override byte[] Sign(byte[] input) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } @@ -225,13 +223,11 @@ public override bool Sign(ReadOnlySpan input, Span signature, out in catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } #endif @@ -260,13 +256,11 @@ public override byte[] Sign(byte[] input, int offset, int count) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } @@ -301,9 +295,6 @@ public override bool Verify(byte[] input, byte[] signature) throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); } - if (LogHelper.IsEnabled(EventLogLevel.Informational)) - LogHelper.LogInformation(LogMessages.IDX10643, input); - KeyedHashAlgorithm keyedHashAlgorithm = GetKeyedHashAlgorithm(GetKeyBytes(Key), Algorithm); try { @@ -312,13 +303,11 @@ public override bool Verify(byte[] input, byte[] signature) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } @@ -449,9 +438,6 @@ internal bool Verify(byte[] input, int inputOffset, int inputLength, byte[] sign throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); } - if (LogHelper.IsEnabled(EventLogLevel.Informational)) - LogHelper.LogInformation(LogMessages.IDX10643, input); - KeyedHashAlgorithm keyedHashAlgorithm = null; try { @@ -465,18 +451,16 @@ internal bool Verify(byte[] input, int inputOffset, int inputLength, byte[] sign #else hash = keyedHashAlgorithm.ComputeHash(input, inputOffset, inputLength).AsSpan(); #endif - return Utility.AreEqual(signature, hash, signatureLength); } catch { - Dispose(true); + CryptoProviderCache?.TryRemove(this); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs index ec42035a02..42fbe8c5cf 100644 --- a/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs +++ b/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs @@ -47,13 +47,13 @@ public void CreateAndReleaseSignatureProviders(SignatureProviderTheoryData theor { var disposeCalled = GetSignatureProviderIsDisposedByReflect(signatureProvider); if (!disposeCalled) - context.Diffs.Add("Dispose wasn't called on the AsymmetricSignatureProvider."); + context.Diffs.Add("Dispose was supposed to be called on the AsymmetricSignatureProvider."); } else // signatureProvider.GetType().Equals(typeof(SymmetricSignatureProvider)) { var disposeCalled = GetSignatureProviderIsDisposedByReflect(signatureProvider); if (!disposeCalled) - context.Diffs.Add("Dispose wasn't called on the SymmetricSignatureProvider."); + context.Diffs.Add("Dispose was supposed to be called on the SymmetricSignatureProvider."); } } catch (Exception ex) @@ -917,8 +917,8 @@ public void ReferenceCountingTest_Caching() cryptoProviderFactory.ReleaseSignatureProvider(signing); - if (!GetSignatureProviderIsDisposedByReflect(signing)) - context.AddDiff($"{nameof(signing2)} should have been disposed"); + if (GetSignatureProviderIsDisposedByReflect(signing)) + context.AddDiff($"{nameof(signing)} should not have been disposed"); TestUtilities.AssertFailIfErrors(context); }